本文整理汇总了Python中gin.parse_config_file方法的典型用法代码示例。如果您正苦于以下问题:Python gin.parse_config_file方法的具体用法?Python gin.parse_config_file怎么用?Python gin.parse_config_file使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类gin
的用法示例。
在下文中一共展示了gin.parse_config_file方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: __init__
# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_file [as 别名]
def __init__(self, data_path, mode, train_set, validation_set, test_set, max_way_train, max_way_test, max_support_train, max_support_test):
self.data_path = data_path
self.train_dataset_next_task = None
self.validation_set_dict = {}
self.test_set_dict = {}
gin.parse_config_file('./meta_dataset_config.gin')
if mode == 'train' or mode == 'train_test':
train_episode_description = self._get_train_episode_description(max_way_train, max_support_train)
self.train_dataset_next_task = self._init_multi_source_dataset(train_set, learning_spec.Split.TRAIN,
train_episode_description)
test_episode_description = self._get_test_episode_description(max_way_test, max_support_test)
for item in validation_set:
next_task = self.validation_dataset = self._init_single_source_dataset(item, learning_spec.Split.VALID,
test_episode_description)
self.validation_set_dict[item] = next_task
if mode == 'test' or mode == 'train_test':
test_episode_description = self._get_test_episode_description(max_way_test, max_support_test)
for item in test_set:
next_task = self._init_single_source_dataset(item, learning_spec.Split.TEST, test_episode_description)
self.test_set_dict[item] = next_task
示例2: test_reformer_wmt_ende
# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_file [as 别名]
def test_reformer_wmt_ende(self):
trax.fastmath.disable_jit()
batch_size_per_device = 2
steps = 1
n_layers = 2
d_ff = 32
gin.parse_config_file('reformer_wmt_ende.gin')
gin.bind_parameter('data_streams.data_dir', _TESTDATA)
gin.bind_parameter('batcher.batch_size_per_device', batch_size_per_device)
gin.bind_parameter('train.steps', steps)
gin.bind_parameter('Reformer.n_encoder_layers', n_layers)
gin.bind_parameter('Reformer.n_decoder_layers', n_layers)
gin.bind_parameter('Reformer.d_ff', d_ff)
with self.tmp_dir() as output_dir:
_ = trainer_lib.train(output_dir=output_dir)
示例3: test_reformer_noencdecattn_wmt_ende
# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_file [as 别名]
def test_reformer_noencdecattn_wmt_ende(self):
trax.fastmath.disable_jit()
batch_size_per_device = 1 # Ignored, but needs to be set.
steps = 1
n_layers = 2
d_ff = 32
gin.parse_config_file('reformer_noencdecattn_wmt_ende.gin')
gin.bind_parameter('data_streams.data_dir', _TESTDATA)
gin.bind_parameter('batcher.batch_size_per_device', batch_size_per_device)
gin.bind_parameter('batcher.buckets', ([513], [1, 1])) # batch size 1.
gin.bind_parameter('train.steps', steps)
gin.bind_parameter('ReformerNoEncDecAttention.n_encoder_layers', n_layers)
gin.bind_parameter('ReformerNoEncDecAttention.n_decoder_layers', n_layers)
gin.bind_parameter('ReformerNoEncDecAttention.d_ff', d_ff)
with self.tmp_dir() as output_dir:
_ = trainer_lib.train(output_dir=output_dir)
示例4: test_run_pose_env_collect
# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_file [as 别名]
def test_run_pose_env_collect(self, demo_policy_cls):
urdf_root = pose_env.get_pybullet_urdf_root()
config_dir = 'research/pose_env/configs'
gin_config = os.path.join(
FLAGS.test_srcdir, config_dir, 'run_random_collect.gin')
gin.parse_config_file(gin_config)
tmp_dir = absltest.get_default_test_tmpdir()
root_dir = os.path.join(tmp_dir, str(demo_policy_cls))
gin.bind_parameter('PoseToyEnv.urdf_root', urdf_root)
gin.bind_parameter(
'collect_eval_loop.root_dir', root_dir)
gin.bind_parameter('run_meta_env.num_tasks', 2)
gin.bind_parameter('run_meta_env.num_episodes_per_adaptation', 1)
gin.bind_parameter(
'collect_eval_loop.policy_class', demo_policy_cls)
continuous_collect_eval.collect_eval_loop()
output_files = tf.io.gfile.glob(os.path.join(
root_dir, 'policy_collect', '*.tfrecord'))
self.assertLen(output_files, 2)
示例5: setUp
# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_file [as 别名]
def setUp(self):
super(EvalMetricsTest, self).setUp()
gin.clear_config()
gin_file = os.path.join(
'./',
'rl_reliability_metrics/evaluation',
'eval_metrics_test.gin')
gin.parse_config_file(gin_file)
# fake set of training curves to test analysis
self.test_data_dir = os.path.join(
'./',
'rl_reliability_metrics/evaluation/test_data')
self.run_dirs = [
os.path.join(self.test_data_dir, 'run%d' % i, 'train') for i in range(3)
]
示例6: setUp
# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_file [as 别名]
def setUp(self):
super(DataLoadingTest, self).setUp()
gin.clear_config()
gin_file = os.path.join(
'./',
'rl_reliability_metrics/evaluation',
'eval_metrics_test.gin')
gin.parse_config_file(gin_file)
# fake set of training curves to test analysis
test_data_dir = os.path.join(
'./',
'rl_reliability_metrics/evaluation/test_data')
self.run_dirs = [
os.path.join(test_data_dir, 'run%d' % i, 'train') for i in range(3)
]
示例7: main
# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_file [as 别名]
def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
gin.parse_config_file(FLAGS.gin_config_path)
runner = runner_lib.Runner()
results = runner.run()
logging.info('Results: %s', results)
with open(FLAGS.output_path, 'w') as f:
f.write(core.to_json(results))
示例8: parse_gin_defaults_and_flags
# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_file [as 别名]
def parse_gin_defaults_and_flags():
"""Parses all default gin files and those provided via flags."""
# Register .gin file search paths with gin
for gin_file_path in FLAGS.gin_location_prefix:
gin.add_config_file_search_path(gin_file_path)
# Set up the default values for the configurable parameters. These values will
# be overridden by any user provided gin files/parameters.
gin.parse_config_file(
pkg_resources.resource_filename(__name__, _DEFAULT_CONFIG_FILE))
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
# TODO(noam): maybe add gin-config to mtf.get_variable so we can delete
# this stupid VariableDtype class and stop passing it all over creation.
示例9: testGinConfig
# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_file [as 别名]
def testGinConfig(self):
gin.parse_config_file(
test_utils.test_src_dir_path('environments/configs/suite_pybullet.gin')
)
env = suite_pybullet.load()
self.assertIsInstance(env, py_environment.PyEnvironment)
self.assertIsInstance(env, wrappers.TimeLimit)
示例10: testGinConfig
# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_file [as 别名]
def testGinConfig(self):
gin.parse_config_file(
test_utils.test_src_dir_path('environments/configs/suite_gym.gin')
)
env = suite_gym.load()
self.assertIsInstance(env, py_environment.PyEnvironment)
self.assertIsInstance(env, wrappers.TimeLimit)
示例11: testGinConfig
# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_file [as 别名]
def testGinConfig(self):
gin.parse_config_file(
test_utils.test_src_dir_path('environments/configs/suite_mujoco.gin')
)
env = suite_mujoco.load()
self.assertIsInstance(env, py_environment.PyEnvironment)
self.assertIsInstance(env, wrappers.TimeLimit)
示例12: testGinConfig
# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_file [as 别名]
def testGinConfig(self):
gin.parse_config_file(
test_utils.test_src_dir_path('environments/configs/suite_bsuite.gin')
)
env = suite_bsuite.load()
self.assertIsInstance(env, py_environment.PyEnvironment)
示例13: eval
# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_file [as 别名]
def eval(self, mixture_or_task_name, checkpoint_steps=None, summary_dir=None,
split="validation"):
"""Evaluate the model on the given Mixture or Task.
Args:
mixture_or_task_name: str, the name of the Mixture or Task to evaluate on.
Must be pre-registered in the global `TaskRegistry` or
`MixtureRegistry.`
checkpoint_steps: int, list of ints, or None. If an int or list of ints,
evaluation will be run on the checkpoint files in `model_dir` whose
global steps are closest to the global steps provided. If None, run eval
continuously waiting for new checkpoints. If -1, get the latest
checkpoint from the model directory.
summary_dir: str, path to write TensorBoard events file summaries for
eval. If None, use model_dir/eval_{split}.
split: str, the mixture/task split to evaluate on.
"""
if checkpoint_steps == -1:
checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir)
vocabulary = t5.models.mesh_transformer.get_vocabulary(mixture_or_task_name)
dataset_fn = functools.partial(
t5.models.mesh_transformer.mesh_eval_dataset_fn,
mixture_or_task_name=mixture_or_task_name,
)
with gin.unlock_config():
gin.parse_config_file(_operative_config_path(self._model_dir))
utils.eval_model(self.estimator(vocabulary), vocabulary,
self._sequence_length, self.batch_size, split,
self._model_dir, dataset_fn, summary_dir, checkpoint_steps)
示例14: finetune
# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_file [as 别名]
def finetune(self, mixture_or_task_name, finetune_steps, pretrained_model_dir,
pretrained_checkpoint_step=-1, split="train"):
"""Finetunes a model from an existing checkpoint.
Args:
mixture_or_task_name: str, the name of the Mixture or Task to evaluate on.
Must be pre-registered in the global `TaskRegistry` or
`MixtureRegistry.`
finetune_steps: int, the number of additional steps to train for.
pretrained_model_dir: str, directory with pretrained model checkpoints and
operative config.
pretrained_checkpoint_step: int, checkpoint to initialize weights from. If
-1 (default), use the latest checkpoint from the pretrained model
directory.
split: str, the mixture/task split to finetune on.
"""
if pretrained_checkpoint_step == -1:
checkpoint_step = _get_latest_checkpoint_from_dir(pretrained_model_dir)
else:
checkpoint_step = pretrained_checkpoint_step
with gin.unlock_config():
gin.parse_config_file(_operative_config_path(pretrained_model_dir))
model_ckpt = "model.ckpt-" + str(checkpoint_step)
self.train(mixture_or_task_name, checkpoint_step + finetune_steps,
init_checkpoint=os.path.join(pretrained_model_dir, model_ckpt),
split=split)
示例15: predict
# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_file [as 别名]
def predict(self, input_file, output_file, checkpoint_steps=-1,
beam_size=1, temperature=1.0, vocabulary=None):
"""Predicts targets from the given inputs.
Args:
input_file: str, path to a text file containing newline-separated input
prompts to predict from.
output_file: str, path prefix of output file to write predictions to. Note
the checkpoint step will be appended to the given filename.
checkpoint_steps: int, list of ints, or None. If an int or list of ints,
inference will be run on the checkpoint files in `model_dir` whose
global steps are closest to the global steps provided. If None, run
inference continuously waiting for new checkpoints. If -1, get the
latest checkpoint from the model directory.
beam_size: int, a number >= 1 specifying the number of beams to use for
beam search.
temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1)
0.0 means argmax, 1.0 means sample according to predicted distribution.
vocabulary: vocabularies.Vocabulary object to use for tokenization, or
None to use the default SentencePieceVocabulary.
"""
# TODO(sharannarang) : It would be nice to have a function like
# load_checkpoint that loads the model once and then call decode_from_file
# multiple times without having to restore the checkpoint weights again.
# This would be particularly useful in colab demo.
if checkpoint_steps == -1:
checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir)
with gin.unlock_config():
gin.parse_config_file(_operative_config_path(self._model_dir))
gin.bind_parameter("Bitransformer.decode.beam_size", beam_size)
gin.bind_parameter("Bitransformer.decode.temperature", temperature)
if vocabulary is None:
vocabulary = t5.data.get_default_vocabulary()
utils.infer_model(
self.estimator(vocabulary), vocabulary, self._sequence_length,
self.batch_size, self._model_type, self._model_dir, checkpoint_steps,
input_file, output_file)