当前位置: 首页>>代码示例>>Python>>正文


Python wmt_utils.EOS_ID属性代码示例

本文整理汇总了Python中wmt_utils.EOS_ID属性的典型用法代码示例。如果您正苦于以下问题:Python wmt_utils.EOS_ID属性的具体用法?Python wmt_utils.EOS_ID怎么用?Python wmt_utils.EOS_ID使用的例子?那么恭喜您, 这里精选的属性代码示例或许可以为您提供帮助。您也可以进一步了解该属性所在wmt_utils的用法示例。


在下文中一共展示了wmt_utils.EOS_ID属性的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: read_data

# 需要导入模块: import wmt_utils [as 别名]
# 或者: from wmt_utils import EOS_ID [as 别名]
def read_data(source_path, target_path, buckets, max_size=None, print_out=True):
  """Read data from source and target files and put into buckets.

  Args:
    source_path: path to the files with token-ids for the source language.
    target_path: path to the file with token-ids for the target language;
      it must be aligned with the source file: n-th line contains the desired
      output for n-th line from the source_path.
    buckets: the buckets to use.
    max_size: maximum number of lines to read, all other will be ignored;
      if 0 or None, data files will be read completely (no limit).
      If set to 1, no data will be returned (empty lists of the right form).
    print_out: whether to print out status or not.

  Returns:
    data_set: a list of length len(_buckets); data_set[n] contains a list of
      (source, target) pairs read from the provided data files that fit
      into the n-th bucket, i.e., such that len(source) < _buckets[n][0] and
      len(target) < _buckets[n][1]; source and target are lists of token-ids.
  """
  data_set = [[] for _ in buckets]
  counter = 0
  if max_size != 1:
    with tf.gfile.GFile(source_path, mode="r") as source_file:
      with tf.gfile.GFile(target_path, mode="r") as target_file:
        source, target = source_file.readline(), target_file.readline()
        while source and target and (not max_size or counter < max_size):
          counter += 1
          if counter % 100000 == 0 and print_out:
            print "  reading data line %d" % counter
            sys.stdout.flush()
          source_ids = [int(x) for x in source.split()]
          target_ids = [int(x) for x in target.split()]
          source_ids, source_len = zero_split(source_ids)
          target_ids, target_len = zero_split(target_ids, append=wmt.EOS_ID)
          for bucket_id, size in enumerate(buckets):
            if source_len <= size and target_len <= size:
              data_set[bucket_id].append([source_ids, target_ids])
              break
          source, target = source_file.readline(), target_file.readline()
  return data_set 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:43,代码来源:neural_gpu_trainer.py

示例2: score_beams

# 需要导入模块: import wmt_utils [as 别名]
# 或者: from wmt_utils import EOS_ID [as 别名]
def score_beams(beams, target, inp, history, p,
                print_out=False, test_mode=False):
  """Score beams."""
  if p == "progsynth":
    return score_beams_prog(beams, target, inp, history, print_out, test_mode)
  elif test_mode:
    return beams[0], 10.0 if str(beams[0][:len(target)]) == str(target) else 0.0
  else:
    history_s = [str(h) for h in history]
    best, best_score, tgt, eos_id = None, -1000.0, target, None
    if p == "wmt":
      eos_id = wmt.EOS_ID
    if eos_id and eos_id in target:
      tgt = target[:target.index(eos_id)]
    for beam in beams:
      if eos_id and eos_id in beam:
        beam = beam[:beam.index(eos_id)]
      l = min(len(tgt), len(beam))
      score = len([i for i in xrange(l) if tgt[i] == beam[i]]) / float(len(tgt))
      hist_score = 20.0 if str([b for b in beam if b > 0]) in history_s else 0.0
      if score < 1.0:
        score -= hist_score
      if score > best_score:
        best = beam
        best_score = score
    return best, best_score 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:28,代码来源:neural_gpu_trainer.py

示例3: read_data

# 需要导入模块: import wmt_utils [as 别名]
# 或者: from wmt_utils import EOS_ID [as 别名]
def read_data(source_path, target_path, buckets, max_size=None, print_out=True):
  """Read data from source and target files and put into buckets.

  Args:
    source_path: path to the files with token-ids for the source language.
    target_path: path to the file with token-ids for the target language;
      it must be aligned with the source file: n-th line contains the desired
      output for n-th line from the source_path.
    buckets: the buckets to use.
    max_size: maximum number of lines to read, all other will be ignored;
      if 0 or None, data files will be read completely (no limit).
      If set to 1, no data will be returned (empty lists of the right form).
    print_out: whether to print out status or not.

  Returns:
    data_set: a list of length len(_buckets); data_set[n] contains a list of
      (source, target) pairs read from the provided data files that fit
      into the n-th bucket, i.e., such that len(source) < _buckets[n][0] and
      len(target) < _buckets[n][1]; source and target are lists of token-ids.
  """
  data_set = [[] for _ in buckets]
  counter = 0
  if max_size != 1:
    with tf.gfile.GFile(source_path, mode="r") as source_file:
      with tf.gfile.GFile(target_path, mode="r") as target_file:
        source, target = source_file.readline(), target_file.readline()
        while source and target and (not max_size or counter < max_size):
          counter += 1
          if counter % 100000 == 0 and print_out:
            print("  reading data line %d" % counter)
            sys.stdout.flush()
          source_ids = [int(x) for x in source.split()]
          target_ids = [int(x) for x in target.split()]
          source_ids, source_len = zero_split(source_ids)
          target_ids, target_len = zero_split(target_ids, append=wmt.EOS_ID)
          for bucket_id, size in enumerate(buckets):
            if source_len <= size and target_len <= size:
              data_set[bucket_id].append([source_ids, target_ids])
              break
          source, target = source_file.readline(), target_file.readline()
  return data_set 
开发者ID:itsamitgoel,项目名称:Gun-Detector,代码行数:43,代码来源:neural_gpu_trainer.py

示例4: linearize

# 需要导入模块: import wmt_utils [as 别名]
# 或者: from wmt_utils import EOS_ID [as 别名]
def linearize(output, rev_fr_vocab, simple_tokenizer=None, eos_id=wmt.EOS_ID):
  # If there is an EOS symbol in outputs, cut them at that point (WMT).
  if eos_id in output:
    output = output[:output.index(eos_id)]
  # Print out French sentence corresponding to outputs.
  if simple_tokenizer or FLAGS.simple_tokenizer:
    vlen = len(rev_fr_vocab)
    def vget(o):
      if o < vlen:
        return rev_fr_vocab[o]
      return "UNK"
    return " ".join([vget(o) for o in output])
  else:
    return wmt.basic_detokenizer([rev_fr_vocab[o] for o in output]) 
开发者ID:itsamitgoel,项目名称:Gun-Detector,代码行数:16,代码来源:neural_gpu_trainer.py


注:本文中的wmt_utils.EOS_ID属性示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。