本文整理汇总了Python中seq2seq_attention_model.HParams方法的典型用法代码示例。如果您正苦于以下问题:Python seq2seq_attention_model.HParams方法的具体用法?Python seq2seq_attention_model.HParams怎么用?Python seq2seq_attention_model.HParams使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类seq2seq_attention_model
的用法示例。
在下文中一共展示了seq2seq_attention_model.HParams方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: main
# 需要导入模块: import seq2seq_attention_model [as 别名]
# 或者: from seq2seq_attention_model import HParams [as 别名]
def main(unused_argv):
vocab = data.Vocab(FLAGS.vocab_path, 1000000)
# Check for presence of required special tokens.
assert vocab.CheckVocab(data.PAD_TOKEN) > 0
assert vocab.CheckVocab(data.UNKNOWN_TOKEN) >= 0
assert vocab.CheckVocab(data.SENTENCE_START) > 0
assert vocab.CheckVocab(data.SENTENCE_END) > 0
batch_size = 4
if FLAGS.mode == 'decode':
batch_size = FLAGS.beam_size
hps = seq2seq_attention_model.HParams(
mode=FLAGS.mode, # train, eval, decode
min_lr=0.01, # min learning rate.
lr=0.15, # learning rate
batch_size=batch_size,
enc_layers=4,
enc_timesteps=120,
dec_timesteps=30,
min_input_len=2, # discard articles/summaries < than this
num_hidden=256, # for rnn cell
emb_dim=128, # If 0, don't use embedding
max_grad_norm=2,
num_softmax_samples=4096) # If 0, no sampled softmax.
batcher = batch_reader.Batcher(
FLAGS.data_path, vocab, hps, FLAGS.article_key,
FLAGS.abstract_key, FLAGS.max_article_sentences,
FLAGS.max_abstract_sentences, bucketing=FLAGS.use_bucketing,
truncate_input=FLAGS.truncate_input)
tf.set_random_seed(FLAGS.random_seed)
if hps.mode == 'train':
model = seq2seq_attention_model.Seq2SeqAttentionModel(
hps, vocab, num_gpus=FLAGS.num_gpus)
_Train(model, batcher)
elif hps.mode == 'eval':
model = seq2seq_attention_model.Seq2SeqAttentionModel(
hps, vocab, num_gpus=FLAGS.num_gpus)
_Eval(model, batcher, vocab=vocab)
elif hps.mode == 'decode':
decode_mdl_hps = hps
# Only need to restore the 1st step and reuse it since
# we keep and feed in state for each step's output.
decode_mdl_hps = hps._replace(dec_timesteps=1)
model = seq2seq_attention_model.Seq2SeqAttentionModel(
decode_mdl_hps, vocab, num_gpus=FLAGS.num_gpus)
decoder = seq2seq_attention_decode.BSDecoder(model, batcher, hps, vocab)
decoder.DecodeLoop()
示例2: main
# 需要导入模块: import seq2seq_attention_model [as 别名]
# 或者: from seq2seq_attention_model import HParams [as 别名]
def main(unused_argv):
vocab = data.Vocab(FLAGS.vocab_path, 1000000)
# Check for presence of required special tokens.
assert vocab.WordToId(data.PAD_TOKEN) > 0
assert vocab.WordToId(data.UNKNOWN_TOKEN) >= 0
assert vocab.WordToId(data.SENTENCE_START) > 0
assert vocab.WordToId(data.SENTENCE_END) > 0
batch_size = 4
if FLAGS.mode == 'decode':
batch_size = FLAGS.beam_size
hps = seq2seq_attention_model.HParams(
mode=FLAGS.mode, # train, eval, decode
min_lr=0.01, # min learning rate.
lr=0.15, # learning rate
batch_size=batch_size,
enc_layers=4,
enc_timesteps=120,
dec_timesteps=30,
min_input_len=2, # discard articles/summaries < than this
num_hidden=256, # for rnn cell
emb_dim=128, # If 0, don't use embedding
max_grad_norm=2,
num_softmax_samples=4096) # If 0, no sampled softmax.
batcher = batch_reader.Batcher(
FLAGS.data_path, vocab, hps, FLAGS.article_key,
FLAGS.abstract_key, FLAGS.max_article_sentences,
FLAGS.max_abstract_sentences, bucketing=FLAGS.use_bucketing,
truncate_input=FLAGS.truncate_input)
tf.set_random_seed(FLAGS.random_seed)
if hps.mode == 'train':
model = seq2seq_attention_model.Seq2SeqAttentionModel(
hps, vocab, num_gpus=FLAGS.num_gpus)
_Train(model, batcher)
elif hps.mode == 'eval':
model = seq2seq_attention_model.Seq2SeqAttentionModel(
hps, vocab, num_gpus=FLAGS.num_gpus)
_Eval(model, batcher, vocab=vocab)
elif hps.mode == 'decode':
decode_mdl_hps = hps
# Only need to restore the 1st step and reuse it since
# we keep and feed in state for each step's output.
decode_mdl_hps = hps._replace(dec_timesteps=1)
model = seq2seq_attention_model.Seq2SeqAttentionModel(
decode_mdl_hps, vocab, num_gpus=FLAGS.num_gpus)
decoder = seq2seq_attention_decode.BSDecoder(model, batcher, hps, vocab)
decoder.DecodeLoop()