本文整理匯總了Python中skip_thoughts.configuration.training_config方法的典型用法代碼示例。如果您正苦於以下問題:Python configuration.training_config方法的具體用法?Python configuration.training_config怎麽用?Python configuration.training_config使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類skip_thoughts.configuration
的用法示例。
在下文中一共展示了configuration.training_config方法的2個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: main
# 需要導入模塊: from skip_thoughts import configuration [as 別名]
# 或者: from skip_thoughts.configuration import training_config [as 別名]
def main(unused_argv):
if not FLAGS.input_file_pattern:
raise ValueError("--input_file_pattern is required.")
if not FLAGS.train_dir:
raise ValueError("--train_dir is required.")
model_config = configuration.model_config(
input_file_pattern=FLAGS.input_file_pattern)
training_config = configuration.training_config()
tf.logging.info("Building training graph.")
g = tf.Graph()
with g.as_default():
model = skip_thoughts_model.SkipThoughtsModel(model_config, mode="train")
model.build()
learning_rate = _setup_learning_rate(training_config, model.global_step)
optimizer = tf.train.AdamOptimizer(learning_rate)
train_tensor = tf.contrib.slim.learning.create_train_op(
total_loss=model.total_loss,
optimizer=optimizer,
global_step=model.global_step,
clip_gradient_norm=training_config.clip_gradient_norm)
saver = tf.train.Saver()
tf.contrib.slim.learning.train(
train_op=train_tensor,
logdir=FLAGS.train_dir,
graph=g,
global_step=model.global_step,
number_of_steps=training_config.number_of_steps,
save_summaries_secs=training_config.save_summaries_secs,
saver=saver,
save_interval_secs=training_config.save_model_secs)
示例2: main
# 需要導入模塊: from skip_thoughts import configuration [as 別名]
# 或者: from skip_thoughts.configuration import training_config [as 別名]
def main(unused_argv):
if not FLAGS.input_file_pattern:
raise ValueError("--input_file_pattern is required.")
if not FLAGS.train_dir:
raise ValueError("--train_dir is required.")
model_config = configuration.model_config(
input_file_pattern=FLAGS.input_file_pattern)
training_config = configuration.training_config()
tf.logging.info("Building training graph.")
g = tf.Graph()
with g.as_default():
model = skip_thoughts_model.SkipThoughtsModel(model_config,
mode="train")
model.build()
learning_rate = _setup_learning_rate(training_config, model.global_step)
optimizer = tf.train.AdamOptimizer(learning_rate)
train_tensor = tf.contrib.slim.learning.create_train_op(
total_loss=model.total_loss,
optimizer=optimizer,
global_step=model.global_step,
clip_gradient_norm=training_config.clip_gradient_norm)
saver = tf.train.Saver()
tf.contrib.slim.learning.train(
train_op=train_tensor,
logdir=FLAGS.train_dir,
graph=g,
global_step=model.global_step,
number_of_steps=training_config.number_of_steps,
save_summaries_secs=training_config.save_summaries_secs,
saver=saver,
save_interval_secs=training_config.save_model_secs)