本文整理匯總了Python中seq2seq_attention_decode.BSDecoder方法的典型用法代碼示例。如果您正苦於以下問題:Python seq2seq_attention_decode.BSDecoder方法的具體用法?Python seq2seq_attention_decode.BSDecoder怎麽用?Python seq2seq_attention_decode.BSDecoder使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類seq2seq_attention_decode
的用法示例。
在下文中一共展示了seq2seq_attention_decode.BSDecoder方法的2個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: main
# 需要導入模塊: import seq2seq_attention_decode [as 別名]
# 或者: from seq2seq_attention_decode import BSDecoder [as 別名]
def main(unused_argv):
vocab = data.Vocab(FLAGS.vocab_path, 1000000)
# Check for presence of required special tokens.
assert vocab.CheckVocab(data.PAD_TOKEN) > 0
assert vocab.CheckVocab(data.UNKNOWN_TOKEN) >= 0
assert vocab.CheckVocab(data.SENTENCE_START) > 0
assert vocab.CheckVocab(data.SENTENCE_END) > 0
batch_size = 4
if FLAGS.mode == 'decode':
batch_size = FLAGS.beam_size
hps = seq2seq_attention_model.HParams(
mode=FLAGS.mode, # train, eval, decode
min_lr=0.01, # min learning rate.
lr=0.15, # learning rate
batch_size=batch_size,
enc_layers=4,
enc_timesteps=120,
dec_timesteps=30,
min_input_len=2, # discard articles/summaries < than this
num_hidden=256, # for rnn cell
emb_dim=128, # If 0, don't use embedding
max_grad_norm=2,
num_softmax_samples=4096) # If 0, no sampled softmax.
batcher = batch_reader.Batcher(
FLAGS.data_path, vocab, hps, FLAGS.article_key,
FLAGS.abstract_key, FLAGS.max_article_sentences,
FLAGS.max_abstract_sentences, bucketing=FLAGS.use_bucketing,
truncate_input=FLAGS.truncate_input)
tf.set_random_seed(FLAGS.random_seed)
if hps.mode == 'train':
model = seq2seq_attention_model.Seq2SeqAttentionModel(
hps, vocab, num_gpus=FLAGS.num_gpus)
_Train(model, batcher)
elif hps.mode == 'eval':
model = seq2seq_attention_model.Seq2SeqAttentionModel(
hps, vocab, num_gpus=FLAGS.num_gpus)
_Eval(model, batcher, vocab=vocab)
elif hps.mode == 'decode':
decode_mdl_hps = hps
# Only need to restore the 1st step and reuse it since
# we keep and feed in state for each step's output.
decode_mdl_hps = hps._replace(dec_timesteps=1)
model = seq2seq_attention_model.Seq2SeqAttentionModel(
decode_mdl_hps, vocab, num_gpus=FLAGS.num_gpus)
decoder = seq2seq_attention_decode.BSDecoder(model, batcher, hps, vocab)
decoder.DecodeLoop()
示例2: main
# 需要導入模塊: import seq2seq_attention_decode [as 別名]
# 或者: from seq2seq_attention_decode import BSDecoder [as 別名]
def main(unused_argv):
vocab = data.Vocab(FLAGS.vocab_path, 1000000)
# Check for presence of required special tokens.
assert vocab.WordToId(data.PAD_TOKEN) > 0
assert vocab.WordToId(data.UNKNOWN_TOKEN) >= 0
assert vocab.WordToId(data.SENTENCE_START) > 0
assert vocab.WordToId(data.SENTENCE_END) > 0
batch_size = 4
if FLAGS.mode == 'decode':
batch_size = FLAGS.beam_size
hps = seq2seq_attention_model.HParams(
mode=FLAGS.mode, # train, eval, decode
min_lr=0.01, # min learning rate.
lr=0.15, # learning rate
batch_size=batch_size,
enc_layers=4,
enc_timesteps=120,
dec_timesteps=30,
min_input_len=2, # discard articles/summaries < than this
num_hidden=256, # for rnn cell
emb_dim=128, # If 0, don't use embedding
max_grad_norm=2,
num_softmax_samples=4096) # If 0, no sampled softmax.
batcher = batch_reader.Batcher(
FLAGS.data_path, vocab, hps, FLAGS.article_key,
FLAGS.abstract_key, FLAGS.max_article_sentences,
FLAGS.max_abstract_sentences, bucketing=FLAGS.use_bucketing,
truncate_input=FLAGS.truncate_input)
tf.set_random_seed(FLAGS.random_seed)
if hps.mode == 'train':
model = seq2seq_attention_model.Seq2SeqAttentionModel(
hps, vocab, num_gpus=FLAGS.num_gpus)
_Train(model, batcher)
elif hps.mode == 'eval':
model = seq2seq_attention_model.Seq2SeqAttentionModel(
hps, vocab, num_gpus=FLAGS.num_gpus)
_Eval(model, batcher, vocab=vocab)
elif hps.mode == 'decode':
decode_mdl_hps = hps
# Only need to restore the 1st step and reuse it since
# we keep and feed in state for each step's output.
decode_mdl_hps = hps._replace(dec_timesteps=1)
model = seq2seq_attention_model.Seq2SeqAttentionModel(
decode_mdl_hps, vocab, num_gpus=FLAGS.num_gpus)
decoder = seq2seq_attention_decode.BSDecoder(model, batcher, hps, vocab)
decoder.DecodeLoop()