本文整理汇总了Python中gin.bind_parameter方法的典型用法代码示例。如果您正苦于以下问题:Python gin.bind_parameter方法的具体用法?Python gin.bind_parameter怎么用?Python gin.bind_parameter使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类gin
的用法示例。
在下文中一共展示了gin.bind_parameter方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: t2t_train
# 需要导入模块: import gin [as 别名]
# 或者: from gin import bind_parameter [as 别名]
def t2t_train(model_name, dataset_name,
data_dir=None, output_dir=None, config_file=None, config=None):
"""Main function to train the given model on the given dataset.
Args:
model_name: The name of the model to train.
dataset_name: The name of the dataset to train on.
data_dir: Directory where the data is located.
output_dir: Directory where to put the logs and checkpoints.
config_file: the gin configuration file to use.
config: string (in gin format) to override gin parameters.
"""
if model_name not in _MODEL_REGISTRY:
raise ValueError("Model %s not in registry. Available models:\n * %s." %
(model_name, "\n * ".join(_MODEL_REGISTRY.keys())))
model_class = _MODEL_REGISTRY[model_name]()
gin.bind_parameter("train_fn.model_class", model_class)
gin.bind_parameter("train_fn.dataset", dataset_name)
gin.parse_config_files_and_bindings(config_file, config)
# TODO(lukaszkaiser): save gin config in output_dir if provided?
train_fn(data_dir, output_dir=output_dir)
示例2: test_training_loop_onlinetune
# 需要导入模块: import gin [as 别名]
# 或者: from gin import bind_parameter [as 别名]
def test_training_loop_onlinetune(self):
with self.tmp_dir() as output_dir:
gin.bind_parameter("OnlineTuneEnv.model", functools.partial(
models.MLP,
n_hidden_layers=0,
n_output_classes=1,
))
gin.bind_parameter("OnlineTuneEnv.inputs", functools.partial(
trax_inputs.random_inputs,
input_shape=(1, 1),
input_dtype=np.float32,
output_shape=(1, 1),
output_dtype=np.float32,
))
gin.bind_parameter("OnlineTuneEnv.train_steps", 2)
gin.bind_parameter("OnlineTuneEnv.eval_steps", 2)
gin.bind_parameter(
"OnlineTuneEnv.output_dir", os.path.join(output_dir, "envs"))
self._run_training_loop(
env=self.get_wrapped_env("OnlineTuneEnv-v0", 2),
eval_env=self.get_wrapped_env("OnlineTuneEnv-v0", 2),
output_dir=output_dir,
)
示例3: test_wrapped_policy_continuous
# 需要导入模块: import gin [as 别名]
# 或者: from gin import bind_parameter [as 别名]
def test_wrapped_policy_continuous(self, vocab_size):
precision = 3
n_controls = 2
n_actions = 4
gin.bind_parameter('BoxSpaceSerializer.precision', precision)
obs = np.array([[[1.5, 2], [-0.3, 1.23], [0.84, 0.07], [0.01, 0.66]]])
act = np.array([[[0, 1], [2, 0], [1, 3]]])
wrapped_policy = serialization_utils.wrap_policy(
TestModel(extra_dim=vocab_size), # pylint: disable=no-value-for-parameter
observation_space=gym.spaces.Box(shape=(2,), low=-2, high=2),
action_space=gym.spaces.MultiDiscrete([n_actions] * n_controls),
vocab_size=vocab_size,
)
example = (obs, act)
wrapped_policy.init(shapes.signature(example))
(act_logits, values) = wrapped_policy(example)
self.assertEqual(act_logits.shape, obs.shape[:2] + (n_controls, n_actions))
self.assertEqual(values.shape, obs.shape[:2])
示例4: test_reformer_wmt_ende
# 需要导入模块: import gin [as 别名]
# 或者: from gin import bind_parameter [as 别名]
def test_reformer_wmt_ende(self):
trax.fastmath.disable_jit()
batch_size_per_device = 2
steps = 1
n_layers = 2
d_ff = 32
gin.parse_config_file('reformer_wmt_ende.gin')
gin.bind_parameter('data_streams.data_dir', _TESTDATA)
gin.bind_parameter('batcher.batch_size_per_device', batch_size_per_device)
gin.bind_parameter('train.steps', steps)
gin.bind_parameter('Reformer.n_encoder_layers', n_layers)
gin.bind_parameter('Reformer.n_decoder_layers', n_layers)
gin.bind_parameter('Reformer.d_ff', d_ff)
with self.tmp_dir() as output_dir:
_ = trainer_lib.train(output_dir=output_dir)
示例5: test_reformer_noencdecattn_wmt_ende
# 需要导入模块: import gin [as 别名]
# 或者: from gin import bind_parameter [as 别名]
def test_reformer_noencdecattn_wmt_ende(self):
trax.fastmath.disable_jit()
batch_size_per_device = 1 # Ignored, but needs to be set.
steps = 1
n_layers = 2
d_ff = 32
gin.parse_config_file('reformer_noencdecattn_wmt_ende.gin')
gin.bind_parameter('data_streams.data_dir', _TESTDATA)
gin.bind_parameter('batcher.batch_size_per_device', batch_size_per_device)
gin.bind_parameter('batcher.buckets', ([513], [1, 1])) # batch size 1.
gin.bind_parameter('train.steps', steps)
gin.bind_parameter('ReformerNoEncDecAttention.n_encoder_layers', n_layers)
gin.bind_parameter('ReformerNoEncDecAttention.n_decoder_layers', n_layers)
gin.bind_parameter('ReformerNoEncDecAttention.d_ff', d_ff)
with self.tmp_dir() as output_dir:
_ = trainer_lib.train(output_dir=output_dir)
示例6: setUp
# 需要导入模块: import gin [as 别名]
# 或者: from gin import bind_parameter [as 别名]
def setUp(self):
super(PoseEnvModelsTest, self).setUp()
base_dir = 'tensor2robot'
test_data = os.path.join(FLAGS.test_srcdir,
base_dir,
'test_data/pose_env_test_data.tfrecord')
self._train_log_dir = FLAGS.test_tmpdir
if tf.io.gfile.exists(self._train_log_dir):
tf.io.gfile.rmtree(self._train_log_dir)
gin.bind_parameter('train_eval_model.max_train_steps', 3)
gin.bind_parameter('train_eval_model.eval_steps', 2)
self._record_input_generator = (
default_input_generator.DefaultRecordInputGenerator(
batch_size=BATCH_SIZE, file_patterns=test_data))
self._meta_record_input_generator_train = (
default_input_generator.DefaultRandomInputGenerator(
batch_size=BATCH_SIZE))
self._meta_record_input_generator_eval = (
default_input_generator.DefaultRandomInputGenerator(
batch_size=BATCH_SIZE))
示例7: test_run_pose_env_collect
# 需要导入模块: import gin [as 别名]
# 或者: from gin import bind_parameter [as 别名]
def test_run_pose_env_collect(self, demo_policy_cls):
urdf_root = pose_env.get_pybullet_urdf_root()
config_dir = 'research/pose_env/configs'
gin_config = os.path.join(
FLAGS.test_srcdir, config_dir, 'run_random_collect.gin')
gin.parse_config_file(gin_config)
tmp_dir = absltest.get_default_test_tmpdir()
root_dir = os.path.join(tmp_dir, str(demo_policy_cls))
gin.bind_parameter('PoseToyEnv.urdf_root', urdf_root)
gin.bind_parameter(
'collect_eval_loop.root_dir', root_dir)
gin.bind_parameter('run_meta_env.num_tasks', 2)
gin.bind_parameter('run_meta_env.num_episodes_per_adaptation', 1)
gin.bind_parameter(
'collect_eval_loop.policy_class', demo_policy_cls)
continuous_collect_eval.collect_eval_loop()
output_files = tf.io.gfile.glob(os.path.join(
root_dir, 'policy_collect', '*.tfrecord'))
self.assertLen(output_files, 2)
示例8: _runSingleTrainingStep
# 需要导入模块: import gin [as 别名]
# 或者: from gin import bind_parameter [as 别名]
def _runSingleTrainingStep(self, architecture, loss_fn, penalty_fn):
parameters = {
"architecture": architecture,
"lambda": 1,
"z_dim": 128,
}
with gin.unlock_config():
gin.bind_parameter("penalty.fn", penalty_fn)
gin.bind_parameter("loss.fn", loss_fn)
model_dir = self._get_empty_model_dir()
run_config = tf.contrib.tpu.RunConfig(
model_dir=model_dir,
tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1))
dataset = datasets.get_dataset("cifar10")
gan = SSGAN(
dataset=dataset,
parameters=parameters,
model_dir=model_dir,
g_optimizer_fn=tf.train.AdamOptimizer,
g_lr=0.0002,
rotated_batch_size=4)
estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=False)
estimator.train(gan.input_fn, steps=1)
示例9: _runSingleTrainingStep
# 需要导入模块: import gin [as 别名]
# 或者: from gin import bind_parameter [as 别名]
def _runSingleTrainingStep(self, architecture, loss_fn, penalty_fn):
parameters = {
"architecture": architecture,
"lambda": 1,
"z_dim": 128,
}
with gin.unlock_config():
gin.bind_parameter("penalty.fn", penalty_fn)
gin.bind_parameter("loss.fn", loss_fn)
dataset = datasets.get_dataset("cifar10")
gan = ModularGAN(
dataset=dataset,
parameters=parameters,
model_dir=self.model_dir,
conditional="biggan" in architecture)
estimator = gan.as_estimator(self.run_config, batch_size=2, use_tpu=False)
estimator.train(gan.input_fn, steps=1)
示例10: testSingleTrainingStepDiscItersWithEma
# 需要导入模块: import gin [as 别名]
# 或者: from gin import bind_parameter [as 别名]
def testSingleTrainingStepDiscItersWithEma(self, disc_iters):
parameters = {
"architecture": c.DUMMY_ARCH,
"lambda": 1,
"z_dim": 128,
"dics_iters": disc_iters,
}
gin.bind_parameter("ModularGAN.g_use_ema", True)
dataset = datasets.get_dataset("cifar10")
gan = ModularGAN(
dataset=dataset,
parameters=parameters,
model_dir=self.model_dir)
estimator = gan.as_estimator(self.run_config, batch_size=2, use_tpu=False)
estimator.train(gan.input_fn, steps=1)
# Check for moving average variables in checkpoint.
checkpoint_path = tf.train.latest_checkpoint(self.model_dir)
ema_vars = sorted([v[0] for v in tf.train.list_variables(checkpoint_path)
if v[0].endswith("ExponentialMovingAverage")])
tf.logging.info("ema_vars=%s", ema_vars)
expected_ema_vars = sorted([
"generator/fc_noise/kernel/ExponentialMovingAverage",
"generator/fc_noise/bias/ExponentialMovingAverage",
])
self.assertAllEqual(ema_vars, expected_ema_vars)
示例11: testUnlabledDatasetRaisesError
# 需要导入模块: import gin [as 别名]
# 或者: from gin import bind_parameter [as 别名]
def testUnlabledDatasetRaisesError(self):
parameters = {
"architecture": c.RESNET_CIFAR_ARCH,
"lambda": 1,
"z_dim": 120,
}
with gin.unlock_config():
gin.bind_parameter("loss.fn", loss_lib.hinge)
# Use dataset without labels.
dataset = datasets.get_dataset("celeb_a")
model_dir = self._get_empty_model_dir()
with self.assertRaises(ValueError):
gan = ModularGAN(
dataset=dataset,
parameters=parameters,
conditional=True,
model_dir=model_dir)
del gan
示例12: testBatchNormTwoCoresCustom
# 需要导入模块: import gin [as 别名]
# 或者: from gin import bind_parameter [as 别名]
def testBatchNormTwoCoresCustom(self):
def computation(x):
custom_bn = arch_ops.batch_norm(x, is_training=True, name="custom_bn")
gin.bind_parameter("cross_replica_moments.parallel", False)
custom_bn_seq = arch_ops.batch_norm(x, is_training=True,
name="custom_bn_seq")
return custom_bn, custom_bn_seq
with tf.Graph().as_default():
x = tf.constant(self._inputs)
custom_bn, custom_bn_seq = tf.contrib.tpu.batch_parallel(
computation, [x], num_shards=2)
with self.session() as sess:
sess.run(tf.contrib.tpu.initialize_system())
sess.run(tf.global_variables_initializer())
custom_bn, custom_bn_seq = sess.run(
[custom_bn, custom_bn_seq])
logging.info("custom_bn: %s", custom_bn)
logging.info("custom_bn_seq: %s", custom_bn_seq)
self.assertAllClose(custom_bn, self._expected_outputs)
self.assertAllClose(custom_bn_seq, self._expected_outputs)
示例13: testInitializersRandomNormal
# 需要导入模块: import gin [as 别名]
# 或者: from gin import bind_parameter [as 别名]
def testInitializersRandomNormal(self):
gin.bind_parameter("weights.initializer", consts.NORMAL_INIT)
valid_initalizer = [
"kernel/Initializer/random_normal",
"bias/Initializer/Const",
"kernel/Initializer/random_normal",
"bias/Initializer/Const",
"beta/Initializer/zeros",
"gamma/Initializer/ones",
]
valid_op_names = "/({}):0$".format("|".join(valid_initalizer))
with tf.Graph().as_default():
z = tf.zeros((2, 128))
fake_image = resnet5.Generator(image_shape=(128, 128, 3))(
z, y=None, is_training=True)
resnet5.Discriminator()(fake_image, y=None, is_training=True)
for var in tf.trainable_variables():
op_name = var.initializer.inputs[1].name
self.assertRegex(op_name, valid_op_names)
示例14: testInitializersTruncatedNormal
# 需要导入模块: import gin [as 别名]
# 或者: from gin import bind_parameter [as 别名]
def testInitializersTruncatedNormal(self):
gin.bind_parameter("weights.initializer", consts.TRUNCATED_INIT)
valid_initalizer = [
"kernel/Initializer/truncated_normal",
"bias/Initializer/Const",
"kernel/Initializer/truncated_normal",
"bias/Initializer/Const",
"beta/Initializer/zeros",
"gamma/Initializer/ones",
]
valid_op_names = "/({}):0$".format("|".join(valid_initalizer))
with tf.Graph().as_default():
z = tf.zeros((2, 128))
fake_image = resnet5.Generator(image_shape=(128, 128, 3))(
z, y=None, is_training=True)
resnet5.Discriminator()(fake_image, y=None, is_training=True)
for var in tf.trainable_variables():
op_name = var.initializer.inputs[1].name
self.assertRegex(op_name, valid_op_names)
示例15: test_serialized_model_continuous
# 需要导入模块: import gin [as 别名]
# 或者: from gin import bind_parameter [as 别名]
def test_serialized_model_continuous(self):
precision = 3
gin.bind_parameter('BoxSpaceSerializer.precision', precision)
vocab_size = 32
obs = np.array([[[1.5, 2], [-0.3, 1.23], [0.84, 0.07], [0, 0]]])
act = np.array([[0, 1, 0]])
mask = np.array([[1, 1, 1, 0]])
obs_serializer = space_serializer.create(
gym.spaces.Box(shape=(2,), low=-2, high=2), vocab_size=vocab_size
)
act_serializer = space_serializer.create(
gym.spaces.Discrete(2), vocab_size=vocab_size
)
serialized_model = serialization_utils.SerializedModel(
TestModel(extra_dim=vocab_size), # pylint: disable=no-value-for-parameter
observation_serializer=obs_serializer,
action_serializer=act_serializer,
significance_decay=0.9,
)
example = (obs, act, obs, mask)
serialized_model.init(shapes.signature(example))
(obs_logits, obs_repr, weights) = serialized_model(example)
self.assertEqual(obs_logits.shape, obs_repr.shape + (vocab_size,))
self.assertEqual(
obs_repr.shape, (1, obs.shape[1], obs.shape[2] * precision)
)
self.assertEqual(obs_repr.shape, weights.shape)