本文整理汇总了Python中wmt_utils.EOS_ID属性的典型用法代码示例。如果您正苦于以下问题:Python wmt_utils.EOS_ID属性的具体用法?Python wmt_utils.EOS_ID怎么用?Python wmt_utils.EOS_ID使用的例子?那么恭喜您, 这里精选的属性代码示例或许可以为您提供帮助。您也可以进一步了解该属性所在类wmt_utils
的用法示例。
在下文中一共展示了wmt_utils.EOS_ID属性的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: read_data
# 需要导入模块: import wmt_utils [as 别名]
# 或者: from wmt_utils import EOS_ID [as 别名]
def read_data(source_path, target_path, buckets, max_size=None, print_out=True):
"""Read data from source and target files and put into buckets.
Args:
source_path: path to the files with token-ids for the source language.
target_path: path to the file with token-ids for the target language;
it must be aligned with the source file: n-th line contains the desired
output for n-th line from the source_path.
buckets: the buckets to use.
max_size: maximum number of lines to read, all other will be ignored;
if 0 or None, data files will be read completely (no limit).
If set to 1, no data will be returned (empty lists of the right form).
print_out: whether to print out status or not.
Returns:
data_set: a list of length len(_buckets); data_set[n] contains a list of
(source, target) pairs read from the provided data files that fit
into the n-th bucket, i.e., such that len(source) < _buckets[n][0] and
len(target) < _buckets[n][1]; source and target are lists of token-ids.
"""
data_set = [[] for _ in buckets]
counter = 0
if max_size != 1:
with tf.gfile.GFile(source_path, mode="r") as source_file:
with tf.gfile.GFile(target_path, mode="r") as target_file:
source, target = source_file.readline(), target_file.readline()
while source and target and (not max_size or counter < max_size):
counter += 1
if counter % 100000 == 0 and print_out:
print " reading data line %d" % counter
sys.stdout.flush()
source_ids = [int(x) for x in source.split()]
target_ids = [int(x) for x in target.split()]
source_ids, source_len = zero_split(source_ids)
target_ids, target_len = zero_split(target_ids, append=wmt.EOS_ID)
for bucket_id, size in enumerate(buckets):
if source_len <= size and target_len <= size:
data_set[bucket_id].append([source_ids, target_ids])
break
source, target = source_file.readline(), target_file.readline()
return data_set
示例2: score_beams
# 需要导入模块: import wmt_utils [as 别名]
# 或者: from wmt_utils import EOS_ID [as 别名]
def score_beams(beams, target, inp, history, p,
print_out=False, test_mode=False):
"""Score beams."""
if p == "progsynth":
return score_beams_prog(beams, target, inp, history, print_out, test_mode)
elif test_mode:
return beams[0], 10.0 if str(beams[0][:len(target)]) == str(target) else 0.0
else:
history_s = [str(h) for h in history]
best, best_score, tgt, eos_id = None, -1000.0, target, None
if p == "wmt":
eos_id = wmt.EOS_ID
if eos_id and eos_id in target:
tgt = target[:target.index(eos_id)]
for beam in beams:
if eos_id and eos_id in beam:
beam = beam[:beam.index(eos_id)]
l = min(len(tgt), len(beam))
score = len([i for i in xrange(l) if tgt[i] == beam[i]]) / float(len(tgt))
hist_score = 20.0 if str([b for b in beam if b > 0]) in history_s else 0.0
if score < 1.0:
score -= hist_score
if score > best_score:
best = beam
best_score = score
return best, best_score
示例3: read_data
# 需要导入模块: import wmt_utils [as 别名]
# 或者: from wmt_utils import EOS_ID [as 别名]
def read_data(source_path, target_path, buckets, max_size=None, print_out=True):
"""Read data from source and target files and put into buckets.
Args:
source_path: path to the files with token-ids for the source language.
target_path: path to the file with token-ids for the target language;
it must be aligned with the source file: n-th line contains the desired
output for n-th line from the source_path.
buckets: the buckets to use.
max_size: maximum number of lines to read, all other will be ignored;
if 0 or None, data files will be read completely (no limit).
If set to 1, no data will be returned (empty lists of the right form).
print_out: whether to print out status or not.
Returns:
data_set: a list of length len(_buckets); data_set[n] contains a list of
(source, target) pairs read from the provided data files that fit
into the n-th bucket, i.e., such that len(source) < _buckets[n][0] and
len(target) < _buckets[n][1]; source and target are lists of token-ids.
"""
data_set = [[] for _ in buckets]
counter = 0
if max_size != 1:
with tf.gfile.GFile(source_path, mode="r") as source_file:
with tf.gfile.GFile(target_path, mode="r") as target_file:
source, target = source_file.readline(), target_file.readline()
while source and target and (not max_size or counter < max_size):
counter += 1
if counter % 100000 == 0 and print_out:
print(" reading data line %d" % counter)
sys.stdout.flush()
source_ids = [int(x) for x in source.split()]
target_ids = [int(x) for x in target.split()]
source_ids, source_len = zero_split(source_ids)
target_ids, target_len = zero_split(target_ids, append=wmt.EOS_ID)
for bucket_id, size in enumerate(buckets):
if source_len <= size and target_len <= size:
data_set[bucket_id].append([source_ids, target_ids])
break
source, target = source_file.readline(), target_file.readline()
return data_set
示例4: linearize
# 需要导入模块: import wmt_utils [as 别名]
# 或者: from wmt_utils import EOS_ID [as 别名]
def linearize(output, rev_fr_vocab, simple_tokenizer=None, eos_id=wmt.EOS_ID):
# If there is an EOS symbol in outputs, cut them at that point (WMT).
if eos_id in output:
output = output[:output.index(eos_id)]
# Print out French sentence corresponding to outputs.
if simple_tokenizer or FLAGS.simple_tokenizer:
vlen = len(rev_fr_vocab)
def vget(o):
if o < vlen:
return rev_fr_vocab[o]
return "UNK"
return " ".join([vget(o) for o in output])
else:
return wmt.basic_detokenizer([rev_fr_vocab[o] for o in output])