当前位置: 首页>>代码示例>>Python>>正文


Python tpu.TPUConfig方法代码示例

本文整理汇总了Python中tensorflow.contrib.tpu.TPUConfig方法的典型用法代码示例。如果您正苦于以下问题:Python tpu.TPUConfig方法的具体用法?Python tpu.TPUConfig怎么用?Python tpu.TPUConfig使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.contrib.tpu的用法示例。


在下文中一共展示了tpu.TPUConfig方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: make_tpu_run_config

# 需要导入模块: from tensorflow.contrib import tpu [as 别名]
# 或者: from tensorflow.contrib.tpu import TPUConfig [as 别名]
def make_tpu_run_config(master, seed, model_dir, iterations_per_loop,
                        save_checkpoints_steps):
  return contrib_tpu.RunConfig(
      master=master,
      evaluation_master=master,
      model_dir=model_dir,
      save_checkpoints_steps=save_checkpoints_steps,
      cluster=None,
      tf_random_seed=seed,
      tpu_config=contrib_tpu.TPUConfig(iterations_per_loop=iterations_per_loop)) 
开发者ID:tensorflow,项目名称:kfac,代码行数:12,代码来源:classifier_mnist_tpu_estimator.py

示例2: main

# 需要导入模块: from tensorflow.contrib import tpu [as 别名]
# 或者: from tensorflow.contrib.tpu import TPUConfig [as 别名]
def main(unused_argv):
  flags.mark_flag_as_required('model_dir')
  flags.mark_flag_as_required('pipeline_config_path')

  tpu_cluster_resolver = (
      contrib_cluster_resolver.TPUClusterResolver(
          tpu=[FLAGS.tpu_name], zone=FLAGS.tpu_zone, project=FLAGS.gcp_project))
  tpu_grpc_url = tpu_cluster_resolver.get_master()

  config = contrib_tpu.RunConfig(
      master=tpu_grpc_url,
      evaluation_master=tpu_grpc_url,
      model_dir=FLAGS.model_dir,
      tpu_config=contrib_tpu.TPUConfig(
          iterations_per_loop=FLAGS.iterations_per_loop,
          num_shards=FLAGS.num_shards))

  kwargs = {}
  if FLAGS.train_batch_size:
    kwargs['batch_size'] = FLAGS.train_batch_size

  train_and_eval_dict = model_lib.create_estimator_and_inputs(
      run_config=config,
      hparams=model_hparams.create_hparams(FLAGS.hparams_overrides),
      pipeline_config_path=FLAGS.pipeline_config_path,
      train_steps=FLAGS.num_train_steps,
      sample_1_of_n_eval_examples=FLAGS.sample_1_of_n_eval_examples,
      sample_1_of_n_eval_on_train_examples=(
          FLAGS.sample_1_of_n_eval_on_train_examples),
      use_tpu_estimator=True,
      use_tpu=FLAGS.use_tpu,
      num_shards=FLAGS.num_shards,
      save_final_config=FLAGS.mode == 'train',
      **kwargs)
  estimator = train_and_eval_dict['estimator']
  train_input_fn = train_and_eval_dict['train_input_fn']
  eval_input_fns = train_and_eval_dict['eval_input_fns']
  eval_on_train_input_fn = train_and_eval_dict['eval_on_train_input_fn']
  train_steps = train_and_eval_dict['train_steps']

  if FLAGS.mode == 'train':
    estimator.train(input_fn=train_input_fn, max_steps=train_steps)

  # Continuously evaluating.
  if FLAGS.mode == 'eval':
    if FLAGS.eval_training_data:
      name = 'training_data'
      input_fn = eval_on_train_input_fn
    else:
      name = 'validation_data'
      # Currently only a single eval input is allowed.
      input_fn = eval_input_fns[0]
    model_lib.continuous_eval(estimator, FLAGS.model_dir, input_fn, train_steps,
                              name, FLAGS.max_eval_retries) 
开发者ID:tensorflow,项目名称:models,代码行数:56,代码来源:model_tpu_main.py


注:本文中的tensorflow.contrib.tpu.TPUConfig方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。