当前位置: 首页>>代码示例>>Python>>正文


Python gin.parse_config_files_and_bindings方法代码示例

本文整理汇总了Python中gin.parse_config_files_and_bindings方法的典型用法代码示例。如果您正苦于以下问题:Python gin.parse_config_files_and_bindings方法的具体用法?Python gin.parse_config_files_and_bindings怎么用?Python gin.parse_config_files_and_bindings使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在gin的用法示例。


在下文中一共展示了gin.parse_config_files_and_bindings方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: t2t_train

# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_files_and_bindings [as 别名]
def t2t_train(model_name, dataset_name,
              data_dir=None, output_dir=None, config_file=None, config=None):
  """Main function to train the given model on the given dataset.

  Args:
    model_name: The name of the model to train.
    dataset_name: The name of the dataset to train on.
    data_dir: Directory where the data is located.
    output_dir: Directory where to put the logs and checkpoints.
    config_file: the gin configuration file to use.
    config: string (in gin format) to override gin parameters.
  """
  if model_name not in _MODEL_REGISTRY:
    raise ValueError("Model %s not in registry. Available models:\n * %s." %
                     (model_name, "\n * ".join(_MODEL_REGISTRY.keys())))
  model_class = _MODEL_REGISTRY[model_name]()
  gin.bind_parameter("train_fn.model_class", model_class)
  gin.bind_parameter("train_fn.dataset", dataset_name)
  gin.parse_config_files_and_bindings(config_file, config)
  # TODO(lukaszkaiser): save gin config in output_dir if provided?
  train_fn(data_dir, output_dir=output_dir) 
开发者ID:yyht,项目名称:BERT,代码行数:23,代码来源:t2t.py

示例2: _setup_gin

# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_files_and_bindings [as 别名]
def _setup_gin():
  """Setup gin configuration."""
  # Imports for configurables
  # pylint: disable=g-import-not-at-top,unused-import,g-bad-import-order,reimported,unused-variable
  from tensor2tensor.trax import models as _trax_models
  from tensor2tensor.trax import optimizers as _trax_opt
  # pylint: disable=g-import-not-at-top,unused-import,g-bad-import-order,reimported,unused-variable

  configs = FLAGS.config or []
  # Override with --dataset and --model
  if FLAGS.dataset:
    configs.append("inputs.dataset_name='%s'" % FLAGS.dataset)
    if FLAGS.data_dir:
      configs.append("inputs.data_dir='%s'" % FLAGS.data_dir)
  if FLAGS.model:
    configs.append("train.model=@trax.models.%s" % FLAGS.model)
  gin.parse_config_files_and_bindings(FLAGS.config_file, configs) 
开发者ID:yyht,项目名称:BERT,代码行数:19,代码来源:trainer.py

示例3: _gin_parse_configs

# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_files_and_bindings [as 别名]
def _gin_parse_configs():
  """Initializes gin-controlled bindings."""
  # Imports for configurables
  # pylint: disable=g-import-not-at-top,unused-import,g-bad-import-order,reimported,unused-variable
  from trax import models as _trax_models
  from trax import optimizers as _trax_opt
  # pylint: disable=g-import-not-at-top,unused-import,g-bad-import-order,reimported,unused-variable

  configs = FLAGS.config or []
  # Override with --dataset and --model
  if FLAGS.dataset:
    configs.append("data_streams.dataset_name='%s'" % FLAGS.dataset)
  if FLAGS.data_dir:
    configs.append("data_streams.data_dir='%s'" % FLAGS.data_dir)
  if FLAGS.model:
    configs.append('train.model=@trax.models.%s' % FLAGS.model)
  gin.parse_config_files_and_bindings(FLAGS.config_file, configs) 
开发者ID:google,项目名称:trax,代码行数:19,代码来源:trainer.py

示例4: main

# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_files_and_bindings [as 别名]
def main(argv):
  del argv
  logging.info('Starting RL training.')

  gin_configs = FLAGS.config or []
  gin.parse_config_files_and_bindings(FLAGS.config_file, gin_configs)

  logging.info('Gin cofig:')
  logging.info(gin_configs)

  train_rl(
      output_dir=FLAGS.output_dir,
      train_batch_size=FLAGS.train_batch_size,
      eval_batch_size=FLAGS.eval_batch_size,
      trajectory_dump_dir=(FLAGS.trajectory_dump_dir or None),
  )

  # TODO(afrozm): This is for debugging.
  logging.info('Dumping stack traces of all stacks.')
  faulthandler.dump_traceback(all_threads=True)

  logging.info('Training is done, should exit.') 
开发者ID:google,项目名称:trax,代码行数:24,代码来源:rl_trainer.py

示例5: test_compress_image

# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_files_and_bindings [as 别名]
def test_compress_image(self):
    if not common.has_eager_been_enabled():
      self.skipTest("Image compression only supported in TF2.x")

    gin.parse_config_files_and_bindings([], """
    _get_feature_encoder.compress_image=True
    _get_feature_parser.compress_image=True
    """)
    spec = {
        "image": array_spec.ArraySpec((128, 128, 3), np.uint8)
    }
    serializer = example_encoding.get_example_serializer(spec)
    decoder = example_encoding.get_example_decoder(spec)

    sample = {
        "image": 128 * np.ones([128, 128, 3], dtype=np.uint8)
    }
    example_proto = serializer(sample)

    recovered = self.evaluate(decoder(example_proto))
    tf.nest.map_structure(np.testing.assert_almost_equal, sample, recovered) 
开发者ID:tensorflow,项目名称:agents,代码行数:23,代码来源:example_encoding_test.py

示例6: main

# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_files_and_bindings [as 别名]
def main(unused_argv):
  logging.info("Gin config: %s\nGin bindings: %s",
               FLAGS.gin_config, FLAGS.gin_bindings)
  gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings)


  if FLAGS.use_tpu is None:
    FLAGS.use_tpu = bool(os.environ.get("TPU_NAME", ""))
    if FLAGS.use_tpu:
      logging.info("Found TPU %s.", os.environ["TPU_NAME"])
  run_config = _get_run_config()
  task_manager = _get_task_manager()
  options = runner_lib.get_options_dict()
  runner_lib.run_with_schedule(
      schedule=FLAGS.schedule,
      run_config=run_config,
      task_manager=task_manager,
      options=options,
      use_tpu=FLAGS.use_tpu,
      num_eval_averaging_runs=FLAGS.num_eval_averaging_runs,
      eval_every_steps=FLAGS.eval_every_steps)
  logging.info("I\"m done with my work, ciao!") 
开发者ID:google,项目名称:compare_gan,代码行数:24,代码来源:main.py

示例7: main

# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_files_and_bindings [as 别名]
def main(_):
  gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_params)
  print('********')
  print(FLAGS.mode)
  print(FLAGS.gin_config)
  print(FLAGS.gin_params)

  env = active_vision_dataset_env.ActiveVisionDatasetEnv(modality_types=[
      task_env.ModalityTypes.IMAGE,
      task_env.ModalityTypes.SEMANTIC_SEGMENTATION,
      task_env.ModalityTypes.OBJECT_DETECTION, task_env.ModalityTypes.DEPTH,
      task_env.ModalityTypes.DISTANCE
  ])

  if FLAGS.mode == BENCHMARK_MODE:
    benchmark(env, env.possible_targets)
  elif FLAGS.mode == GRAPH_MODE:
    for loc in env.worlds:
      env.check_scene_graph(loc, 'fridge')
  elif FLAGS.mode == HUMAN_MODE:
    human(env, env.possible_targets)
  elif FLAGS.mode == VIS_MODE:
    visualize_random_step_sequence(env)
  elif FLAGS.mode == EVAL_MODE:
    evaluate_folder(env, FLAGS.eval_folder) 
开发者ID:generalized-iou,项目名称:g-tensorflow-models,代码行数:27,代码来源:viz_active_vision_dataset_main.py

示例8: evaluate_metrics

# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_files_and_bindings [as 别名]
def evaluate_metrics():
  """Evaluates metrics specified in the gin config."""
  # Parse gin config.
  gin.parse_config_files_and_bindings([p.gin_file], [])

  for algo in p.algos:
    for task in p.tasks:
      # Get the subdirectories corresponding to each run.
      summary_path = os.path.join(p.data_dir, algo, task)
      run_dirs = eval_metrics.get_run_dirs(summary_path, 'train', p.runs)

      # Evaluate metrics.
      outfile_prefix = os.path.join(p.metric_values_dir, algo, task) + '/'
      evaluator = eval_metrics.Evaluator(metrics=gin.REQUIRED)
      evaluator.write_metric_params(outfile_prefix)
      evaluator.evaluate(run_dirs=run_dirs, outfile_prefix=outfile_prefix) 
开发者ID:google-research,项目名称:rl-reliability-metrics,代码行数:18,代码来源:evaluate_metrics.py

示例9: override_gin

# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_files_and_bindings [as 别名]
def override_gin(self, bindings):
    gin.parse_config_files_and_bindings(None, bindings) 
开发者ID:yyht,项目名称:BERT,代码行数:4,代码来源:backend_test.py

示例10: main

# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_files_and_bindings [as 别名]
def main(unused_argv):
  gin.parse_config_files_and_bindings(FLAGS.gin_configs, FLAGS.gin_bindings)
  train_eval.train_eval_model() 
开发者ID:google-research,项目名称:tensor2robot,代码行数:5,代码来源:run_t2r_trainer.py

示例11: main

# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_files_and_bindings [as 别名]
def main(unused_argv):
  del unused_argv
  gin.parse_config_files_and_bindings(FLAGS.gin_configs, FLAGS.gin_bindings)
  continuous_collect_eval.collect_eval_loop(root_dir=FLAGS.root_dir) 
开发者ID:google-research,项目名称:tensor2robot,代码行数:6,代码来源:run_collect_eval.py

示例12: parse_gin_defaults_and_flags

# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_files_and_bindings [as 别名]
def parse_gin_defaults_and_flags():
  """Parses all default gin files and those provided via flags."""
  # Register .gin file search paths with gin
  for gin_file_path in FLAGS.gin_location_prefix:
    gin.add_config_file_search_path(gin_file_path)
  # Set up the default values for the configurable parameters. These values will
  # be overridden by any user provided gin files/parameters.
  gin.parse_config_file(
      pkg_resources.resource_filename(__name__, _DEFAULT_CONFIG_FILE))
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)


# TODO(noam): maybe add gin-config to mtf.get_variable so we can delete
#  this stupid VariableDtype class and stop passing it all over creation. 
开发者ID:tensorflow,项目名称:mesh,代码行数:16,代码来源:utils.py

示例13: main

# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_files_and_bindings [as 别名]
def main(_):
  logging.set_verbosity(logging.INFO)
  tf.compat.v1.enable_v2_behavior()
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
  train_eval(FLAGS.root_dir, num_iterations=FLAGS.num_iterations) 
开发者ID:tensorflow,项目名称:agents,代码行数:7,代码来源:train_eval.py

示例14: main

# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_files_and_bindings [as 别名]
def main(_):
  tf.compat.v1.enable_v2_behavior()
  logging.set_verbosity(logging.INFO)
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
  train_eval(FLAGS.root_dir, num_iterations=FLAGS.num_iterations) 
开发者ID:tensorflow,项目名称:agents,代码行数:7,代码来源:train_eval_rnn.py

示例15: main

# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_files_and_bindings [as 别名]
def main(_):
  tf.compat.v1.enable_resource_variables()
  logging.set_verbosity(logging.INFO)
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
  train_eval(FLAGS.root_dir) 
开发者ID:tensorflow,项目名称:agents,代码行数:7,代码来源:train_eval.py


注:本文中的gin.parse_config_files_and_bindings方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。