当前位置: 首页>>代码示例>>Python>>正文


Python data_iterator.TextIterator方法代码示例

本文整理汇总了Python中data_iterator.TextIterator方法的典型用法代码示例。如果您正苦于以下问题:Python data_iterator.TextIterator方法的具体用法?Python data_iterator.TextIterator怎么用?Python data_iterator.TextIterator使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在data_iterator的用法示例。


在下文中一共展示了data_iterator.TextIterator方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: gen_force_train_iter

# 需要导入模块: import data_iterator [as 别名]
# 或者: from data_iterator import TextIterator [as 别名]
def gen_force_train_iter(source_data, target_data, reshuffle, source_dict, target_dict, batch_size, maxlen, n_words_src, n_words_trg):
   iter = 0
   while True:
     if reshuffle:
         os.popen('python shuffle.py '+ source_data + ' ' + target_data)
         os.popen('mv '+ source_data + '.shuf ' + source_data)
         os.popen('mv '+ target_data + '.shuf ' + target_data)
     gen_force_train = TextIterator(source_data, target_data, source_dict, target_dict, batch_size, maxlen, n_words_src, n_words_trg)
     ExampleNum = 0
     EpochStart = time.time()
     for x, y in gen_force_train:
         if len(x) < batch_size and len(y) < batch_size:
                 continue
         ExampleNum += len(x)
         yield x, y, iter
     TimeCost = time.time() - EpochStart
     iter +=1
     print('Seen', ExampleNum, 'generator samples. Time cost is ', TimeCost) 
开发者ID:ZhenYangIACAS,项目名称:NMT_GAN,代码行数:20,代码来源:share_function.py

示例2: train_iter

# 需要导入模块: import data_iterator [as 别名]
# 或者: from data_iterator import TextIterator [as 别名]
def train_iter(self):
        Epoch=0;
        while True:
            if self.reshuffle:
                os.popen('python  shuffle.py  '+self.train_data_source+' '+ self.train_data_target)
                os.popen('mv '+ self.train_data_source+'.shuf   '+ self.train_data_source)
                os.popen('mv '+ self.train_data_target+'.shuf   '+ self.train_data_target)
            train = TextIterator(self.train_data_source, self.train_data_target,
                         self.dictionaries[0], self.dictionaries[1],
                         n_words_source= self.n_words_src, n_words_target= self.n_words_trg,
                         batch_size= self.batch_size * self.gpu_num,
                         maxlen= self.max_len)
            ExamplesNum=0;
            print( 'Epoch : ' , Epoch )
            EpochStart = time.time()
            for x,y in train:
                if len(x) < self.gpu_num * self.batch_size:
                    continue
                ExamplesNum+=len(x);
                yield x, y, Epoch
            TimeCost = time.time() - EpochStart;    
            Epoch+=1;
            print('Seen ',ExamplesNum,' examples. Time Cost : ',TimeCost) 
开发者ID:ZhenYangIACAS,项目名称:NMT_GAN,代码行数:25,代码来源:nmt_generator.py

示例3: load_data

# 需要导入模块: import data_iterator [as 别名]
# 或者: from data_iterator import TextIterator [as 别名]
def load_data(config):
    logging.info('Reading data...')
    text_iterator = TextIterator(
                        source=config.source_dataset,
                        target=config.target_dataset,
                        source_dicts=config.source_dicts,
                        target_dict=config.target_dict,
                        model_type=config.model_type,
                        batch_size=config.batch_size,
                        maxlen=config.maxlen,
                        source_vocab_sizes=config.source_vocab_sizes,
                        target_vocab_size=config.target_vocab_size,
                        skip_empty=True,
                        shuffle_each_epoch=config.shuffle_each_epoch,
                        sort_by_length=config.sort_by_length,
                        use_factor=(config.factors > 1),
                        maxibatch_size=config.maxibatch_size,
                        token_batch_size=config.token_batch_size,
                        keep_data_in_memory=config.keep_train_set_in_memory,
                        preprocess_script=config.preprocess_script)

    if config.valid_freq and config.valid_source_dataset and config.valid_target_dataset:
        valid_text_iterator = TextIterator(
                            source=config.valid_source_dataset,
                            target=config.valid_target_dataset,
                            source_dicts=config.source_dicts,
                            target_dict=config.target_dict,
                            model_type=config.model_type,
                            batch_size=config.valid_batch_size,
                            maxlen=config.maxlen,
                            source_vocab_sizes=config.source_vocab_sizes,
                            target_vocab_size=config.target_vocab_size,
                            shuffle_each_epoch=False,
                            sort_by_length=True,
                            use_factor=(config.factors > 1),
                            maxibatch_size=config.maxibatch_size,
                            token_batch_size=config.valid_token_batch_size)
    else:
        logging.info('no validation set loaded')
        valid_text_iterator = None
    logging.info('Done')
    return text_iterator, valid_text_iterator 
开发者ID:EdinburghNLP,项目名称:nematus,代码行数:44,代码来源:nmt.py

示例4: calc_cross_entropy_per_sentence

# 需要导入模块: import data_iterator [as 别名]
# 或者: from data_iterator import TextIterator [as 别名]
def calc_cross_entropy_per_sentence(session, model, config, text_iterator,
                                    normalization_alpha=0.0):
    """Calculates cross entropy values for a parallel corpus.

    By default (when normalization_alpha is 0.0), the sentence-level cross
    entropy is calculated. If normalization_alpha is 1.0 then the per-token
    cross entropy is calculated. Other values of normalization_alpha may be
    useful if the cross entropy value will be used as a score for selecting
    between translation candidates (e.g. in reranking an n-nbest list). Using
    a different (empirically determined) alpha value can help correct a model
    bias toward too-short / too-long sentences.

    TODO Support for multiple GPUs

    Args:
        session: TensorFlow session.
        model: a RNNModel object.
        config: model config.
        text_iterator: TextIterator.
        normalization_alpha: length normalization hyperparameter.

    Returns:
        A pair of lists. The first contains the (possibly normalized) cross
        entropy value for each sentence pair. The second contains the
        target-side token count for each pair (including the terminating
        <EOS> symbol).
    """
    ce_vals, token_counts = [], []
    for xx, yy in text_iterator:
        if len(xx[0][0]) != config.factors:
            logging.error('Mismatch between number of factors in settings ' \
                          '({0}) and number present in data ({1})'.format(
                          config.factors, len(xx[0][0])))
            sys.exit(1)
        x, x_mask, y, y_mask = util.prepare_data(xx, yy, config.factors,
                                                 maxlen=None)

        # Run the minibatch through the model to get the sentence-level cross
        # entropy values.
        feeds = {model.inputs.x: x,
                 model.inputs.x_mask: x_mask,
                 model.inputs.y: y,
                 model.inputs.y_mask: y_mask,
                 model.inputs.training: False}
        batch_ce_vals = session.run(model.loss_per_sentence, feed_dict=feeds)

        # Optionally, do length normalization.
        batch_token_counts = [numpy.count_nonzero(s) for s in y_mask.T]
        if normalization_alpha:
            adjusted_lens = [n**normalization_alpha for n in batch_token_counts]
            batch_ce_vals /= numpy.array(adjusted_lens)

        ce_vals += list(batch_ce_vals)
        token_counts += batch_token_counts
        logging.info("Seen {}".format(len(ce_vals)))

    assert len(ce_vals) == len(token_counts)
    return ce_vals, token_counts 
开发者ID:EdinburghNLP,项目名称:nematus,代码行数:60,代码来源:nmt.py


注:本文中的data_iterator.TextIterator方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。