本文整理汇总了Python中tfcode.tf_utils.train_step_custom_online_sampling方法的典型用法代码示例。如果您正苦于以下问题:Python tf_utils.train_step_custom_online_sampling方法的具体用法?Python tf_utils.train_step_custom_online_sampling怎么用?Python tf_utils.train_step_custom_online_sampling使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tfcode.tf_utils
的用法示例。
在下文中一共展示了tf_utils.train_step_custom_online_sampling方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _train
# 需要导入模块: from tfcode import tf_utils [as 别名]
# 或者: from tfcode.tf_utils import train_step_custom_online_sampling [as 别名]
def _train(args):
container_name = ""
R = lambda: nav_env.get_multiplexer_class(args.navtask, args.solver.task)
m = utils.Foo()
m.tf_graph = tf.Graph()
config = tf.ConfigProto()
config.device_count['GPU'] = 1
with m.tf_graph.as_default():
with tf.device(tf.train.replica_device_setter(args.solver.ps_tasks,
merge_devices=True)):
with tf.container(container_name):
m = args.setup_to_run(m, args, is_training=True,
batch_norm_is_training=True, summary_mode='train')
train_step_kwargs = args.setup_train_step_kwargs(
m, R(), os.path.join(args.logdir, 'train'), rng_seed=args.solver.task,
is_chief=args.solver.task==0,
num_steps=args.navtask.task_params.num_steps*args.navtask.task_params.num_goals, iters=1,
train_display_interval=args.summary.display_interval,
dagger_sample_bn_false=args.arch.dagger_sample_bn_false)
delay_start = (args.solver.task*(args.solver.task+1))/2 * FLAGS.delay_start_iters
logging.error('delaying start for task %d by %d steps.',
args.solver.task, delay_start)
additional_args = {}
final_loss = slim.learning.train(
train_op=m.train_op,
logdir=args.logdir,
master=args.solver.master,
is_chief=args.solver.task == 0,
number_of_steps=args.solver.max_steps,
train_step_fn=tf_utils.train_step_custom_online_sampling,
train_step_kwargs=train_step_kwargs,
global_step=m.global_step_op,
init_op=m.init_op,
init_fn=m.init_fn,
sync_optimizer=m.sync_optimizer,
saver=m.saver_op,
startup_delay_steps=delay_start,
summary_op=None, session_config=config, **additional_args)