当前位置: 首页>>代码示例>>Python>>正文


Python gin.parse_config_file方法代码示例

本文整理汇总了Python中gin.parse_config_file方法的典型用法代码示例。如果您正苦于以下问题:Python gin.parse_config_file方法的具体用法?Python gin.parse_config_file怎么用?Python gin.parse_config_file使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在gin的用法示例。


在下文中一共展示了gin.parse_config_file方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __init__

# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_file [as 别名]
def __init__(self, data_path, mode, train_set, validation_set, test_set, max_way_train, max_way_test, max_support_train, max_support_test):

        self.data_path = data_path
        self.train_dataset_next_task = None
        self.validation_set_dict = {}
        self.test_set_dict = {}
        gin.parse_config_file('./meta_dataset_config.gin')

        if mode == 'train' or mode == 'train_test':
            train_episode_description = self._get_train_episode_description(max_way_train, max_support_train)
            self.train_dataset_next_task = self._init_multi_source_dataset(train_set, learning_spec.Split.TRAIN,
                                                                           train_episode_description)

            test_episode_description = self._get_test_episode_description(max_way_test, max_support_test)
            for item in validation_set:
                next_task = self.validation_dataset = self._init_single_source_dataset(item, learning_spec.Split.VALID,
                                                                                       test_episode_description)
                self.validation_set_dict[item] = next_task

        if mode == 'test' or mode == 'train_test':
            test_episode_description = self._get_test_episode_description(max_way_test, max_support_test)
            for item in test_set:
                next_task = self._init_single_source_dataset(item, learning_spec.Split.TEST, test_episode_description)
                self.test_set_dict[item] = next_task 
开发者ID:cambridge-mlg,项目名称:cnaps,代码行数:26,代码来源:meta_dataset_reader.py

示例2: test_reformer_wmt_ende

# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_file [as 别名]
def test_reformer_wmt_ende(self):
    trax.fastmath.disable_jit()

    batch_size_per_device = 2
    steps = 1
    n_layers = 2
    d_ff = 32

    gin.parse_config_file('reformer_wmt_ende.gin')

    gin.bind_parameter('data_streams.data_dir', _TESTDATA)
    gin.bind_parameter('batcher.batch_size_per_device', batch_size_per_device)
    gin.bind_parameter('train.steps', steps)
    gin.bind_parameter('Reformer.n_encoder_layers', n_layers)
    gin.bind_parameter('Reformer.n_decoder_layers', n_layers)
    gin.bind_parameter('Reformer.d_ff', d_ff)

    with self.tmp_dir() as output_dir:
      _ = trainer_lib.train(output_dir=output_dir) 
开发者ID:google,项目名称:trax,代码行数:21,代码来源:reformer_e2e_test.py

示例3: test_reformer_noencdecattn_wmt_ende

# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_file [as 别名]
def test_reformer_noencdecattn_wmt_ende(self):
    trax.fastmath.disable_jit()

    batch_size_per_device = 1  # Ignored, but needs to be set.
    steps = 1
    n_layers = 2
    d_ff = 32

    gin.parse_config_file('reformer_noencdecattn_wmt_ende.gin')

    gin.bind_parameter('data_streams.data_dir', _TESTDATA)
    gin.bind_parameter('batcher.batch_size_per_device', batch_size_per_device)
    gin.bind_parameter('batcher.buckets', ([513], [1, 1]))  # batch size 1.
    gin.bind_parameter('train.steps', steps)
    gin.bind_parameter('ReformerNoEncDecAttention.n_encoder_layers', n_layers)
    gin.bind_parameter('ReformerNoEncDecAttention.n_decoder_layers', n_layers)
    gin.bind_parameter('ReformerNoEncDecAttention.d_ff', d_ff)

    with self.tmp_dir() as output_dir:
      _ = trainer_lib.train(output_dir=output_dir) 
开发者ID:google,项目名称:trax,代码行数:22,代码来源:reformer_e2e_test.py

示例4: test_run_pose_env_collect

# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_file [as 别名]
def test_run_pose_env_collect(self, demo_policy_cls):
    urdf_root = pose_env.get_pybullet_urdf_root()

    config_dir = 'research/pose_env/configs'
    gin_config = os.path.join(
        FLAGS.test_srcdir, config_dir, 'run_random_collect.gin')
    gin.parse_config_file(gin_config)
    tmp_dir = absltest.get_default_test_tmpdir()
    root_dir = os.path.join(tmp_dir, str(demo_policy_cls))
    gin.bind_parameter('PoseToyEnv.urdf_root', urdf_root)
    gin.bind_parameter(
        'collect_eval_loop.root_dir', root_dir)
    gin.bind_parameter('run_meta_env.num_tasks', 2)
    gin.bind_parameter('run_meta_env.num_episodes_per_adaptation', 1)
    gin.bind_parameter(
        'collect_eval_loop.policy_class', demo_policy_cls)
    continuous_collect_eval.collect_eval_loop()
    output_files = tf.io.gfile.glob(os.path.join(
        root_dir, 'policy_collect', '*.tfrecord'))
    self.assertLen(output_files, 2) 
开发者ID:google-research,项目名称:tensor2robot,代码行数:22,代码来源:continuous_collect_eval_test.py

示例5: setUp

# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_file [as 别名]
def setUp(self):
    super(EvalMetricsTest, self).setUp()

    gin.clear_config()
    gin_file = os.path.join(
        './',
        'rl_reliability_metrics/evaluation',
        'eval_metrics_test.gin')
    gin.parse_config_file(gin_file)

    # fake set of training curves to test analysis
    self.test_data_dir = os.path.join(
        './',
        'rl_reliability_metrics/evaluation/test_data')
    self.run_dirs = [
        os.path.join(self.test_data_dir, 'run%d' % i, 'train') for i in range(3)
    ] 
开发者ID:google-research,项目名称:rl-reliability-metrics,代码行数:19,代码来源:eval_metrics_test.py

示例6: setUp

# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_file [as 别名]
def setUp(self):
    super(DataLoadingTest, self).setUp()

    gin.clear_config()
    gin_file = os.path.join(
        './',
        'rl_reliability_metrics/evaluation',
        'eval_metrics_test.gin')
    gin.parse_config_file(gin_file)

    # fake set of training curves to test analysis
    test_data_dir = os.path.join(
        './',
        'rl_reliability_metrics/evaluation/test_data')
    self.run_dirs = [
        os.path.join(test_data_dir, 'run%d' % i, 'train') for i in range(3)
    ] 
开发者ID:google-research,项目名称:rl-reliability-metrics,代码行数:19,代码来源:data_loading_test.py

示例7: main

# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_file [as 别名]
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  gin.parse_config_file(FLAGS.gin_config_path)
  runner = runner_lib.Runner()

  results = runner.run()
  logging.info('Results: %s', results)

  with open(FLAGS.output_path, 'w') as f:
    f.write(core.to_json(results)) 
开发者ID:google,项目名称:ml-fairness-gym,代码行数:14,代码来源:runner.py

示例8: parse_gin_defaults_and_flags

# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_file [as 别名]
def parse_gin_defaults_and_flags():
  """Parses all default gin files and those provided via flags."""
  # Register .gin file search paths with gin
  for gin_file_path in FLAGS.gin_location_prefix:
    gin.add_config_file_search_path(gin_file_path)
  # Set up the default values for the configurable parameters. These values will
  # be overridden by any user provided gin files/parameters.
  gin.parse_config_file(
      pkg_resources.resource_filename(__name__, _DEFAULT_CONFIG_FILE))
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)


# TODO(noam): maybe add gin-config to mtf.get_variable so we can delete
#  this stupid VariableDtype class and stop passing it all over creation. 
开发者ID:tensorflow,项目名称:mesh,代码行数:16,代码来源:utils.py

示例9: testGinConfig

# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_file [as 别名]
def testGinConfig(self):
    gin.parse_config_file(
        test_utils.test_src_dir_path('environments/configs/suite_pybullet.gin')
    )
    env = suite_pybullet.load()
    self.assertIsInstance(env, py_environment.PyEnvironment)
    self.assertIsInstance(env, wrappers.TimeLimit) 
开发者ID:tensorflow,项目名称:agents,代码行数:9,代码来源:suite_pybullet_test.py

示例10: testGinConfig

# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_file [as 别名]
def testGinConfig(self):
    gin.parse_config_file(
        test_utils.test_src_dir_path('environments/configs/suite_gym.gin')
    )
    env = suite_gym.load()
    self.assertIsInstance(env, py_environment.PyEnvironment)
    self.assertIsInstance(env, wrappers.TimeLimit) 
开发者ID:tensorflow,项目名称:agents,代码行数:9,代码来源:suite_gym_test.py

示例11: testGinConfig

# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_file [as 别名]
def testGinConfig(self):
    gin.parse_config_file(
        test_utils.test_src_dir_path('environments/configs/suite_mujoco.gin')
    )
    env = suite_mujoco.load()
    self.assertIsInstance(env, py_environment.PyEnvironment)
    self.assertIsInstance(env, wrappers.TimeLimit) 
开发者ID:tensorflow,项目名称:agents,代码行数:9,代码来源:suite_mujoco_test.py

示例12: testGinConfig

# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_file [as 别名]
def testGinConfig(self):
    gin.parse_config_file(
        test_utils.test_src_dir_path('environments/configs/suite_bsuite.gin')
    )
    env = suite_bsuite.load()
    self.assertIsInstance(env, py_environment.PyEnvironment) 
开发者ID:tensorflow,项目名称:agents,代码行数:8,代码来源:suite_bsuite_test.py

示例13: eval

# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_file [as 别名]
def eval(self, mixture_or_task_name, checkpoint_steps=None, summary_dir=None,
           split="validation"):
    """Evaluate the model on the given Mixture or Task.

    Args:
      mixture_or_task_name: str, the name of the Mixture or Task to evaluate on.
        Must be pre-registered in the global `TaskRegistry` or
        `MixtureRegistry.`
      checkpoint_steps: int, list of ints, or None. If an int or list of ints,
        evaluation will be run on the checkpoint files in `model_dir` whose
        global steps are closest to the global steps provided. If None, run eval
        continuously waiting for new checkpoints. If -1, get the latest
        checkpoint from the model directory.
      summary_dir: str, path to write TensorBoard events file summaries for
        eval. If None, use model_dir/eval_{split}.
      split: str, the mixture/task split to evaluate on.
    """
    if checkpoint_steps == -1:
      checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir)
    vocabulary = t5.models.mesh_transformer.get_vocabulary(mixture_or_task_name)
    dataset_fn = functools.partial(
        t5.models.mesh_transformer.mesh_eval_dataset_fn,
        mixture_or_task_name=mixture_or_task_name,
    )
    with gin.unlock_config():
      gin.parse_config_file(_operative_config_path(self._model_dir))
    utils.eval_model(self.estimator(vocabulary), vocabulary,
                     self._sequence_length, self.batch_size, split,
                     self._model_dir, dataset_fn, summary_dir, checkpoint_steps) 
开发者ID:google-research,项目名称:text-to-text-transfer-transformer,代码行数:31,代码来源:mtf_model.py

示例14: finetune

# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_file [as 别名]
def finetune(self, mixture_or_task_name, finetune_steps, pretrained_model_dir,
               pretrained_checkpoint_step=-1, split="train"):
    """Finetunes a model from an existing checkpoint.

    Args:
      mixture_or_task_name: str, the name of the Mixture or Task to evaluate on.
        Must be pre-registered in the global `TaskRegistry` or
        `MixtureRegistry.`
      finetune_steps: int, the number of additional steps to train for.
      pretrained_model_dir: str, directory with pretrained model checkpoints and
        operative config.
      pretrained_checkpoint_step: int, checkpoint to initialize weights from. If
        -1 (default), use the latest checkpoint from the pretrained model
        directory.
      split: str, the mixture/task split to finetune on.
    """
    if pretrained_checkpoint_step == -1:
      checkpoint_step = _get_latest_checkpoint_from_dir(pretrained_model_dir)
    else:
      checkpoint_step = pretrained_checkpoint_step
    with gin.unlock_config():
      gin.parse_config_file(_operative_config_path(pretrained_model_dir))

    model_ckpt = "model.ckpt-" + str(checkpoint_step)
    self.train(mixture_or_task_name, checkpoint_step + finetune_steps,
               init_checkpoint=os.path.join(pretrained_model_dir, model_ckpt),
               split=split) 
开发者ID:google-research,项目名称:text-to-text-transfer-transformer,代码行数:29,代码来源:mtf_model.py

示例15: predict

# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_file [as 别名]
def predict(self, input_file, output_file, checkpoint_steps=-1,
              beam_size=1, temperature=1.0, vocabulary=None):
    """Predicts targets from the given inputs.

    Args:
      input_file: str, path to a text file containing newline-separated input
        prompts to predict from.
      output_file: str, path prefix of output file to write predictions to. Note
        the checkpoint step will be appended to the given filename.
      checkpoint_steps: int, list of ints, or None. If an int or list of ints,
        inference will be run on the checkpoint files in `model_dir` whose
        global steps are closest to the global steps provided. If None, run
        inference continuously waiting for new checkpoints. If -1, get the
        latest checkpoint from the model directory.
      beam_size: int, a number >= 1 specifying the number of beams to use for
        beam search.
      temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1)
        0.0 means argmax, 1.0 means sample according to predicted distribution.
      vocabulary: vocabularies.Vocabulary object to use for tokenization, or
        None to use the default SentencePieceVocabulary.
    """
    # TODO(sharannarang) : It would be nice to have a function like
    # load_checkpoint that loads the model once and then call decode_from_file
    # multiple times without having to restore the checkpoint weights again.
    # This would be particularly useful in colab demo.

    if checkpoint_steps == -1:
      checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir)

    with gin.unlock_config():
      gin.parse_config_file(_operative_config_path(self._model_dir))
      gin.bind_parameter("Bitransformer.decode.beam_size", beam_size)
      gin.bind_parameter("Bitransformer.decode.temperature", temperature)

    if vocabulary is None:
      vocabulary = t5.data.get_default_vocabulary()
    utils.infer_model(
        self.estimator(vocabulary), vocabulary, self._sequence_length,
        self.batch_size, self._model_type, self._model_dir, checkpoint_steps,
        input_file, output_file) 
开发者ID:google-research,项目名称:text-to-text-transfer-transformer,代码行数:42,代码来源:mtf_model.py


注:本文中的gin.parse_config_file方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。