本文整理汇总了C++中AnyType::setUserFuncContext方法的典型用法代码示例。如果您正苦于以下问题:C++ AnyType::setUserFuncContext方法的具体用法?C++ AnyType::setUserFuncContext怎么用?C++ AnyType::setUserFuncContext使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类AnyType
的用法示例。
在下文中一共展示了AnyType::setUserFuncContext方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的C++代码示例。
示例1: run
/**
* @brief This function learns the topics of words in a document and is the
* main step of a Gibbs sampling iteration. The word topic counts and
* corpus topic counts are passed to this function in the first call and
* then transfered to the rest calls through args.mSysInfo->user_fctx for
* efficiency.
* @param args[0] The unique words in the documents
* @param args[1] The counts of each unique words
* @param args[2] The topic counts and topic assignments in the document
* @param args[3] The model (word topic counts and corpus topic
* counts)
* @param args[4] The Dirichlet parameter for per-document topic
* multinomial, i.e. alpha
* @param args[5] The Dirichlet parameter for per-topic word
* multinomial, i.e. beta
* @param args[6] The size of vocabulary
* @param args[7] The number of topics
* @param args[8] The number of iterations (=1:training, >1:prediction)
* @return The updated topic counts and topic assignments for
* the document
**/
AnyType lda_gibbs_sample::run(AnyType & args)
{
ArrayHandle<int32_t> words = args[0].getAs<ArrayHandle<int32_t> >();
ArrayHandle<int32_t> counts = args[1].getAs<ArrayHandle<int32_t> >();
MutableArrayHandle<int32_t> doc_topic = args[2].getAs<MutableArrayHandle<int32_t> >();
double alpha = args[4].getAs<double>();
double beta = args[5].getAs<double>();
int32_t voc_size = args[6].getAs<int32_t>();
int32_t topic_num = args[7].getAs<int32_t>();
int32_t iter_num = args[8].getAs<int32_t>();
size_t model64_size = static_cast<size_t>(voc_size * (topic_num + 1) + 1) * sizeof(int32_t) / sizeof(int64_t);
if(alpha <= 0)
throw std::invalid_argument("invalid argument - alpha");
if(beta <= 0)
throw std::invalid_argument("invalid argument - beta");
if(voc_size <= 0)
throw std::invalid_argument(
"invalid argument - voc_size");
if(topic_num <= 0)
throw std::invalid_argument(
"invalid argument - topic_num");
if(iter_num <= 0)
throw std::invalid_argument(
"invalid argument - iter_num");
if(words.size() != counts.size())
throw std::invalid_argument(
"dimensions mismatch: words.size() != counts.size()");
if(__min(words) < 0 || __max(words) >= voc_size)
throw std::invalid_argument(
"invalid values in words");
if(__min(counts) <= 0)
throw std::invalid_argument(
"invalid values in counts");
int32_t word_count = __sum(counts);
if(doc_topic.size() != (size_t)(word_count + topic_num))
throw std::invalid_argument(
"invalid dimension - doc_topic.size() != word_count + topic_num");
if(__min(doc_topic, 0, topic_num) < 0)
throw std::invalid_argument("invalid values in topic_count");
if(
__min(doc_topic, topic_num, word_count) < 0 ||
__max(doc_topic, topic_num, word_count) >= topic_num)
throw std::invalid_argument( "invalid values in topic_assignment");
if (!args.getUserFuncContext()) {
ArrayHandle<int64_t> model64 = args[3].getAs<ArrayHandle<int64_t> >();
if (model64.size() != model64_size) {
std::stringstream ss;
ss << "invalid dimension: model64.size() = " << model64.size();
throw std::invalid_argument(ss.str());
}
if (__min(model64) < 0) {
throw std::invalid_argument("invalid topic counts in model");
}
int32_t *context =
static_cast<int32_t *>(
MemoryContextAllocZero(
args.getCacheMemoryContext(),
model64.size() * sizeof(int64_t)
+ topic_num * sizeof(int64_t)));
memcpy(context, model64.ptr(), model64.size() * sizeof(int64_t));
int32_t *model = context;
int64_t *running_topic_counts = reinterpret_cast<int64_t *>(
context + model64_size * sizeof(int64_t) / sizeof(int32_t));
for (int i = 0; i < voc_size; i ++) {
for (int j = 0; j < topic_num; j ++) {
running_topic_counts[j] += model[i * (topic_num + 1) + j];
}
}
args.setUserFuncContext(context);
}
int32_t *context = static_cast<int32_t *>(args.getUserFuncContext());
//.........这里部分代码省略.........
示例2: run
/**
* @brief This function learns the topics of words in a document and is the
* main step of a Gibbs sampling iteration. The word topic counts and
* corpus topic counts are passed to this function in the first call and
* then transfered to the rest calls through args.mSysInfo->user_fctx for
* efficiency.
* @param args[0] The unique words in the documents
* @param args[1] The counts of each unique words
* @param args[2] The topic counts and topic assignments in the document
* @param args[3] The model (word topic counts and corpus topic
* counts)
* @param args[4] The Dirichlet parameter for per-document topic
* multinomial, i.e. alpha
* @param args[5] The Dirichlet parameter for per-topic word
* multinomial, i.e. beta
* @param args[6] The size of vocabulary
* @param args[7] The number of topics
* @param args[8] The number of iterations (=1:training, >1:prediction)
* @return The updated topic counts and topic assignments for
* the document
**/
AnyType lda_gibbs_sample::run(AnyType & args)
{
ArrayHandle<int32_t> words = args[0].getAs<ArrayHandle<int32_t> >();
ArrayHandle<int32_t> counts = args[1].getAs<ArrayHandle<int32_t> >();
MutableArrayHandle<int32_t> doc_topic = args[2].getAs<MutableArrayHandle<int32_t> >();
double alpha = args[4].getAs<double>();
double beta = args[5].getAs<double>();
int32_t voc_size = args[6].getAs<int32_t>();
int32_t topic_num = args[7].getAs<int32_t>();
int32_t iter_num = args[8].getAs<int32_t>();
if(alpha <= 0)
throw std::invalid_argument("invalid argument - alpha");
if(beta <= 0)
throw std::invalid_argument("invalid argument - beta");
if(voc_size <= 0)
throw std::invalid_argument(
"invalid argument - voc_size");
if(topic_num <= 0)
throw std::invalid_argument(
"invalid argument - topic_num");
if(iter_num <= 0)
throw std::invalid_argument(
"invalid argument - iter_num");
if(words.size() != counts.size())
throw std::invalid_argument(
"dimensions mismatch: words.size() != counts.size()");
if(__min(words) < 0 || __max(words) >= voc_size)
throw std::invalid_argument(
"invalid values in words");
if(__min(counts) <= 0)
throw std::invalid_argument(
"invalid values in counts");
int32_t word_count = __sum(counts);
if(doc_topic.size() != (size_t)(word_count + topic_num))
throw std::invalid_argument(
"invalid dimension - doc_topic.size() != word_count + topic_num");
if(__min(doc_topic, 0, topic_num) < 0)
throw std::invalid_argument("invalid values in topic_count");
if(
__min(doc_topic, topic_num, word_count) < 0 ||
__max(doc_topic, topic_num, word_count) >= topic_num)
throw std::invalid_argument( "invalid values in topic_assignment");
if (!args.getUserFuncContext())
{
if(args[3].isNull())
throw std::invalid_argument("invalid argument - the model \
parameter should not be null for the first call");
ArrayHandle<int64_t> model = args[3].getAs<ArrayHandle<int64_t> >();
if(model.size() != (size_t)((voc_size + 1) * topic_num))
throw std::invalid_argument(
"invalid dimension - model.size() != (voc_size + 1) * topic_num");
if(__min(model) < 0)
throw std::invalid_argument("invalid topic counts in model");
int64_t * state =
static_cast<int64_t *>(
MemoryContextAllocZero(
args.getCacheMemoryContext(),
model.size() * sizeof(int64_t)));
memcpy(state, model.ptr(), model.size() * sizeof(int64_t));
args.setUserFuncContext(state);
}
int64_t * state = static_cast<int64_t *>(args.getUserFuncContext());
if(NULL == state){
throw std::runtime_error("args.mSysInfo->user_fctx is null");
}
int32_t unique_word_count = static_cast<int32_t>(words.size());
for(int32_t it = 0; it < iter_num; it++){
int32_t word_index = topic_num;
for(int32_t i = 0; i < unique_word_count; i++) {
int32_t wordid = words[i];
for(int32_t j = 0; j < counts[i]; j++){
int32_t topic = doc_topic[word_index];
//.........这里部分代码省略.........