本文整理汇总了Python中gin.parse_config_files_and_bindings方法的典型用法代码示例。如果您正苦于以下问题:Python gin.parse_config_files_and_bindings方法的具体用法?Python gin.parse_config_files_and_bindings怎么用?Python gin.parse_config_files_and_bindings使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类gin
的用法示例。
在下文中一共展示了gin.parse_config_files_and_bindings方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: t2t_train
# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_files_and_bindings [as 别名]
def t2t_train(model_name, dataset_name,
data_dir=None, output_dir=None, config_file=None, config=None):
"""Main function to train the given model on the given dataset.
Args:
model_name: The name of the model to train.
dataset_name: The name of the dataset to train on.
data_dir: Directory where the data is located.
output_dir: Directory where to put the logs and checkpoints.
config_file: the gin configuration file to use.
config: string (in gin format) to override gin parameters.
"""
if model_name not in _MODEL_REGISTRY:
raise ValueError("Model %s not in registry. Available models:\n * %s." %
(model_name, "\n * ".join(_MODEL_REGISTRY.keys())))
model_class = _MODEL_REGISTRY[model_name]()
gin.bind_parameter("train_fn.model_class", model_class)
gin.bind_parameter("train_fn.dataset", dataset_name)
gin.parse_config_files_and_bindings(config_file, config)
# TODO(lukaszkaiser): save gin config in output_dir if provided?
train_fn(data_dir, output_dir=output_dir)
示例2: _setup_gin
# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_files_and_bindings [as 别名]
def _setup_gin():
"""Setup gin configuration."""
# Imports for configurables
# pylint: disable=g-import-not-at-top,unused-import,g-bad-import-order,reimported,unused-variable
from tensor2tensor.trax import models as _trax_models
from tensor2tensor.trax import optimizers as _trax_opt
# pylint: disable=g-import-not-at-top,unused-import,g-bad-import-order,reimported,unused-variable
configs = FLAGS.config or []
# Override with --dataset and --model
if FLAGS.dataset:
configs.append("inputs.dataset_name='%s'" % FLAGS.dataset)
if FLAGS.data_dir:
configs.append("inputs.data_dir='%s'" % FLAGS.data_dir)
if FLAGS.model:
configs.append("train.model=@trax.models.%s" % FLAGS.model)
gin.parse_config_files_and_bindings(FLAGS.config_file, configs)
示例3: _gin_parse_configs
# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_files_and_bindings [as 别名]
def _gin_parse_configs():
"""Initializes gin-controlled bindings."""
# Imports for configurables
# pylint: disable=g-import-not-at-top,unused-import,g-bad-import-order,reimported,unused-variable
from trax import models as _trax_models
from trax import optimizers as _trax_opt
# pylint: disable=g-import-not-at-top,unused-import,g-bad-import-order,reimported,unused-variable
configs = FLAGS.config or []
# Override with --dataset and --model
if FLAGS.dataset:
configs.append("data_streams.dataset_name='%s'" % FLAGS.dataset)
if FLAGS.data_dir:
configs.append("data_streams.data_dir='%s'" % FLAGS.data_dir)
if FLAGS.model:
configs.append('train.model=@trax.models.%s' % FLAGS.model)
gin.parse_config_files_and_bindings(FLAGS.config_file, configs)
示例4: main
# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_files_and_bindings [as 别名]
def main(argv):
del argv
logging.info('Starting RL training.')
gin_configs = FLAGS.config or []
gin.parse_config_files_and_bindings(FLAGS.config_file, gin_configs)
logging.info('Gin cofig:')
logging.info(gin_configs)
train_rl(
output_dir=FLAGS.output_dir,
train_batch_size=FLAGS.train_batch_size,
eval_batch_size=FLAGS.eval_batch_size,
trajectory_dump_dir=(FLAGS.trajectory_dump_dir or None),
)
# TODO(afrozm): This is for debugging.
logging.info('Dumping stack traces of all stacks.')
faulthandler.dump_traceback(all_threads=True)
logging.info('Training is done, should exit.')
示例5: test_compress_image
# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_files_and_bindings [as 别名]
def test_compress_image(self):
if not common.has_eager_been_enabled():
self.skipTest("Image compression only supported in TF2.x")
gin.parse_config_files_and_bindings([], """
_get_feature_encoder.compress_image=True
_get_feature_parser.compress_image=True
""")
spec = {
"image": array_spec.ArraySpec((128, 128, 3), np.uint8)
}
serializer = example_encoding.get_example_serializer(spec)
decoder = example_encoding.get_example_decoder(spec)
sample = {
"image": 128 * np.ones([128, 128, 3], dtype=np.uint8)
}
example_proto = serializer(sample)
recovered = self.evaluate(decoder(example_proto))
tf.nest.map_structure(np.testing.assert_almost_equal, sample, recovered)
示例6: main
# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_files_and_bindings [as 别名]
def main(unused_argv):
logging.info("Gin config: %s\nGin bindings: %s",
FLAGS.gin_config, FLAGS.gin_bindings)
gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings)
if FLAGS.use_tpu is None:
FLAGS.use_tpu = bool(os.environ.get("TPU_NAME", ""))
if FLAGS.use_tpu:
logging.info("Found TPU %s.", os.environ["TPU_NAME"])
run_config = _get_run_config()
task_manager = _get_task_manager()
options = runner_lib.get_options_dict()
runner_lib.run_with_schedule(
schedule=FLAGS.schedule,
run_config=run_config,
task_manager=task_manager,
options=options,
use_tpu=FLAGS.use_tpu,
num_eval_averaging_runs=FLAGS.num_eval_averaging_runs,
eval_every_steps=FLAGS.eval_every_steps)
logging.info("I\"m done with my work, ciao!")
示例7: main
# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_files_and_bindings [as 别名]
def main(_):
gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_params)
print('********')
print(FLAGS.mode)
print(FLAGS.gin_config)
print(FLAGS.gin_params)
env = active_vision_dataset_env.ActiveVisionDatasetEnv(modality_types=[
task_env.ModalityTypes.IMAGE,
task_env.ModalityTypes.SEMANTIC_SEGMENTATION,
task_env.ModalityTypes.OBJECT_DETECTION, task_env.ModalityTypes.DEPTH,
task_env.ModalityTypes.DISTANCE
])
if FLAGS.mode == BENCHMARK_MODE:
benchmark(env, env.possible_targets)
elif FLAGS.mode == GRAPH_MODE:
for loc in env.worlds:
env.check_scene_graph(loc, 'fridge')
elif FLAGS.mode == HUMAN_MODE:
human(env, env.possible_targets)
elif FLAGS.mode == VIS_MODE:
visualize_random_step_sequence(env)
elif FLAGS.mode == EVAL_MODE:
evaluate_folder(env, FLAGS.eval_folder)
示例8: evaluate_metrics
# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_files_and_bindings [as 别名]
def evaluate_metrics():
"""Evaluates metrics specified in the gin config."""
# Parse gin config.
gin.parse_config_files_and_bindings([p.gin_file], [])
for algo in p.algos:
for task in p.tasks:
# Get the subdirectories corresponding to each run.
summary_path = os.path.join(p.data_dir, algo, task)
run_dirs = eval_metrics.get_run_dirs(summary_path, 'train', p.runs)
# Evaluate metrics.
outfile_prefix = os.path.join(p.metric_values_dir, algo, task) + '/'
evaluator = eval_metrics.Evaluator(metrics=gin.REQUIRED)
evaluator.write_metric_params(outfile_prefix)
evaluator.evaluate(run_dirs=run_dirs, outfile_prefix=outfile_prefix)
示例9: override_gin
# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_files_and_bindings [as 别名]
def override_gin(self, bindings):
gin.parse_config_files_and_bindings(None, bindings)
示例10: main
# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_files_and_bindings [as 别名]
def main(unused_argv):
gin.parse_config_files_and_bindings(FLAGS.gin_configs, FLAGS.gin_bindings)
train_eval.train_eval_model()
示例11: main
# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_files_and_bindings [as 别名]
def main(unused_argv):
del unused_argv
gin.parse_config_files_and_bindings(FLAGS.gin_configs, FLAGS.gin_bindings)
continuous_collect_eval.collect_eval_loop(root_dir=FLAGS.root_dir)
示例12: parse_gin_defaults_and_flags
# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_files_and_bindings [as 别名]
def parse_gin_defaults_and_flags():
"""Parses all default gin files and those provided via flags."""
# Register .gin file search paths with gin
for gin_file_path in FLAGS.gin_location_prefix:
gin.add_config_file_search_path(gin_file_path)
# Set up the default values for the configurable parameters. These values will
# be overridden by any user provided gin files/parameters.
gin.parse_config_file(
pkg_resources.resource_filename(__name__, _DEFAULT_CONFIG_FILE))
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
# TODO(noam): maybe add gin-config to mtf.get_variable so we can delete
# this stupid VariableDtype class and stop passing it all over creation.
示例13: main
# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_files_and_bindings [as 别名]
def main(_):
logging.set_verbosity(logging.INFO)
tf.compat.v1.enable_v2_behavior()
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
train_eval(FLAGS.root_dir, num_iterations=FLAGS.num_iterations)
示例14: main
# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_files_and_bindings [as 别名]
def main(_):
tf.compat.v1.enable_v2_behavior()
logging.set_verbosity(logging.INFO)
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
train_eval(FLAGS.root_dir, num_iterations=FLAGS.num_iterations)
示例15: main
# 需要导入模块: import gin [as 别名]
# 或者: from gin import parse_config_files_and_bindings [as 别名]
def main(_):
tf.compat.v1.enable_resource_variables()
logging.set_verbosity(logging.INFO)
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
train_eval(FLAGS.root_dir)