当前位置: 首页>>代码示例>>Python>>正文


Python cmp.setup_train_step_kwargs方法代码示例

本文整理汇总了Python中tfcode.cmp.setup_train_step_kwargs方法的典型用法代码示例。如果您正苦于以下问题:Python cmp.setup_train_step_kwargs方法的具体用法?Python cmp.setup_train_step_kwargs怎么用?Python cmp.setup_train_step_kwargs使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tfcode.cmp的用法示例。


在下文中一共展示了cmp.setup_train_step_kwargs方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get_args_for_config

# 需要导入模块: from tfcode import cmp [as 别名]
# 或者: from tfcode.cmp import setup_train_step_kwargs [as 别名]
def get_args_for_config(config_name):
  configs = config_name.split('.')
  type = configs[0]
  config_name = '.'.join(configs[1:])
  if type == 'cmp':
    args = config_cmp.get_args_for_config(config_name)
    args.setup_to_run = cmp.setup_to_run
    args.setup_train_step_kwargs = cmp.setup_train_step_kwargs

  elif type == 'bl':
    args = config_vision_baseline.get_args_for_config(config_name)
    args.setup_to_run = vision_baseline_lstm.setup_to_run
    args.setup_train_step_kwargs = vision_baseline_lstm.setup_train_step_kwargs

  else:
    logging.fatal('Unknown type: {:s}'.format(type))
  return args 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:19,代码来源:script_nav_agent_release.py

示例2: _train

# 需要导入模块: from tfcode import cmp [as 别名]
# 或者: from tfcode.cmp import setup_train_step_kwargs [as 别名]
def _train(args):
  container_name = ""

  R = lambda: nav_env.get_multiplexer_class(args.navtask, args.solver.task)
  m = utils.Foo()
  m.tf_graph = tf.Graph()

  config = tf.ConfigProto()
  config.device_count['GPU'] = 1

  with m.tf_graph.as_default():
    with tf.device(tf.train.replica_device_setter(args.solver.ps_tasks,
                                          merge_devices=True)):
      with tf.container(container_name):
        m = args.setup_to_run(m, args, is_training=True,
                             batch_norm_is_training=True, summary_mode='train')

        train_step_kwargs = args.setup_train_step_kwargs(
            m, R(), os.path.join(args.logdir, 'train'), rng_seed=args.solver.task,
            is_chief=args.solver.task==0,
            num_steps=args.navtask.task_params.num_steps*args.navtask.task_params.num_goals, iters=1,
            train_display_interval=args.summary.display_interval,
            dagger_sample_bn_false=args.arch.dagger_sample_bn_false)

        delay_start = (args.solver.task*(args.solver.task+1))/2 * FLAGS.delay_start_iters
        logging.error('delaying start for task %d by %d steps.',
                      args.solver.task, delay_start)

        additional_args = {}
        final_loss = slim.learning.train(
            train_op=m.train_op,
            logdir=args.logdir,
            master=args.solver.master,
            is_chief=args.solver.task == 0,
            number_of_steps=args.solver.max_steps,
            train_step_fn=tf_utils.train_step_custom_online_sampling,
            train_step_kwargs=train_step_kwargs,
            global_step=m.global_step_op,
            init_op=m.init_op,
            init_fn=m.init_fn,
            sync_optimizer=m.sync_optimizer,
            saver=m.saver_op,
            startup_delay_steps=delay_start,
            summary_op=None, session_config=config, **additional_args) 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:46,代码来源:script_nav_agent_release.py


注:本文中的tfcode.cmp.setup_train_step_kwargs方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。