本文整理汇总了Python中data.outputids2words方法的典型用法代码示例。如果您正苦于以下问题:Python data.outputids2words方法的具体用法?Python data.outputids2words怎么用?Python data.outputids2words使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类data
的用法示例。
在下文中一共展示了data.outputids2words方法的11个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: decode
# 需要导入模块: import data [as 别名]
# 或者: from data import outputids2words [as 别名]
def decode(self):
"""Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
# t0 = time.time()
batch = self._batcher.next_batch() # 1 example repeated across batch
original_article = batch.original_articles[0] # string
original_abstract = batch.original_abstracts[0] # string
# input data
article_withunks = data.show_art_oovs(original_article, self._vocab) # string
abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string
# Run beam search to get best Hypothesis
best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch)
# Extract the output ids from the hypothesis and convert back to words
output_ids = [int(t) for t in best_hyp.tokens[1:]]
decoded_words = data.outputids2words(output_ids, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None))
# Remove the [STOP] token from decoded_words, if necessary
try:
fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
decoded_words = decoded_words[:fst_stop_idx]
except ValueError:
decoded_words = decoded_words
decoded_output = ' '.join(decoded_words) # single string
# tf.logging.info('ARTICLE: %s', article)
# tf.logging.info('GENERATED SUMMARY: %s', decoded_output)
sys.stdout.write(decoded_output)
示例2: compute_reward
# 需要导入模块: import data [as 别名]
# 或者: from data import outputids2words [as 别名]
def compute_reward(batch, decode_batch, vocab, mode, use_cuda):
target_sents = batch.original_abstracts # list of string
decode_batch = decode_batch.cpu().numpy() # B x S x L
output_ids = decode_batch[:, :, 1:]
all_rewards = torch.zeros((config.batch_size, config.sample_size)) # B x S
if use_cuda: all_rewards = all_rewards.cuda()
for i in range(config.batch_size):
for j in range(config.sample_size):
words = data.outputids2words(list(output_ids[i,j,:]), vocab,
(batch.art_oovs[i] if config.pointer_gen else None))
# Remove the [STOP] token from decoded_words, if necessary
try:
fst_stop_idx = words.index(data.STOP_DECODING)
words = words[:fst_stop_idx]
except ValueError:
words = words
decode_sent = ' '.join(words)
all_rewards[i, j] = rouge_2(target_sents[i], decode_sent)
batch_avg_reward = torch.mean(all_rewards, dim=1, keepdim=True) # B x 1
ones = torch.ones((config.batch_size, config.sample_size))
if use_cuda: ones = ones.cuda()
if mode == 'MLE':
return ones, torch.zeros(1)
else:
batch_avg_reward = batch_avg_reward * ones # B x S
if torch.equal(all_rewards, batch_avg_reward):
all_rewards = all_rewards
else:
all_rewards = all_rewards - batch_avg_reward
for i in range(config.batch_size):
for j in range(config.sample_size):
if all_rewards[i,j] < 0:
all_rewards[i, j] = 0
return all_rewards, batch_avg_reward.mean()
示例3: process_one_article
# 需要导入模块: import data [as 别名]
# 或者: from data import outputids2words [as 别名]
def process_one_article(self, original_article_sents, original_abstract_sents, \
original_selected_ids, output_ids, oovs, attn_dists_norescale, \
attn_dists, p_gens, log_probs, sent_probs, counter):
# Remove the [STOP] token from decoded_words, if necessary
decoded_words = data.outputids2words(output_ids, self._vocab, oovs)
try:
fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
decoded_words = decoded_words[:fst_stop_idx]
except ValueError:
decoded_words = decoded_words
decoded_output = ' '.join(decoded_words) # single string
decoded_sents = data.words2sents(decoded_words)
if FLAGS.single_pass:
verbose = False if FLAGS.mode == 'eval' else True
self.write_for_rouge(original_abstract_sents, decoded_sents, counter, verbose) # write ref summary and decoded summary to file, to eval with pyrouge later
if FLAGS.decode_method == 'beam' and FLAGS.save_vis:
sent_probs_per_word = []
for sent_id, sent in enumerate(original_article_sents):
sent_len = len(sent.split(' '))
for _ in range(sent_len):
if sent_id < FLAGS.max_art_len:
sent_probs_per_word.append(sent_probs[sent_id])
else:
sent_probs_per_word.append(0)
original_article = ' '.join(original_article_sents)
original_abstract = ' '.join(original_abstract_sents)
article_withunks = data.show_art_oovs(original_article, self._vocab) # string
abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, oovs)
self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, attn_dists_norescale, \
attn_dists, p_gens, log_probs, sent_probs_per_word, counter, verbose)
if FLAGS.save_pkl:
self.save_result(original_article_sents, original_abstract_sents, \
original_selected_ids, decoded_sents, counter, verbose)
示例4: process_one_article
# 需要导入模块: import data [as 别名]
# 或者: from data import outputids2words [as 别名]
def process_one_article(self, original_article_sents, original_abstract_sents, \
original_selected_ids, output_ids, oovs, \
attn_dists, p_gens, log_probs, counter):
# Remove the [STOP] token from decoded_words, if necessary
decoded_words = data.outputids2words(output_ids, self._vocab, oovs)
try:
fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
decoded_words = decoded_words[:fst_stop_idx]
except ValueError:
decoded_words = decoded_words
decoded_output = ' '.join(decoded_words) # single string
decoded_sents = data.words2sents(decoded_words)
if FLAGS.single_pass:
verbose = False if FLAGS.mode == 'eval' else True
self.write_for_rouge(original_abstract_sents, decoded_sents, counter, verbose) # write ref summary and decoded summary to file, to eval with pyrouge later
if FLAGS.decode_method == 'beam' and FLAGS.save_vis:
original_article = ' '.join(original_article_sents)
original_abstract = ' '.join(original_abstract_sents)
article_withunks = data.show_art_oovs(original_article, self._vocab) # string
abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, oovs)
self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, \
attn_dists, p_gens, log_probs, counter, verbose)
if FLAGS.save_pkl:
self.save_result(original_article_sents, original_abstract_sents, \
original_selected_ids, decoded_sents, counter, verbose)
示例5: decode
# 需要导入模块: import data [as 别名]
# 或者: from data import outputids2words [as 别名]
def decode(self):
"""Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
t0 = time.time()
counter = FLAGS.decode_after
while True:
tf.reset_default_graph()
batch = self._batcher.next_batch() # 1 example repeated across batch
if batch is None: # finished decoding dataset in single_pass mode
assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
tf.logging.info("Decoder has finished reading dataset for single_pass.")
tf.logging.info("Output has been saved in %s and %s. Now starting ROUGE eval...", self._rouge_ref_dir, self._rouge_dec_dir)
results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
rouge_log(results_dict, self._decode_dir)
return
original_article = batch.original_articles[0] # string
original_abstract = batch.original_abstracts[0] # string
original_abstract_sents = batch.original_abstracts_sents[0] # list of strings
if len(original_abstract_sents) == 0:
print("NOOOOO!!!!, An empty abstract :(")
continue
article_withunks = data.show_art_oovs(original_article, self._vocab) # string
abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string
# Run beam search to get best Hypothesis
if FLAGS.ac_training:
best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch, self._dqn, self._dqn_sess, self._dqn_graph)
else:
best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch)
# Extract the output ids from the hypothesis and convert back to words
output_ids = [int(t) for t in best_hyp.tokens[1:]]
decoded_words = data.outputids2words(output_ids, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None))
# Remove the [STOP] token from decoded_words, if necessary
try:
fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
decoded_words = decoded_words[:fst_stop_idx]
except ValueError:
decoded_words = decoded_words
decoded_output = ' '.join(decoded_words) # single string
if FLAGS.single_pass:
self.write_for_rouge(original_abstract_sents, decoded_words, counter) # write ref summary and decoded summary to file, to eval with pyrouge later
counter += 1 # this is how many examples we've decoded
else:
print_results(article_withunks, abstract_withunks, decoded_output) # log output to screen
self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens) # write info to .json file for visualization tool
# Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
t1 = time.time()
if t1-t0 > SECS_UNTIL_NEW_CKPT:
tf.logging.info('We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint', t1-t0)
_ = util.load_ckpt(self._saver, self._sess, FLAGS.decode_from)
t0 = time.time()
示例6: decode
# 需要导入模块: import data [as 别名]
# 或者: from data import outputids2words [as 别名]
def decode(self):
"""Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
t0 = time.time()
counter = FLAGS.decode_after
while True:
tf.reset_default_graph()
batch = self._batcher.next_batch() # 1 example repeated across batch
if batch is None: # finished decoding dataset in single_pass mode
assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
tf.logging.info("Decoder has finished reading dataset for single_pass.")
tf.logging.info("Output has been saved in %s and %s. Now starting ROUGE eval...", self._rouge_ref_dir, self._rouge_dec_dir)
results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
rouge_log(results_dict, self._decode_dir)
return
original_article = batch.original_articles[0] # string
original_abstract = batch.original_abstracts[0] # string
original_abstract_sents = batch.original_abstracts_sents[0] # list of strings
article_withunks = data.show_art_oovs(original_article, self._vocab) # string
abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string
# Run beam search to get best Hypothesis
if FLAGS.ac_training:
best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch, self._dqn, self._dqn_sess, self._dqn_graph)
else:
best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch)
# Extract the output ids from the hypothesis and convert back to words
output_ids = [int(t) for t in best_hyp.tokens[1:]]
decoded_words = data.outputids2words(output_ids, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None))
# Remove the [STOP] token from decoded_words, if necessary
try:
fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
decoded_words = decoded_words[:fst_stop_idx]
except ValueError:
decoded_words = decoded_words
decoded_output = ' '.join(decoded_words) # single string
if FLAGS.single_pass:
self.write_for_rouge(original_abstract_sents, decoded_words, counter) # write ref summary and decoded summary to file, to eval with pyrouge later
counter += 1 # this is how many examples we've decoded
else:
print_results(article_withunks, abstract_withunks, decoded_output) # log output to screen
self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens) # write info to .json file for visualization tool
# Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
t1 = time.time()
if t1-t0 > SECS_UNTIL_NEW_CKPT:
tf.logging.info('We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint', t1-t0)
_ = util.load_ckpt(self._saver, self._sess, FLAGS.decode_from)
t0 = time.time()
示例7: decode
# 需要导入模块: import data [as 别名]
# 或者: from data import outputids2words [as 别名]
def decode(self):
start = time.time()
counter = 0
bleu_scores = []
batch = self.batcher.next_batch()
while batch is not None:
# Run beam search to get best Hypothesis
best_summary = self.beam_search(batch)
# Extract the output ids from the hypothesis and convert back to words
output_ids = [int(t) for t in best_summary.tokens[1:]]
decoded_words = data.outputids2words(output_ids, self.vocab,
(batch.art_oovs[0] if config.pointer_gen else None))
# Remove the [STOP] token from decoded_words, if necessary
try:
fst_stop_idx = decoded_words.index(data.STOP_DECODING)
decoded_words = decoded_words[:fst_stop_idx]
except ValueError:
decoded_words = decoded_words
original_abstracts = batch.original_abstracts_sents[0]
reference = original_abstracts[0].strip().split()
bleu = nltk.translate.bleu_score.sentence_bleu([reference], decoded_words, weights = (0.5, 0.5))
bleu_scores.append(bleu)
write_for_rouge(original_abstracts, decoded_words, counter,
self._rouge_ref_dir, self._rouge_dec_dir)
counter += 1
if counter % 1000 == 0:
print('%d example in %d sec'%(counter, time.time() - start))
start = time.time()
batch = self.batcher.next_batch()
print('Average BLEU score:', np.mean(bleu_scores))
'''
# uncomment this if you successfully install `pyrouge`
print("Decoder has finished reading dataset for single_pass.")
print("Now starting ROUGE eval...")
results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
rouge_log(results_dict, self._decode_dir)
'''
示例8: decode
# 需要导入模块: import data [as 别名]
# 或者: from data import outputids2words [as 别名]
def decode(self):
"""Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
t0 = time.time()
counter = 0
all_decoded = {} # a dictionary keeping the decoded files to be written for visualization
while True:
batch = self._batcher.next_batch() # 1 example repeated across batch
if batch is None: # finished decoding dataset in single_pass mode
assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
tf.logging.info("Decoder has finished reading dataset for single_pass.")
tf.logging.info("Output has been saved in %s and %s. Now starting ROUGE eval...", self._rouge_ref_dir, self._rouge_dec_dir)
results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
rouge_log(results_dict, self._decode_dir)
if FLAGS.single_pass:
self.write_all_for_attnvis(all_decoded)
return
original_article = batch.original_articles[0] # string
original_abstract = batch.original_abstracts[0] # string
original_abstract_sents = batch.original_abstracts_sents[0] # list of strings
article_id = batch.article_ids[0] #string
article_withunks = data.show_art_oovs(original_article, self._vocab) # string
abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string
# Run beam search to get best Hypothesis
# import pdb; pdb.set_trace()
best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch)
# Extract the output ids from the hypothesis and convert back to words
output_ids = [int(t) for t in best_hyp.tokens[1:]]
decoded_words = data.outputids2words(output_ids, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None))
# Remove the [STOP] token from decoded_words, if necessary
try:
fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
decoded_words = decoded_words[:fst_stop_idx]
except ValueError:
decoded_words = decoded_words
decoded_output = ' '.join(decoded_words) # single string
if FLAGS.single_pass:
self.write_for_rouge(original_abstract_sents, decoded_words, article_id) # write ref summary and decoded summary to file, to eval with pyrouge later
print_results(article_withunks, abstract_withunks, decoded_output, article_id) # log output to screen
all_decoded[article_id] = self.prepare_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens, best_hyp.attn_dists_sec)
counter += 1 # this is how many examples we've decoded
self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens, best_hyp.attn_dists_sec) # write info to .json file for visualization tool
else:
print_results(article_withunks, abstract_withunks, decoded_output, article_id) # log output to screen
self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens, best_hyp.attn_dists_sec) # write info to .json file for visualization tool
# Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
t1 = time.time()
if t1-t0 > SECS_UNTIL_NEW_CKPT:
tf.logging.info('We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint', t1-t0)
_ = util.load_ckpt(self._saver, self._sess)
t0 = time.time()
示例9: generator_whole_negative_example
# 需要导入模块: import data [as 别名]
# 或者: from data import outputids2words [as 别名]
def generator_whole_negative_example(self):
counter = 0
step = 0
t0 = time.time()
batches = self.batches
while step < 1000:
batch = batches[step]
step += 1
decode_result = self._model.run_eval_given_step(self._sess, batch)
for i in range(FLAGS.batch_size):
decoded_words_all = []
original_review = batch.original_review_output[i] # string
for j in range(FLAGS.max_dec_sen_num):
output_ids = [int(t) for t in decode_result['generated'][i][j]][1:]
decoded_words = data.outputids2words(output_ids, self._vocab, None)
# Remove the [STOP] token from decoded_words, if necessary
try:
fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
decoded_words = decoded_words[:fst_stop_idx]
except ValueError:
decoded_words = decoded_words
if len(decoded_words)<2:
continue
if len(decoded_words_all)>0:
new_set1 =set(decoded_words_all[len(decoded_words_all)-1].split())
new_set2= set(decoded_words)
if len(new_set1 & new_set2) > 0.5 * len(new_set2):
continue
if decoded_words[-1] !='.' and decoded_words[-1] !='!' and decoded_words[-1] !='?':
decoded_words.append('.')
decoded_output = ' '.join(decoded_words).strip() # single string
decoded_words_all.append(decoded_output)
decoded_words_all = ' '.join(decoded_words_all).strip()
try:
fst_stop_idx = decoded_words_all.index(
data.STOP_DECODING_DOCUMENT) # index of the (first) [STOP] symbol
decoded_words_all = decoded_words_all[:fst_stop_idx]
except ValueError:
decoded_words_all = decoded_words_all
decoded_words_all = decoded_words_all.replace("[UNK] ", "")
decoded_words_all = decoded_words_all.replace("[UNK]", "")
decoded_words_all, _ = re.subn(r"(! ){2,}", "", decoded_words_all)
decoded_words_all, _ = re.subn(r"(\. ){2,}", "", decoded_words_all)
self.write_negtive_to_json(original_review, decoded_words_all, counter, self.train_sample_whole_positive_dir, self.train_sample_whole_negative_dir)
counter += 1 # this is how many examples we've decoded
示例10: generator_test_negative_example
# 需要导入模块: import data [as 别名]
# 或者: from data import outputids2words [as 别名]
def generator_test_negative_example(self):
counter = 0
step = 0
t0 = time.time()
batches = self.test_batches
while step < 100:
step += 1
batch = batches[step]
decode_result =self._model.run_eval_given_step(self._sess, batch)
for i in range(FLAGS.batch_size):
decoded_words_all = []
original_review = batch.original_review_output[i] # string
for j in range(FLAGS.max_dec_sen_num):
output_ids = [int(t) for t in decode_result['generated'][i][j]][1:]
decoded_words = data.outputids2words(output_ids, self._vocab, None)
# Remove the [STOP] token from decoded_words, if necessary
try:
fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
decoded_words = decoded_words[:fst_stop_idx]
except ValueError:
decoded_words = decoded_words
if len(decoded_words)<2:
continue
if len(decoded_words_all)>0:
new_set1 =set(decoded_words_all[len(decoded_words_all)-1].split())
new_set2= set(decoded_words)
if len(new_set1 & new_set2) > 0.5 * len(new_set2):
continue
if decoded_words[-1] !='.' and decoded_words[-1] !='!' and decoded_words[-1] !='?':
decoded_words.append('.')
decoded_output = ' '.join(decoded_words).strip() # single string
decoded_words_all.append(decoded_output)
decoded_words_all = ' '.join(decoded_words_all).strip()
try:
fst_stop_idx = decoded_words_all.index(
data.STOP_DECODING_DOCUMENT) # index of the (first) [STOP] symbol
decoded_words_all = decoded_words_all[:fst_stop_idx]
except ValueError:
decoded_words_all = decoded_words_all
decoded_words_all = decoded_words_all.replace("[UNK] ", "")
decoded_words_all = decoded_words_all.replace("[UNK]", "")
decoded_words_all, _ = re.subn(r"(! ){2,}", "", decoded_words_all)
decoded_words_all, _ = re.subn(r"(\. ){2,}", "", decoded_words_all)
self.write_negtive_to_json(original_review, decoded_words_all, counter, self.test_sample_whole_positive_dir,self.test_sample_whole_negative_dir)
counter += 1 # this is how many examples we've decoded
示例11: decode
# 需要导入模块: import data [as 别名]
# 或者: from data import outputids2words [as 别名]
def decode(self):
"""Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
t0 = time.time()
counter = 0
while True:
batch = self._batcher.next_batch() # 1 example repeated across batch
if batch is None: # finished decoding dataset in single_pass mode
assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
tf.logging.info("Decoder has finished reading dataset for single_pass.")
tf.logging.info("Output has been saved in %s and %s. Now starting ROUGE eval...", self._rouge_ref_dir, self._rouge_dec_dir)
results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
rouge_log(results_dict, self._decode_dir)
return
original_article = batch.original_articles[0] # string
original_abstract = batch.original_abstracts[0] # string
original_abstract_sents = batch.original_abstracts_sents[0] # list of strings
article_withunks = data.show_art_oovs(original_article, self._vocab) # string
abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string
# Run beam search to get best Hypothesis
best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch)
# Extract the output ids from the hypothesis and convert back to words
output_ids = [int(t) for t in best_hyp.tokens[1:]]
decoded_words = data.outputids2words(output_ids, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None))
# Remove the [STOP] token from decoded_words, if necessary
try:
fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
decoded_words = decoded_words[:fst_stop_idx]
except ValueError:
decoded_words = decoded_words
decoded_output = ' '.join(decoded_words) # single string
if FLAGS.single_pass:
self.write_for_rouge(original_abstract_sents, decoded_words, counter) # write ref summary and decoded summary to file, to eval with pyrouge later
counter += 1 # this is how many examples we've decoded
else:
print_results(article_withunks, abstract_withunks, decoded_output) # log output to screen
self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens) # write info to .json file for visualization tool
# Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
t1 = time.time()
if t1-t0 > SECS_UNTIL_NEW_CKPT:
tf.logging.info('We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint', t1-t0)
_ = util.load_ckpt(self._saver, self._sess)
t0 = time.time()