当前位置: 首页>>代码示例>>Python>>正文


Python data.STOP_DECODING属性代码示例

本文整理汇总了Python中data.STOP_DECODING属性的典型用法代码示例。如果您正苦于以下问题:Python data.STOP_DECODING属性的具体用法?Python data.STOP_DECODING怎么用?Python data.STOP_DECODING使用的例子?那么, 这里精选的属性代码示例或许可以为您提供帮助。您也可以进一步了解该属性所在data的用法示例。


在下文中一共展示了data.STOP_DECODING属性的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: decode

# 需要导入模块: import data [as 别名]
# 或者: from data import STOP_DECODING [as 别名]
def decode(self):
    """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
    # t0 = time.time()
    batch = self._batcher.next_batch()  # 1 example repeated across batch

    original_article = batch.original_articles[0]  # string
    original_abstract = batch.original_abstracts[0]  # string

    # input data
    article_withunks = data.show_art_oovs(original_article, self._vocab) # string
    abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string

    # Run beam search to get best Hypothesis
    best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch)

    # Extract the output ids from the hypothesis and convert back to words
    output_ids = [int(t) for t in best_hyp.tokens[1:]]
    decoded_words = data.outputids2words(output_ids, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None))

    # Remove the [STOP] token from decoded_words, if necessary
    try:
      fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
      decoded_words = decoded_words[:fst_stop_idx]
    except ValueError:
      decoded_words = decoded_words
    decoded_output = ' '.join(decoded_words) # single string

    # tf.logging.info('ARTICLE:  %s', article)
    #  tf.logging.info('GENERATED SUMMARY: %s', decoded_output)

    sys.stdout.write(decoded_output) 
开发者ID:IBM,项目名称:MAX-Text-Summarizer,代码行数:33,代码来源:decode.py

示例2: compute_reward

# 需要导入模块: import data [as 别名]
# 或者: from data import STOP_DECODING [as 别名]
def compute_reward(batch, decode_batch, vocab, mode, use_cuda):
    target_sents = batch.original_abstracts  # list of string
    decode_batch = decode_batch.cpu().numpy()  # B x S x L
    output_ids = decode_batch[:, :, 1:]
    all_rewards = torch.zeros((config.batch_size, config.sample_size)) # B x S
    if use_cuda: all_rewards = all_rewards.cuda()
    for i in range(config.batch_size):
        for j in range(config.sample_size):
            words = data.outputids2words(list(output_ids[i,j,:]), vocab,
                                       (batch.art_oovs[i] if config.pointer_gen else None))
            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = words.index(data.STOP_DECODING)
                words = words[:fst_stop_idx]
            except ValueError:
                words = words
            decode_sent = ' '.join(words)
            all_rewards[i, j] = rouge_2(target_sents[i], decode_sent)
    batch_avg_reward = torch.mean(all_rewards, dim=1, keepdim=True)  # B x 1
    
    ones = torch.ones((config.batch_size, config.sample_size))
    if use_cuda: ones = ones.cuda()
    if mode == 'MLE':
        return ones, torch.zeros(1)
    else:
        batch_avg_reward = batch_avg_reward * ones   # B x S
        if torch.equal(all_rewards, batch_avg_reward):
            all_rewards = all_rewards
        else:
            all_rewards = all_rewards - batch_avg_reward
        
        for i in range(config.batch_size):
            for j in range(config.sample_size):
                if all_rewards[i,j] < 0:
                    all_rewards[i, j] = 0
        return all_rewards, batch_avg_reward.mean() 
开发者ID:wyu-du,项目名称:Reinforce-Paraphrase-Generation,代码行数:38,代码来源:train_util.py

示例3: __init__

# 需要导入模块: import data [as 别名]
# 或者: from data import STOP_DECODING [as 别名]
def __init__(self, article, abstract_sentences, vocab):
    # Get ids of special tokens
    start_decoding = vocab.word2id(data.START_DECODING)
    stop_decoding = vocab.word2id(data.STOP_DECODING)

    # Process the article
    article_words = article.split()
    if len(article_words) > config.max_enc_steps:
      article_words = article_words[:config.max_enc_steps]
    self.enc_len = len(article_words) # store the length after truncation but before padding
    self.enc_input = [vocab.word2id(w) for w in article_words] # list of word ids; OOVs are represented by the id for UNK token
    
    # Process the abstract
    abstract = ' '.join(abstract_sentences)
    abstract_words = abstract.split() # list of strings
    abs_ids = [vocab.word2id(w) for w in abstract_words] # list of word ids; OOVs are represented by the id for UNK token
    
    # Get the decoder input sequence and target sequence
    self.dec_input, self.target = self.get_dec_inp_targ_seqs(abs_ids, config.max_dec_steps, start_decoding, stop_decoding)
    self.dec_len = len(self.dec_input)

    # If using pointer-generator mode, we need to store some extra info
    if config.pointer_gen:
      # Store a version of the enc_input where in-article OOVs are represented by their temporary OOV id; also store the in-article OOVs words themselves
      self.enc_input_extend_vocab, self.article_oovs = data.article2ids(article_words, vocab)

      # Get a verison of the reference summary where in-article OOVs are represented by their temporary article OOV id
      abs_ids_extend_vocab = data.abstract2ids(abstract_words, vocab, self.article_oovs)

      # Overwrite decoder target sequence so it uses the temp article OOV ids
      # NOTE: dec_input does not contain article OOV ids!!!!
      _, self.target = self.get_dec_inp_targ_seqs(abs_ids_extend_vocab, config.max_dec_steps, start_decoding, stop_decoding)

    # Store the original strings
    self.original_article = article
    self.original_abstract = abstract
    self.original_abstract_sents = abstract_sentences 
开发者ID:wyu-du,项目名称:Reinforce-Paraphrase-Generation,代码行数:39,代码来源:batcher.py

示例4: __init__

# 需要导入模块: import data [as 别名]
# 或者: from data import STOP_DECODING [as 别名]
def __init__(self, article, abstract_sentences, vocab):
    # Get ids of special tokens
    start_decoding = vocab.word2id(data.START_DECODING)
    stop_decoding = vocab.word2id(data.STOP_DECODING)

    # Process the article
    article_words = article.split()
    if len(article_words) > config.max_enc_steps:
      article_words = article_words[:config.max_enc_steps]
    self.enc_len = len(article_words) # store the length after truncation but before padding
    self.enc_input = [vocab.word2id(w) for w in article_words] # list of word ids; OOVs are represented by the id for UNK token

    # Process the abstract
    abstract = ' '.join(abstract_sentences) # string
    abstract_words = abstract.split() # list of strings
    abs_ids = [vocab.word2id(w) for w in abstract_words] # list of word ids; OOVs are represented by the id for UNK token

    # Get the decoder input sequence and target sequence
    self.dec_input, self.target = self.get_dec_inp_targ_seqs(abs_ids, config.max_dec_steps, start_decoding, stop_decoding)
    self.dec_len = len(self.dec_input)

    # If using pointer-generator mode, we need to store some extra info
    if config.pointer_gen:
      # Store a version of the enc_input where in-article OOVs are represented by their temporary OOV id; also store the in-article OOVs words themselves
      self.enc_input_extend_vocab, self.article_oovs = data.article2ids(article_words, vocab)

      # Get a verison of the reference summary where in-article OOVs are represented by their temporary article OOV id
      abs_ids_extend_vocab = data.abstract2ids(abstract_words, vocab, self.article_oovs)

      # Overwrite decoder target sequence so it uses the temp article OOV ids
      _, self.target = self.get_dec_inp_targ_seqs(abs_ids_extend_vocab, config.max_dec_steps, start_decoding, stop_decoding)

    # Store the original strings
    self.original_article = article
    self.original_abstract = abstract
    self.original_abstract_sents = abstract_sentences 
开发者ID:atulkum,项目名称:pointer_summarizer,代码行数:38,代码来源:batcher.py

示例5: __init__

# 需要导入模块: import data [as 别名]
# 或者: from data import STOP_DECODING [as 别名]
def __init__(self, review, label, vocab, hps):

    start_decoding = vocab.word2id(data.START_DECODING)
    stop_decoding = vocab.word2id(data.STOP_DECODING)
    review_sentenc_orig = []


    self.hps = hps
    self.label = label

    #abstract_sentences = [x.strip() for x in abstract_sentences]
    article_sens = sent_tokenize(review)

    article_words = []
    for i in range(len(article_sens)):
        if i >= hps.max_enc_sen_num:
            article_words = article_words[:hps.max_enc_sen_num]
            review_sentenc_orig = review_sentenc_orig[:hps.max_enc_sen_num]
            break
        article_sen = article_sens[i]
        article_sen_words = article_sen.split()
        if len(article_sen_words) > hps.max_enc_seq_len:
            article_sen_words = article_sen_words[:hps.max_enc_seq_len]
        article_words.append(article_sen_words)
        review_sentenc_orig.append(article_sens[i])


    # Process the abstract
    #abstract = ' '.join(abstract_sentences)  # string
    # abstract_words = abstract.split() # list of strings
    abs_ids = [[vocab.word2id(w) for w in sen] for sen in
               article_words]  # list of word ids; OOVs are represented by the id for UNK token

    # Get the decoder input sequence and target sequence
    self.dec_input, self.target = self.get_dec_inp_targ_seqs(abs_ids, hps.max_enc_sen_num, hps.max_enc_seq_len,
                                                              start_decoding,
                                                             stop_decoding)  # max_sen_num,max_len, start_doc_id, end_doc_id,start_id, stop_id
    self.dec_len = len(self.dec_input)
    self.dec_sen_len = [len(sentence) for sentence in self.target]

    self.original_reivew = review_sentenc_orig 
开发者ID:loretoparisi,项目名称:docker,代码行数:43,代码来源:batcher_discriminator.py

示例6: process_one_article

# 需要导入模块: import data [as 别名]
# 或者: from data import STOP_DECODING [as 别名]
def process_one_article(self, original_article_sents, original_abstract_sents, \
                          original_selected_ids, output_ids, oovs, attn_dists_norescale, \
                          attn_dists, p_gens, log_probs, sent_probs, counter):
    # Remove the [STOP] token from decoded_words, if necessary
    decoded_words = data.outputids2words(output_ids, self._vocab, oovs)
    try:
      fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
      decoded_words = decoded_words[:fst_stop_idx]
    except ValueError:
      decoded_words = decoded_words
    decoded_output = ' '.join(decoded_words) # single string
    decoded_sents = data.words2sents(decoded_words)

    if FLAGS.single_pass:
      verbose = False if FLAGS.mode == 'eval' else True
      self.write_for_rouge(original_abstract_sents, decoded_sents, counter, verbose) # write ref summary and decoded summary to file, to eval with pyrouge later
      if FLAGS.decode_method == 'beam' and FLAGS.save_vis:
        sent_probs_per_word = []
        for sent_id, sent in enumerate(original_article_sents):
          sent_len = len(sent.split(' '))
          for _ in range(sent_len):
            if sent_id < FLAGS.max_art_len:
              sent_probs_per_word.append(sent_probs[sent_id])
            else:
              sent_probs_per_word.append(0)
        original_article = ' '.join(original_article_sents)
        original_abstract = ' '.join(original_abstract_sents)
        article_withunks = data.show_art_oovs(original_article, self._vocab) # string
        abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, oovs)
        self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, attn_dists_norescale, \
                               attn_dists, p_gens, log_probs, sent_probs_per_word, counter, verbose)
      if FLAGS.save_pkl:
        self.save_result(original_article_sents, original_abstract_sents, \
                         original_selected_ids, decoded_sents, counter, verbose) 
开发者ID:HsuWanTing,项目名称:unified-summarization,代码行数:36,代码来源:evaluate.py

示例7: process_one_article

# 需要导入模块: import data [as 别名]
# 或者: from data import STOP_DECODING [as 别名]
def process_one_article(self, original_article_sents, original_abstract_sents, \
                          original_selected_ids, output_ids, oovs, \
                          attn_dists, p_gens, log_probs, counter):
    # Remove the [STOP] token from decoded_words, if necessary
    decoded_words = data.outputids2words(output_ids, self._vocab, oovs)
    try:
      fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
      decoded_words = decoded_words[:fst_stop_idx]
    except ValueError:
      decoded_words = decoded_words
    decoded_output = ' '.join(decoded_words) # single string
    decoded_sents = data.words2sents(decoded_words)

    if FLAGS.single_pass:
      verbose = False if FLAGS.mode == 'eval' else True
      self.write_for_rouge(original_abstract_sents, decoded_sents, counter, verbose) # write ref summary and decoded summary to file, to eval with pyrouge later
      if FLAGS.decode_method == 'beam' and FLAGS.save_vis:
        original_article = ' '.join(original_article_sents)
        original_abstract = ' '.join(original_abstract_sents)
        article_withunks = data.show_art_oovs(original_article, self._vocab) # string
        abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, oovs)
        self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, \
                               attn_dists, p_gens, log_probs, counter, verbose)
      if FLAGS.save_pkl:
        self.save_result(original_article_sents, original_abstract_sents, \
                         original_selected_ids, decoded_sents, counter, verbose) 
开发者ID:HsuWanTing,项目名称:unified-summarization,代码行数:28,代码来源:decode.py

示例8: decode

# 需要导入模块: import data [as 别名]
# 或者: from data import STOP_DECODING [as 别名]
def decode(self):
    """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
    t0 = time.time()
    counter = FLAGS.decode_after
    while True:
      tf.reset_default_graph()
      batch = self._batcher.next_batch()  # 1 example repeated across batch
      if batch is None: # finished decoding dataset in single_pass mode
        assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
        tf.logging.info("Decoder has finished reading dataset for single_pass.")
        tf.logging.info("Output has been saved in %s and %s. Now starting ROUGE eval...", self._rouge_ref_dir, self._rouge_dec_dir)
        results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        rouge_log(results_dict, self._decode_dir)
        return

      original_article = batch.original_articles[0]  # string
      original_abstract = batch.original_abstracts[0]  # string
      original_abstract_sents = batch.original_abstracts_sents[0]  # list of strings
      if len(original_abstract_sents) == 0:
        print("NOOOOO!!!!, An empty abstract :(")
        continue

      article_withunks = data.show_art_oovs(original_article, self._vocab) # string
      abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string

      # Run beam search to get best Hypothesis
      if FLAGS.ac_training:
        best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch, self._dqn, self._dqn_sess, self._dqn_graph)
      else:
        best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch)
      # Extract the output ids from the hypothesis and convert back to words
      output_ids = [int(t) for t in best_hyp.tokens[1:]]
      decoded_words = data.outputids2words(output_ids, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None))

      # Remove the [STOP] token from decoded_words, if necessary
      try:
        fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
        decoded_words = decoded_words[:fst_stop_idx]
      except ValueError:
        decoded_words = decoded_words
      decoded_output = ' '.join(decoded_words) # single string

      if FLAGS.single_pass:
        self.write_for_rouge(original_abstract_sents, decoded_words, counter) # write ref summary and decoded summary to file, to eval with pyrouge later
        counter += 1 # this is how many examples we've decoded
      else:
        print_results(article_withunks, abstract_withunks, decoded_output) # log output to screen
        self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens) # write info to .json file for visualization tool

        # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
        t1 = time.time()
        if t1-t0 > SECS_UNTIL_NEW_CKPT:
          tf.logging.info('We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint', t1-t0)
          _ = util.load_ckpt(self._saver, self._sess, FLAGS.decode_from)
          t0 = time.time() 
开发者ID:yaserkl,项目名称:TransferRL,代码行数:57,代码来源:decode.py

示例9: __init__

# 需要导入模块: import data [as 别名]
# 或者: from data import STOP_DECODING [as 别名]
def __init__(self, article, abstract_sentences, vocab, hps):
    """Initializes the Example, performing tokenization and truncation to produce the encoder, decoder and target sequences, which are stored in self.

    Args:
      article: source text; a string. each token is separated by a single space.
      abstract_sentences: list of strings, one per abstract sentence. In each sentence, each token is separated by a single space.
      vocab: Vocabulary object
      hps: hyperparameters
    """
    self.hps = hps

    # Get ids of special tokens
    start_decoding = vocab.word2id(data.START_DECODING)
    stop_decoding = vocab.word2id(data.STOP_DECODING)

    # Process the article
    article_words = article.split()
    if len(article_words) > hps.max_enc_steps:
      article_words = article_words[:hps.max_enc_steps]
    self.enc_len = len(article_words) # store the length after truncation but before padding
    self.enc_input = [vocab.word2id(w) for w in article_words] # list of word ids; OOVs are represented by the id for UNK token

    # Process the abstract
    abstract = ' '.join(abstract_sentences) # string
    abstract_words = abstract.split() # list of strings
    abs_ids = [vocab.word2id(w) for w in abstract_words] # list of word ids; OOVs are represented by the id for UNK token

    # Get the decoder input sequence and target sequence
    self.dec_input, self.target = self.get_dec_inp_targ_seqs(abs_ids, hps.max_dec_steps, start_decoding, stop_decoding)
    self.dec_len = len(self.dec_input)

    # If using pointer-generator mode, we need to store some extra info
    if hps.pointer_gen:
      # Store a version of the enc_input where in-article OOVs are represented by their temporary OOV id; also store the in-article OOVs words themselves
      self.enc_input_extend_vocab, self.article_oovs = data.article2ids(article_words, vocab)

      # Get a verison of the reference summary where in-article OOVs are represented by their temporary article OOV id
      abs_ids_extend_vocab = data.abstract2ids(abstract_words, vocab, self.article_oovs)

      # Overwrite decoder target sequence so it uses the temp article OOV ids
      _, self.target = self.get_dec_inp_targ_seqs(abs_ids_extend_vocab, hps.max_dec_steps, start_decoding, stop_decoding)

    # Store the original strings
    self.original_article = article
    self.original_abstract = abstract
    self.original_abstract_sents = abstract_sentences 
开发者ID:yaserkl,项目名称:TransferRL,代码行数:48,代码来源:batcher.py

示例10: decode

# 需要导入模块: import data [as 别名]
# 或者: from data import STOP_DECODING [as 别名]
def decode(self):
    """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
    t0 = time.time()
    counter = FLAGS.decode_after
    while True:
      tf.reset_default_graph()
      batch = self._batcher.next_batch()  # 1 example repeated across batch
      if batch is None: # finished decoding dataset in single_pass mode
        assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
        tf.logging.info("Decoder has finished reading dataset for single_pass.")
        tf.logging.info("Output has been saved in %s and %s. Now starting ROUGE eval...", self._rouge_ref_dir, self._rouge_dec_dir)
        results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        rouge_log(results_dict, self._decode_dir)
        return

      original_article = batch.original_articles[0]  # string
      original_abstract = batch.original_abstracts[0]  # string
      original_abstract_sents = batch.original_abstracts_sents[0]  # list of strings

      article_withunks = data.show_art_oovs(original_article, self._vocab) # string
      abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string

      # Run beam search to get best Hypothesis
      if FLAGS.ac_training:
        best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch, self._dqn, self._dqn_sess, self._dqn_graph)
      else:
        best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch)
      # Extract the output ids from the hypothesis and convert back to words
      output_ids = [int(t) for t in best_hyp.tokens[1:]]
      decoded_words = data.outputids2words(output_ids, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None))

      # Remove the [STOP] token from decoded_words, if necessary
      try:
        fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
        decoded_words = decoded_words[:fst_stop_idx]
      except ValueError:
        decoded_words = decoded_words
      decoded_output = ' '.join(decoded_words) # single string

      if FLAGS.single_pass:
        self.write_for_rouge(original_abstract_sents, decoded_words, counter) # write ref summary and decoded summary to file, to eval with pyrouge later
        counter += 1 # this is how many examples we've decoded
      else:
        print_results(article_withunks, abstract_withunks, decoded_output) # log output to screen
        self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens) # write info to .json file for visualization tool

        # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
        t1 = time.time()
        if t1-t0 > SECS_UNTIL_NEW_CKPT:
          tf.logging.info('We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint', t1-t0)
          _ = util.load_ckpt(self._saver, self._sess, FLAGS.decode_from)
          t0 = time.time() 
开发者ID:yaserkl,项目名称:RLSeq2Seq,代码行数:54,代码来源:decode.py

示例11: __init__

# 需要导入模块: import data [as 别名]
# 或者: from data import STOP_DECODING [as 别名]
def __init__(self, article, abstract_sentences, vocab, hps):
        """Initializes the Example, performing tokenization and truncation to produce the encoder, decoder and target sequences, which are stored in self.

    Args:
      article: source text; a string. each token is separated by a single space.
      abstract_sentences: list of strings, one per abstract sentence. In each sentence, each token is separated by a single space.
      vocab: Vocabulary object
      hps: hyperparameters
    """
        self.hps = hps

        # Get ids of special tokens
        start_decoding = vocab.word2id(data.START_DECODING)
        stop_decoding = vocab.word2id(data.STOP_DECODING)

        # Process the article
        article_words = article.split()
        if len(article_words) > hps.max_enc_steps:
            article_words = article_words[:hps.max_enc_steps]
        self.enc_len = len(article_words)  # store the length after truncation but before padding
        self.enc_input = [vocab.word2id(w) for w in
                          article_words]  # list of word ids; OOVs are represented by the id for UNK token

        # Process the abstract
        abstract = ' '.join(abstract_sentences)  # string
        abstract_words = abstract.split()  # list of strings
        abs_ids = [vocab.word2id(w) for w in
                   abstract_words]  # list of word ids; OOVs are represented by the id for UNK token

        # Get the decoder input sequence and target sequence
        self.dec_input, self.target = self.get_dec_inp_targ_seqs(abs_ids, hps.max_dec_steps, start_decoding,
                                                                 stop_decoding)
        self.dec_len = len(self.dec_input)

        # If using pointer-generator mode, we need to store some extra info
        if hps.pointer_gen:
            # Store a version of the enc_input where in-article OOVs are represented by their temporary OOV id; also store the in-article OOVs words themselves
            self.enc_input_extend_vocab, self.article_oovs = data.article2ids(article_words, vocab)

            # Get a verison of the reference summary where in-article OOVs are represented by their temporary article OOV id
            abs_ids_extend_vocab = data.abstract2ids(abstract_words, vocab, self.article_oovs)

            # Overwrite decoder target sequence so it uses the temp article OOV ids
            _, self.target = self.get_dec_inp_targ_seqs(abs_ids_extend_vocab, hps.max_dec_steps, start_decoding,
                                                        stop_decoding)

        # Store the original strings
        self.original_article = article
        self.original_abstract = abstract
        self.original_abstract_sents = abstract_sentences 
开发者ID:IBM,项目名称:MAX-Text-Summarizer,代码行数:52,代码来源:batcher.py

示例12: decode

# 需要导入模块: import data [as 别名]
# 或者: from data import STOP_DECODING [as 别名]
def decode(self):
        start = time.time()
        counter = 0
        bleu_scores = []
        batch = self.batcher.next_batch()
        while batch is not None:
            # Run beam search to get best Hypothesis
            best_summary = self.beam_search(batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_summary.tokens[1:]]
            decoded_words = data.outputids2words(output_ids, self.vocab,
                                                 (batch.art_oovs[0] if config.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(data.STOP_DECODING)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words

            original_abstracts = batch.original_abstracts_sents[0]
            reference = original_abstracts[0].strip().split()
            bleu = nltk.translate.bleu_score.sentence_bleu([reference], decoded_words, weights = (0.5, 0.5))
            bleu_scores.append(bleu)

            write_for_rouge(original_abstracts, decoded_words, counter,
                            self._rouge_ref_dir, self._rouge_dec_dir)
            counter += 1
            if counter % 1000 == 0:
                print('%d example in %d sec'%(counter, time.time() - start))
                start = time.time()

            batch = self.batcher.next_batch()
        
        print('Average BLEU score:', np.mean(bleu_scores))
        '''
        # uncomment this if you successfully install `pyrouge`
        print("Decoder has finished reading dataset for single_pass.")
        print("Now starting ROUGE eval...")
        results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        rouge_log(results_dict, self._decode_dir)
        ''' 
开发者ID:wyu-du,项目名称:Reinforce-Paraphrase-Generation,代码行数:45,代码来源:decode.py

示例13: decode

# 需要导入模块: import data [as 别名]
# 或者: from data import STOP_DECODING [as 别名]
def decode(self):
    """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
    t0 = time.time()
    counter = 0
    all_decoded = {} # a dictionary keeping the decoded files to be written for visualization
    while True:
      batch = self._batcher.next_batch()  # 1 example repeated across batch
      if batch is None: # finished decoding dataset in single_pass mode
        assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
        tf.logging.info("Decoder has finished reading dataset for single_pass.")
        tf.logging.info("Output has been saved in %s and %s. Now starting ROUGE eval...", self._rouge_ref_dir, self._rouge_dec_dir)
        results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        rouge_log(results_dict, self._decode_dir)
        if FLAGS.single_pass:
          self.write_all_for_attnvis(all_decoded)
        return


      original_article = batch.original_articles[0]  # string
      original_abstract = batch.original_abstracts[0]  # string
      original_abstract_sents = batch.original_abstracts_sents[0]  # list of strings
      article_id = batch.article_ids[0] #string

      article_withunks = data.show_art_oovs(original_article, self._vocab) # string
      abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string

      # Run beam search to get best Hypothesis
#       import pdb; pdb.set_trace()
      best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch)

      # Extract the output ids from the hypothesis and convert back to words
      output_ids = [int(t) for t in best_hyp.tokens[1:]]
      decoded_words = data.outputids2words(output_ids, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None))

      # Remove the [STOP] token from decoded_words, if necessary
      try:
        fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
        decoded_words = decoded_words[:fst_stop_idx]
      except ValueError:
        decoded_words = decoded_words
      decoded_output = ' '.join(decoded_words) # single string

      if FLAGS.single_pass:
        self.write_for_rouge(original_abstract_sents, decoded_words, article_id) # write ref summary and decoded summary to file, to eval with pyrouge later
        print_results(article_withunks, abstract_withunks, decoded_output, article_id) # log output to screen
        all_decoded[article_id] = self.prepare_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens, best_hyp.attn_dists_sec)
        counter += 1 # this is how many examples we've decoded
        self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens, best_hyp.attn_dists_sec) # write info to .json file for visualization tool        
      else:
        print_results(article_withunks, abstract_withunks, decoded_output, article_id) # log output to screen
        self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens, best_hyp.attn_dists_sec) # write info to .json file for visualization tool

        # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
        t1 = time.time()
        if t1-t0 > SECS_UNTIL_NEW_CKPT:
          tf.logging.info('We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint', t1-t0)
          _ = util.load_ckpt(self._saver, self._sess)
          t0 = time.time() 
开发者ID:armancohan,项目名称:long-summarization,代码行数:60,代码来源:decode.py

示例14: generator_whole_negative_example

# 需要导入模块: import data [as 别名]
# 或者: from data import STOP_DECODING [as 别名]
def generator_whole_negative_example(self):

        counter = 0
        step = 0

        t0 = time.time()
        batches = self.batches

        while step < 1000:
            
            batch = batches[step]
            step += 1

            decode_result = self._model.run_eval_given_step(self._sess, batch)

            for i in range(FLAGS.batch_size):
                decoded_words_all = []
                original_review = batch.original_review_output[i]  # string

                for j in range(FLAGS.max_dec_sen_num):

                    output_ids = [int(t) for t in decode_result['generated'][i][j]][1:]
                    decoded_words = data.outputids2words(output_ids, self._vocab, None)
                    # Remove the [STOP] token from decoded_words, if necessary
                    try:
                        fst_stop_idx = decoded_words.index(data.STOP_DECODING)  # index of the (first) [STOP] symbol
                        decoded_words = decoded_words[:fst_stop_idx]
                    except ValueError:
                        decoded_words = decoded_words

                    if len(decoded_words)<2:
                        continue

                    if len(decoded_words_all)>0:
                        new_set1 =set(decoded_words_all[len(decoded_words_all)-1].split())
                        new_set2= set(decoded_words)
                        if len(new_set1 & new_set2) > 0.5 * len(new_set2):
                            continue
                    if decoded_words[-1] !='.' and decoded_words[-1] !='!' and decoded_words[-1] !='?':
                        decoded_words.append('.')
                    decoded_output = ' '.join(decoded_words).strip()  # single string
                    decoded_words_all.append(decoded_output)

                decoded_words_all = ' '.join(decoded_words_all).strip()
                try:
                    fst_stop_idx = decoded_words_all.index(
                        data.STOP_DECODING_DOCUMENT)  # index of the (first) [STOP] symbol
                    decoded_words_all = decoded_words_all[:fst_stop_idx]
                except ValueError:
                    decoded_words_all = decoded_words_all
                decoded_words_all = decoded_words_all.replace("[UNK] ", "")
                decoded_words_all = decoded_words_all.replace("[UNK]", "")
                decoded_words_all, _ = re.subn(r"(! ){2,}", "", decoded_words_all)
                decoded_words_all, _ = re.subn(r"(\. ){2,}", "", decoded_words_all)

                self.write_negtive_to_json(original_review, decoded_words_all, counter, self.train_sample_whole_positive_dir, self.train_sample_whole_negative_dir)

                counter += 1  # this is how many examples we've decoded 
开发者ID:loretoparisi,项目名称:docker,代码行数:60,代码来源:generated_sample.py

示例15: generator_test_negative_example

# 需要导入模块: import data [as 别名]
# 或者: from data import STOP_DECODING [as 别名]
def generator_test_negative_example(self):

        counter = 0
        step = 0

        t0 = time.time()
        batches = self.test_batches

        while step < 100:
            step += 1
            batch = batches[step]

            decode_result =self._model.run_eval_given_step(self._sess, batch)

            for i in range(FLAGS.batch_size):
                decoded_words_all = []
                original_review = batch.original_review_output[i]  # string

                for j in range(FLAGS.max_dec_sen_num):


                    output_ids = [int(t) for t in decode_result['generated'][i][j]][1:]
                    decoded_words = data.outputids2words(output_ids, self._vocab, None)
                    # Remove the [STOP] token from decoded_words, if necessary
                    try:
                        fst_stop_idx = decoded_words.index(data.STOP_DECODING)  # index of the (first) [STOP] symbol
                        decoded_words = decoded_words[:fst_stop_idx]
                    except ValueError:
                        decoded_words = decoded_words

                    if len(decoded_words)<2:
                        continue

                    if len(decoded_words_all)>0:
                        new_set1 =set(decoded_words_all[len(decoded_words_all)-1].split())
                        new_set2= set(decoded_words)
                        if len(new_set1 & new_set2) > 0.5 * len(new_set2):
                            continue
                    if decoded_words[-1] !='.' and decoded_words[-1] !='!' and decoded_words[-1] !='?':
                        decoded_words.append('.')
                    decoded_output = ' '.join(decoded_words).strip()  # single string
                    decoded_words_all.append(decoded_output)

                decoded_words_all = ' '.join(decoded_words_all).strip()
                try:
                    fst_stop_idx = decoded_words_all.index(
                        data.STOP_DECODING_DOCUMENT)  # index of the (first) [STOP] symbol
                    decoded_words_all = decoded_words_all[:fst_stop_idx]
                except ValueError:
                    decoded_words_all = decoded_words_all
                decoded_words_all = decoded_words_all.replace("[UNK] ", "")
                decoded_words_all = decoded_words_all.replace("[UNK]", "")
                decoded_words_all, _ = re.subn(r"(! ){2,}", "", decoded_words_all)
                decoded_words_all, _ = re.subn(r"(\. ){2,}", "", decoded_words_all)
                self.write_negtive_to_json(original_review, decoded_words_all, counter, self.test_sample_whole_positive_dir,self.test_sample_whole_negative_dir)

                counter += 1  # this is how many examples we've decoded 
开发者ID:loretoparisi,项目名称:docker,代码行数:59,代码来源:generated_sample.py


注:本文中的data.STOP_DECODING属性示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。