当前位置: 首页>>代码示例>>Python>>正文


Python data_utils.EOS_ID属性代码示例

本文整理汇总了Python中data_utils.EOS_ID属性的典型用法代码示例。如果您正苦于以下问题:Python data_utils.EOS_ID属性的具体用法?Python data_utils.EOS_ID怎么用?Python data_utils.EOS_ID使用的例子?那么恭喜您, 这里精选的属性代码示例或许可以为您提供帮助。您也可以进一步了解该属性所在data_utils的用法示例。


在下文中一共展示了data_utils.EOS_ID属性的14个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: read_data

# 需要导入模块: import data_utils [as 别名]
# 或者: from data_utils import EOS_ID [as 别名]
def read_data(path, max_size=None):
    data_set = [[] for _ in _buckets]
    data = json.load(open(path,'r'))
    counter = 0
    size_max = 0
    for pair in data:
        post = pair[0]
        responses = pair[1]
        source_ids = [int(x) for x in post[0]]
        for response in responses:
            if not max_size or counter < max_size:
                counter += 1
                if counter % 100000 == 0:
                    print("    reading data pair %d" % counter)
                    sys.stdout.flush()
                target_ids = [int(x) for x in response[0]]
                target_ids.append(data_utils.EOS_ID)
                #size_max = len(source_ids) if len(source_ids) > size_max else size_max
                #size_max = len(target_ids) if len(target_ids) > size_max else size_max
                for bucket_id, (source_size, target_size) in enumerate(_buckets):
                    if len(source_ids) < source_size and len(target_ids) < target_size:
                        data_set[bucket_id].append([source_ids, target_ids, int(post[1]), int(response[1])])
                        break
    return data_set 
开发者ID:thu-coai,项目名称:ecm,代码行数:26,代码来源:baseline.py

示例2: run

# 需要导入模块: import data_utils [as 别名]
# 或者: from data_utils import EOS_ID [as 别名]
def run(self, sentence):
    # Get token-ids for the input sentence.
    token_ids = data_utils.sentence_to_token_ids(sentence, self.en_vocab)
    # Which bucket does it belong to?
    bucket_id = min([b for b in xrange(len(_buckets))
                     if _buckets[b][0] > len(token_ids)])
    # Get a 1-element batch to feed the sentence to the model.
    encoder_inputs, decoder_inputs, target_weights = self.model.get_batch(
        {bucket_id: [(token_ids, [])]}, bucket_id)
    # Get output logits for the sentence.
    _, _, output_logits = self.model.step(self.sess, encoder_inputs, decoder_inputs,
                                     target_weights, bucket_id, True)
    # This is a greedy decoder - outputs are just argmaxes of output_logits.
    outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
    # If there is an EOS symbol in outputs, cut them at that point.
    if data_utils.EOS_ID in outputs:
      outputs = outputs[:outputs.index(data_utils.EOS_ID)]
    # Print out French sentence corresponding to outputs.
    return "".join([self.rev_fr_vocab[output] for output in outputs]) 
开发者ID:muik,项目名称:transliteration,代码行数:21,代码来源:translate.py

示例3: read_data

# 需要导入模块: import data_utils [as 别名]
# 或者: from data_utils import EOS_ID [as 别名]
def read_data(source_path, target_path, max_size=None):
  """Read data from source and target files and put into buckets.

  Args:
    source_path: path to the files with token-ids for the source language.
    target_path: path to the file with token-ids for the target language;
      it must be aligned with the source file: n-th line contains the desired
      output for n-th line from the source_path.
    max_size: maximum number of lines to read, all other will be ignored;
      if 0 or None, data files will be read completely (no limit).

  Returns:
    data_set: a list of length len(_buckets); data_set[n] contains a list of
      (source, target) pairs read from the provided data files that fit
      into the n-th bucket, i.e., such that len(source) < _buckets[n][0] and
      len(target) < _buckets[n][1]; source and target are lists of token-ids.
  """
  data_set = [[] for _ in _buckets]
  with tf.gfile.GFile(source_path, mode="r") as source_file:
    with tf.gfile.GFile(target_path, mode="r") as target_file:
      source, target = source_file.readline(), target_file.readline()
      counter = 0
      while source and target and (not max_size or counter < max_size):
        counter += 1
        if counter % 100000 == 0:
          print("  reading data line %d" % counter)
          sys.stdout.flush()
        source_ids = [int(x) for x in source.split()]
        target_ids = [int(x) for x in target.split()]
        target_ids.append(data_utils.EOS_ID)
        for bucket_id, (source_size, target_size) in enumerate(_buckets):
          if len(source_ids) < source_size and len(target_ids) < target_size:
            data_set[bucket_id].append([source_ids, target_ids])
            break
        source, target = source_file.readline(), target_file.readline()
  return data_set 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:38,代码来源:translate.py

示例4: read_data

# 需要导入模块: import data_utils [as 别名]
# 或者: from data_utils import EOS_ID [as 别名]
def read_data(source_path, target_path, max_size=None):
  """Read data from source and target files and put into buckets.

  Args:
    source_path: path to the files with token-ids for the source language.
    target_path: path to the file with token-ids for the target language;
      it must be aligned with the source file: n-th line contains the desired
      output for n-th line from the source_path.
    max_size: maximum number of lines to read, all other will be ignored;
      if 0 or None, data files will be read completely (no limit).

  Returns:
    data_set: a list of length len(_buckets); data_set[n] contains a list of
      (source, target) pairs read from the provided data files that fit
      into the n-th bucket, i.e., such that len(source) < _buckets[n][0] and
      len(target) < _buckets[n][1]; source and target are lists of token-ids.
  """
  data_set = [[] for _ in _buckets]
  with tf.gfile.GFile(source_path, mode="r") as source_file:
    with tf.gfile.GFile(target_path, mode="r") as target_file:
      source, target = source_file.readline(), target_file.readline()
      counter = 0
      while source and target and (not max_size or counter < max_size):
        counter += 1
        if counter % 1000 == 0:
          print("  reading data line %d" % counter)
          sys.stdout.flush()
        source_ids = [int(x) for x in source.split()]
        target_ids = [int(x) for x in target.split()]
        target_ids.append(data_utils.EOS_ID)
        for bucket_id, (source_size, target_size) in enumerate(_buckets):
          if len(source_ids) < source_size and len(target_ids) < target_size:
            data_set[bucket_id].append([source_ids, target_ids])
            break
        source, target = source_file.readline(), target_file.readline()
  return data_set 
开发者ID:hengluchang,项目名称:deep-news-summarization,代码行数:38,代码来源:execute.py

示例5: read_data

# 需要导入模块: import data_utils [as 别名]
# 或者: from data_utils import EOS_ID [as 别名]
def read_data(source_path, target_path, max_size=None):
  """Read data from source and target files and put into buckets.

  Args:
    source_path: path to the files with token-ids for the source language.
    target_path: path to the file with token-ids for the target language;
      it must be aligned with the source file: n-th line contains the desired
      output for n-th line from the source_path.
    max_size: maximum number of lines to read, all other will be ignored;
      if 0 or None, data files will be read completely (no limit).

  Returns:
    data_set: a list of length len(_buckets); data_set[n] contains a list of
      (source, target) pairs read from the provided data files that fit
      into the n-th bucket, i.e., such that len(source) < _buckets[n][0] and
      len(target) < _buckets[n][1]; source and target are lists of token-ids.
  """
  data_set = [[] for _ in _buckets]
  with tf.gfile.GFile(source_path, mode="r") as source_file:
    with tf.gfile.GFile(target_path, mode="r") as target_file:
      source, target = source_file.readline(), target_file.readline()
      counter = 0
      while source and target and (not max_size or counter < max_size):
        counter += 1
        if counter % 10000 == 0:
          print("  reading data line %d" % counter)
          sys.stdout.flush()
        source_ids = [int(x) for x in source.split()]
        target_ids = [int(x) for x in target.split()]
        target_ids.append(data_utils.EOS_ID)
        for bucket_id, (source_size, target_size) in enumerate(_buckets):
          if len(source_ids) < source_size and len(target_ids) < target_size:
            data_set[bucket_id].append([source_ids, target_ids])
            break
        source, target = source_file.readline(), target_file.readline()
  return data_set 
开发者ID:dashayushman,项目名称:deep-trans,代码行数:38,代码来源:transliterate.py

示例6: read_data

# 需要导入模块: import data_utils [as 别名]
# 或者: from data_utils import EOS_ID [as 别名]
def read_data(source_path, target_path, max_size=None):
  """Read data from source and target files and put into buckets.
  Args:
    source_path: path to the files with token-ids for the source language.
    target_path: path to the file with token-ids for the target language;
      it must be aligned with the source file: n-th line contains the desired
      output for n-th line from the source_path.
    max_size: maximum number of lines to read, all other will be ignored;
      if 0 or None, data files will be read completely (no limit).
  Returns:
    data_set: a list of length len(_buckets); data_set[n] contains a list of
      (source, target) pairs read from the provided data files that fit
      into the n-th bucket, i.e., such that len(source) < _buckets[n][0] and
      len(target) < _buckets[n][1]; source and target are lists of token-ids.
  """
  data_set = [[] for _ in _buckets]
  mycount=0
  with tf.gfile.GFile(source_path, mode="r") as source_file:
    with tf.gfile.GFile(target_path, mode="r") as target_file:
      source, target = source_file.readline(), target_file.readline()
      counter = 0
      while source and target and (not max_size or counter < max_size):
        counter += 1
        if counter % 100000 == 0:
          print("  reading data line %d" % counter)
          sys.stdout.flush()
        source_ids = [int(x) for x in source.split()]
        target_ids = [int(x) for x in target.split()]
        target_ids.append(data_utils.EOS_ID)
        for bucket_id, (source_size, target_size) in enumerate(_buckets):
          if len(source_ids) < source_size and len(target_ids) < target_size:
            data_set[bucket_id].append([source_ids, target_ids])
            mycount=mycount+1
            break
        source, target = source_file.readline(), target_file.readline()
  print(mycount)
  return data_set 
开发者ID:Shen-Lab,项目名称:DeepAffinity,代码行数:39,代码来源:translate.py

示例7: read_data

# 需要导入模块: import data_utils [as 别名]
# 或者: from data_utils import EOS_ID [as 别名]
def read_data(source_path, target_path, max_size=None):
  """Read data from source and target files and put into buckets.

  Args:
    source_path: path to the files with token-ids for the source language.
    target_path: path to the file with token-ids for the target language;
      it must be aligned with the source file: n-th line contains the desired
      output for n-th line from the source_path.
    max_size: maximum number of lines to read, all other will be ignored;
      if 0 or None, data files will be read completely (no limit).

  Returns:
    data_set: a list of length len(_buckets); data_set[n] contains a list of
      (source, target) pairs read from the provided data files that fit
      into the n-th bucket, i.e., such that len(source) < _buckets[n][0] and
      len(target) < _buckets[n][1]; source and target are lists of token-ids.
  """
  data_set = [[] for _ in _buckets]
  diffs = []
  with tf.gfile.GFile(source_path, mode="r") as source_file:
    with tf.gfile.GFile(target_path, mode="r") as target_file:
      source, target = source_file.readline(), target_file.readline()
      counter = 0
      while source and target and (not max_size or counter < max_size):
        counter += 1
        if counter % 100000 == 0:
          print("  reading data line %d" % counter)
          sys.stdout.flush()
        source_ids = [int(x) for x in source.split()]
        target_ids = [int(x) for x in target.split()]
        target_ids.append(data_utils.EOS_ID)
        for bucket_id, (source_size, target_size) in enumerate(_buckets):
          if len(source_ids) < source_size and len(target_ids) < target_size:
            data_set[bucket_id].append([source_ids, target_ids])
            diffs.append(source_size - len(source_ids) + target_size - len(target_ids))
            break
        source, target = source_file.readline(), target_file.readline()
  #print("mean padding count: %f" % np.mean(diffs))
  return data_set 
开发者ID:muik,项目名称:transliteration,代码行数:41,代码来源:translate.py

示例8: read_data

# 需要导入模块: import data_utils [as 别名]
# 或者: from data_utils import EOS_ID [as 别名]
def read_data(source_path, target_path, max_size=None):
  """Read data from source and target files and put into buckets.

  Args:
    source_path: path to the files with token-ids for the source language.
    target_path: path to the file with token-ids for the target language;
      it must be aligned with the source file: n-th line contains the desired
      output for n-th line from the source_path.
    max_size: maximum number of lines to read, all other will be ignored;
      if 0 or None, data files will be read completely (no limit).

  Returns:
    data_set: a list of length len(_buckets); data_set[n] contains a list of
      (source, target) pairs read from the provided data files that fit
      into the n-th bucket, i.e., such that len(source) < _buckets[n][0] and
      len(target) < _buckets[n][1]; source and target are lists of token-ids.
  """
  data_set = [[] for _ in _buckets]
  with tf.gfile.GFile(source_path, mode="r") as source_file:
    with tf.gfile.GFile(target_path, mode="r") as target_file:
      source, target = source_file.readline(), target_file.readline()
      counter = 0
      while source and target and (not max_size or counter < max_size):
        counter += 1
        if counter % 10 == 0:
          print("  reading data line %d" % counter)
          sys.stdout.flush()
        source_ids = [int(x) for x in source.split()]
        target_ids = [int(x) for x in target.split()]
        target_ids.append(data_utils.EOS_ID)
        for bucket_id, (source_size, target_size) in enumerate(_buckets):
          if len(source_ids) < source_size and len(target_ids) < target_size:
            data_set[bucket_id].append([source_ids, target_ids])
            break
        source, target = source_file.readline(), target_file.readline()
  return data_set 
开发者ID:warmheartli,项目名称:ChatBotCourse,代码行数:38,代码来源:translate.py

示例9: decode

# 需要导入模块: import data_utils [as 别名]
# 或者: from data_utils import EOS_ID [as 别名]
def decode():
  with tf.Session() as sess:
    # Create model and load parameters.
    model = create_model(sess, True)
    model.batch_size = 1  # We decode one sentence at a time.

    # Load vocabularies.
    en_vocab_path = os.path.join(FLAGS.data_dir,
                                 "vocab%d.from" % FLAGS.from_vocab_size)
    fr_vocab_path = os.path.join(FLAGS.data_dir,
                                 "vocab%d.to" % FLAGS.to_vocab_size)
    en_vocab, _ = data_utils.initialize_vocabulary(en_vocab_path)
    _, rev_fr_vocab = data_utils.initialize_vocabulary(fr_vocab_path)

    # Decode from standard input.
    sys.stdout.write("> ")
    sys.stdout.flush()
    sentence = sys.stdin.readline()
    while sentence:
      # Get token-ids for the input sentence.
      token_ids = data_utils.sentence_to_token_ids(tf.compat.as_bytes(sentence), en_vocab)
      # Which bucket does it belong to?
      bucket_id = len(_buckets) - 1
      for i, bucket in enumerate(_buckets):
        if bucket[0] >= len(token_ids):
          bucket_id = i
          break
      else:
        logging.warning("Sentence truncated: %s", sentence)

      # Get a 1-element batch to feed the sentence to the model.
      encoder_inputs, decoder_inputs, target_weights = model.get_batch(
          {bucket_id: [(token_ids, [])]}, bucket_id)
      # Get output logits for the sentence.
      _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs,
                                       target_weights, bucket_id, True)
      # This is a greedy decoder - outputs are just argmaxes of output_logits.
      outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
      # If there is an EOS symbol in outputs, cut them at that point.
      if data_utils.EOS_ID in outputs:
        outputs = outputs[:outputs.index(data_utils.EOS_ID)]
      # Print out French sentence corresponding to outputs.
      print(" ".join([tf.compat.as_str(rev_fr_vocab[output]) for output in outputs]))
      print("> ", end="")
      sys.stdout.flush()
      sentence = sys.stdin.readline() 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:48,代码来源:translate.py

示例10: decode

# 需要导入模块: import data_utils [as 别名]
# 或者: from data_utils import EOS_ID [as 别名]
def decode():
  with tf.Session() as sess:
    # Create model and load parameters.
    model = create_model(sess, True)
    model.batch_size = 1  # We decode one sentence at a time.

    # Load vocabularies.
    enc_vocab_path = os.path.join(gConfig['working_directory'],"vocab%d_enc.txt" % gConfig['enc_vocab_size'])
    dec_vocab_path = os.path.join(gConfig['working_directory'],"vocab%d_dec.txt" % gConfig['dec_vocab_size'])

    enc_vocab, _ = data_utils.initialize_vocabulary(enc_vocab_path)
    _, rev_dec_vocab = data_utils.initialize_vocabulary(dec_vocab_path)



    # Decode sentence and store it
    with open(gConfig["test_enc"], 'r') as test_enc:
        with open(gConfig["output"], 'w') as predicted_headline:
            sentence_count = 0
            for sentence in test_enc:
                # Get token-ids for the input sentence.
                token_ids = data_utils.sentence_to_token_ids(sentence, enc_vocab)
                # Which bucket does it belong to? And place the sentence to the last bucket if its token length is larger then X.
                bucket_id = min([b for b in range(len(_buckets)) if _buckets[b][0] > len(token_ids)] + [len(_buckets)-1])
                # Get a 1-element batch to feed the sentence to the model.
                encoder_inputs, decoder_inputs, target_weights = model.get_batch(
                {bucket_id: [(token_ids, [])]}, bucket_id)
                # Get output logits for the sentence.
                _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs,
                                           target_weights, bucket_id, True)

                # This is a greedy decoder - outputs are just argmaxes of output_logits.
                outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]

                # If there is an EOS symbol in outputs, cut them at that point.
                if data_utils.EOS_ID in outputs:
                    outputs = outputs[:outputs.index(data_utils.EOS_ID)]
                # Write predicted headline corresponding to article.
                predicted_headline.write(" ".join([tf.compat.as_str(rev_dec_vocab[output]) for output in outputs])+'\n')
                sentence_count += 1
                if sentence_count % 100 == 0:
                    print("predicted data line %d" % sentence_count)
                    sys.stdout.flush()

        predicted_headline.close()
    test_enc.close()

    print("Finished decoding and stored predicted results in %s!" % gConfig["output"]) 
开发者ID:hengluchang,项目名称:deep-news-summarization,代码行数:50,代码来源:execute.py

示例11: decode_input

# 需要导入模块: import data_utils [as 别名]
# 或者: from data_utils import EOS_ID [as 别名]
def decode_input():
  with tf.Session() as sess:
    # Create model and load parameters.
    model = create_model(sess, True)
    model.batch_size = 1  # We decode one sentence at a time.

    # Load vocabularies.
    enc_vocab_path = os.path.join(gConfig['working_directory'],"vocab%d_enc.txt" % gConfig['enc_vocab_size'])
    dec_vocab_path = os.path.join(gConfig['working_directory'],"vocab%d_dec.txt" % gConfig['dec_vocab_size'])

    enc_vocab, _ = data_utils.initialize_vocabulary(enc_vocab_path)
    _, rev_dec_vocab = data_utils.initialize_vocabulary(dec_vocab_path)


    # Decode from standard input.
    sys.stdout.write("> ")
    sys.stdout.flush()
    sentence = sys.stdin.readline()

    while sentence:
      # Get token-ids for the input sentence.
      token_ids = data_utils.sentence_to_token_ids(sentence, enc_vocab)
      # Which bucket does it belong to? And place the sentence to the last bucket if its token length is larger then the bucket length.
      bucket_id = min([b for b in range(len(_buckets)) if _buckets[b][0] > len(token_ids)] + [len(_buckets)-1])
      # Get a 1-element batch to feed the sentence to the model.
      encoder_inputs, decoder_inputs, target_weights = model.get_batch(
          {bucket_id: [(token_ids, [])]}, bucket_id)
      # Get output logits for the sentence.
      _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs,
                                       target_weights, bucket_id, True)
      # This is a greedy decoder - outputs are just argmaxes of output_logits.
      outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]


      # If there is an EOS symbol in outputs, cut them at that point.
      if data_utils.EOS_ID in outputs:
        outputs = outputs[:outputs.index(data_utils.EOS_ID)]
      # Print out French sentence corresponding to outputs.
      print(" ".join([tf.compat.as_str(rev_dec_vocab[output]) for output in outputs]))

      print("> ", end="")
      sys.stdout.flush()
      sentence = sys.stdin.readline() 
开发者ID:hengluchang,项目名称:deep-news-summarization,代码行数:45,代码来源:execute.py

示例12: decode

# 需要导入模块: import data_utils [as 别名]
# 或者: from data_utils import EOS_ID [as 别名]
def decode():
  with tf.Session() as sess:
    # Create model and load parameters.
    model = create_model(sess, True)
    model.batch_size = 1  # We decode one word at a time.

    # Load vocabularies.
    en_vocab_path = os.path.join(FLAGS.data_dir,
                                 "vocab%d.en" % FLAGS.en_vocab_size)
    hn_vocab_path = os.path.join(FLAGS.data_dir,
                                 "vocab%d.hn" % FLAGS.hn_vocab_size)
    en_vocab, _ = data_utils.initialize_vocabulary(en_vocab_path)
    _, rev_hn_vocab = data_utils.initialize_vocabulary(hn_vocab_path)

    # Decode from standard input.
    sys.stdout.write("> ")
    sys.stdout.flush()
    word = sys.stdin.readline()
    while word:
      word = word.lower()
      char_list_new = list(word)
      word = " ".join(char_list_new)
      # Get token-ids for the input word.
      token_ids = data_utils.word_to_token_ids(tf.compat.as_bytes(word), en_vocab)
      # Which bucket does it belong to?
      bucket_id = min([b for b in xrange(len(_buckets))
                       if _buckets[b][0] > len(token_ids)])
      # Get a 1-element batch to feed the word to the model.
      encoder_inputs, decoder_inputs, target_weights = model.get_batch(
          {bucket_id: [(token_ids, [])]}, bucket_id)
      # Get output logits for the word.
      _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs,
                                       target_weights, bucket_id, True)
      # This is a greedy decoder - outputs are just argmaxes of output_logits.
      outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
      # If there is an EOS symbol in outputs, cut them at that point.
      if data_utils.EOS_ID in outputs:
        outputs = outputs[:outputs.index(data_utils.EOS_ID)]
      # Print out Hindi word corresponding to outputs.
      print("".join([tf.compat.as_str(rev_hn_vocab[output]) for output in outputs]))
      print("> ", end="")
      sys.stdout.flush()
      word = sys.stdin.readline() 
开发者ID:dashayushman,项目名称:deep-trans,代码行数:45,代码来源:transliterate.py

示例13: decode

# 需要导入模块: import data_utils [as 别名]
# 或者: from data_utils import EOS_ID [as 别名]
def decode():
  with tf.Session() as sess:
    # Create model and load parameters.
    model = create_model(sess, True)
    model.batch_size = 1  # We decode one sentence at a time.

    # Load vocabularies.
    en_vocab_path = os.path.join(FLAGS.data_dir,
                                 "vocab%d.input" % FLAGS.input_vocab_size)
    fr_vocab_path = os.path.join(FLAGS.data_dir,
                                 "vocab%d.output" % FLAGS.output_vocab_size)
    en_vocab, _ = data_utils.initialize_vocabulary(en_vocab_path)
    _, rev_fr_vocab = data_utils.initialize_vocabulary(fr_vocab_path)

    # Decode from standard input.
    sys.stdout.write("> ")
    sys.stdout.flush()
    sentence = sys.stdin.readline()
    while sentence:
      # Get token-ids for the input sentence.
      token_ids = data_utils.sentence_to_token_ids(tf.compat.as_bytes(sentence), en_vocab)
      # Which bucket does it belong to?
      bucket_id = len(_buckets) - 1
      for i, bucket in enumerate(_buckets):
        if bucket[0] >= len(token_ids):
          bucket_id = i
          break
      else:
        logging.warning("Sentence truncated: %s", sentence)

      # Get a 1-element batch to feed the sentence to the model.
      encoder_inputs, decoder_inputs, target_weights = model.get_batch(
          {bucket_id: [(token_ids, [])]}, bucket_id)
      # Get output logits for the sentence.
      _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs,
                                       target_weights, bucket_id, True)
      # This is a greedy decoder - outputs are just argmaxes of output_logits.
      outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
      # If there is an EOS symbol in outputs, cut them at that point.
      if data_utils.EOS_ID in outputs:
        outputs = outputs[:outputs.index(data_utils.EOS_ID)]
      # Print out French sentence corresponding to outputs.
      print(" ".join([tf.compat.as_str(rev_fr_vocab[output]) for output in outputs]))
      print("> ", end="")
      sys.stdout.flush()
      sentence = sys.stdin.readline() 
开发者ID:warmheartli,项目名称:ChatBotCourse,代码行数:48,代码来源:translate.py

示例14: read_mrs_data

# 需要导入模块: import data_utils [as 别名]
# 或者: from data_utils import EOS_ID [as 别名]
def read_mrs_data(buckets, source_paths, target_paths, max_size=None,
    any_length=False, offset_target=-1):
  # Read in all files seperately.
  source_inputs = [data_utils.read_ids_file(path, max_size) 
                   for path in source_paths]
  target_inputs = [data_utils.read_ids_file(path, max_size) 
                   for path in target_paths]
  
  data_set = [[] for _ in buckets]
  data_list = []
  # Assume everything is well-aligned.
  for i in xrange(len(source_inputs[0])): # over examples
    # List of sequences of each type.
    source_ids = [source_input[i] for source_input in source_inputs]
    # Assume first target type predicts EOS.
    # Not checking pointer ranges: do that inside tf graph.
    target_ids = [target_inputs[0][i] + [data_utils.EOS_ID]]
    for j, target_input in enumerate(target_inputs[1:]):
      if offset_target > 0 and j + 1 == offset_target:
        target_ids.append([data_utils.PAD_ID] + target_input[i] 
                          + [data_utils.PAD_ID])
      else:
        target_ids.append(target_input[i] + [data_utils.PAD_ID])

    found_bucket = False
    for bucket_id, (source_size, target_size) in enumerate(buckets):
      if len(source_ids[0]) < source_size and len(target_ids[0]) < target_size:
        data_set[bucket_id].append([source_ids, target_ids])
        data_list.append([source_ids, target_ids, bucket_id])
        found_bucket = True
        break
    if any_length and not found_bucket:
      # Crop examples that are larger than the largest bucket.
      source_size, target_size = buckets[-1][0], buckets[-1][1]
      if len(source_ids[0]) >= source_size:
        source_ids = [source_id[:source_size] for source_id in source_ids]
      if len(target_ids[0]) >= target_size:
        target_ids = [target_id[:target_size] for target_id in target_ids]
      bucket_id = len(buckets) - 1
      data_set[bucket_id].append([source_ids, target_ids])
      data_list.append([source_ids, target_ids, bucket_id])
  return data_set, data_list 
开发者ID:janmbuys,项目名称:DeepDeepParser,代码行数:44,代码来源:parser.py


注:本文中的data_utils.EOS_ID属性示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。