当前位置: 首页>>代码示例>>Python>>正文


Python gin.bind_parameter方法代码示例

本文整理汇总了Python中gin.bind_parameter方法的典型用法代码示例。如果您正苦于以下问题:Python gin.bind_parameter方法的具体用法?Python gin.bind_parameter怎么用?Python gin.bind_parameter使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在gin的用法示例。


在下文中一共展示了gin.bind_parameter方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: t2t_train

# 需要导入模块: import gin [as 别名]
# 或者: from gin import bind_parameter [as 别名]
def t2t_train(model_name, dataset_name,
              data_dir=None, output_dir=None, config_file=None, config=None):
  """Main function to train the given model on the given dataset.

  Args:
    model_name: The name of the model to train.
    dataset_name: The name of the dataset to train on.
    data_dir: Directory where the data is located.
    output_dir: Directory where to put the logs and checkpoints.
    config_file: the gin configuration file to use.
    config: string (in gin format) to override gin parameters.
  """
  if model_name not in _MODEL_REGISTRY:
    raise ValueError("Model %s not in registry. Available models:\n * %s." %
                     (model_name, "\n * ".join(_MODEL_REGISTRY.keys())))
  model_class = _MODEL_REGISTRY[model_name]()
  gin.bind_parameter("train_fn.model_class", model_class)
  gin.bind_parameter("train_fn.dataset", dataset_name)
  gin.parse_config_files_and_bindings(config_file, config)
  # TODO(lukaszkaiser): save gin config in output_dir if provided?
  train_fn(data_dir, output_dir=output_dir) 
开发者ID:yyht,项目名称:BERT,代码行数:23,代码来源:t2t.py

示例2: test_training_loop_onlinetune

# 需要导入模块: import gin [as 别名]
# 或者: from gin import bind_parameter [as 别名]
def test_training_loop_onlinetune(self):
    with self.tmp_dir() as output_dir:
      gin.bind_parameter("OnlineTuneEnv.model", functools.partial(
          models.MLP,
          n_hidden_layers=0,
          n_output_classes=1,
      ))
      gin.bind_parameter("OnlineTuneEnv.inputs", functools.partial(
          trax_inputs.random_inputs,
          input_shape=(1, 1),
          input_dtype=np.float32,
          output_shape=(1, 1),
          output_dtype=np.float32,
      ))
      gin.bind_parameter("OnlineTuneEnv.train_steps", 2)
      gin.bind_parameter("OnlineTuneEnv.eval_steps", 2)
      gin.bind_parameter(
          "OnlineTuneEnv.output_dir", os.path.join(output_dir, "envs"))
      self._run_training_loop(
          env=self.get_wrapped_env("OnlineTuneEnv-v0", 2),
          eval_env=self.get_wrapped_env("OnlineTuneEnv-v0", 2),
          output_dir=output_dir,
      ) 
开发者ID:yyht,项目名称:BERT,代码行数:25,代码来源:ppo_training_loop_test.py

示例3: test_wrapped_policy_continuous

# 需要导入模块: import gin [as 别名]
# 或者: from gin import bind_parameter [as 别名]
def test_wrapped_policy_continuous(self, vocab_size):
    precision = 3
    n_controls = 2
    n_actions = 4
    gin.bind_parameter('BoxSpaceSerializer.precision', precision)

    obs = np.array([[[1.5, 2], [-0.3, 1.23], [0.84, 0.07], [0.01, 0.66]]])
    act = np.array([[[0, 1], [2, 0], [1, 3]]])

    wrapped_policy = serialization_utils.wrap_policy(
        TestModel(extra_dim=vocab_size),  # pylint: disable=no-value-for-parameter
        observation_space=gym.spaces.Box(shape=(2,), low=-2, high=2),
        action_space=gym.spaces.MultiDiscrete([n_actions] * n_controls),
        vocab_size=vocab_size,
    )

    example = (obs, act)
    wrapped_policy.init(shapes.signature(example))
    (act_logits, values) = wrapped_policy(example)
    self.assertEqual(act_logits.shape, obs.shape[:2] + (n_controls, n_actions))
    self.assertEqual(values.shape, obs.shape[:2]) 
开发者ID:google,项目名称:trax,代码行数:23,代码来源:serialization_utils_test.py

示例4: test_reformer_wmt_ende

# 需要导入模块: import gin [as 别名]
# 或者: from gin import bind_parameter [as 别名]
def test_reformer_wmt_ende(self):
    trax.fastmath.disable_jit()

    batch_size_per_device = 2
    steps = 1
    n_layers = 2
    d_ff = 32

    gin.parse_config_file('reformer_wmt_ende.gin')

    gin.bind_parameter('data_streams.data_dir', _TESTDATA)
    gin.bind_parameter('batcher.batch_size_per_device', batch_size_per_device)
    gin.bind_parameter('train.steps', steps)
    gin.bind_parameter('Reformer.n_encoder_layers', n_layers)
    gin.bind_parameter('Reformer.n_decoder_layers', n_layers)
    gin.bind_parameter('Reformer.d_ff', d_ff)

    with self.tmp_dir() as output_dir:
      _ = trainer_lib.train(output_dir=output_dir) 
开发者ID:google,项目名称:trax,代码行数:21,代码来源:reformer_e2e_test.py

示例5: test_reformer_noencdecattn_wmt_ende

# 需要导入模块: import gin [as 别名]
# 或者: from gin import bind_parameter [as 别名]
def test_reformer_noencdecattn_wmt_ende(self):
    trax.fastmath.disable_jit()

    batch_size_per_device = 1  # Ignored, but needs to be set.
    steps = 1
    n_layers = 2
    d_ff = 32

    gin.parse_config_file('reformer_noencdecattn_wmt_ende.gin')

    gin.bind_parameter('data_streams.data_dir', _TESTDATA)
    gin.bind_parameter('batcher.batch_size_per_device', batch_size_per_device)
    gin.bind_parameter('batcher.buckets', ([513], [1, 1]))  # batch size 1.
    gin.bind_parameter('train.steps', steps)
    gin.bind_parameter('ReformerNoEncDecAttention.n_encoder_layers', n_layers)
    gin.bind_parameter('ReformerNoEncDecAttention.n_decoder_layers', n_layers)
    gin.bind_parameter('ReformerNoEncDecAttention.d_ff', d_ff)

    with self.tmp_dir() as output_dir:
      _ = trainer_lib.train(output_dir=output_dir) 
开发者ID:google,项目名称:trax,代码行数:22,代码来源:reformer_e2e_test.py

示例6: setUp

# 需要导入模块: import gin [as 别名]
# 或者: from gin import bind_parameter [as 别名]
def setUp(self):
    super(PoseEnvModelsTest, self).setUp()
    base_dir = 'tensor2robot'
    test_data = os.path.join(FLAGS.test_srcdir,
                             base_dir,
                             'test_data/pose_env_test_data.tfrecord')
    self._train_log_dir = FLAGS.test_tmpdir
    if tf.io.gfile.exists(self._train_log_dir):
      tf.io.gfile.rmtree(self._train_log_dir)
    gin.bind_parameter('train_eval_model.max_train_steps', 3)
    gin.bind_parameter('train_eval_model.eval_steps', 2)

    self._record_input_generator = (
        default_input_generator.DefaultRecordInputGenerator(
            batch_size=BATCH_SIZE, file_patterns=test_data))

    self._meta_record_input_generator_train = (
        default_input_generator.DefaultRandomInputGenerator(
            batch_size=BATCH_SIZE))
    self._meta_record_input_generator_eval = (
        default_input_generator.DefaultRandomInputGenerator(
            batch_size=BATCH_SIZE)) 
开发者ID:google-research,项目名称:tensor2robot,代码行数:24,代码来源:pose_env_models_test.py

示例7: test_run_pose_env_collect

# 需要导入模块: import gin [as 别名]
# 或者: from gin import bind_parameter [as 别名]
def test_run_pose_env_collect(self, demo_policy_cls):
    urdf_root = pose_env.get_pybullet_urdf_root()

    config_dir = 'research/pose_env/configs'
    gin_config = os.path.join(
        FLAGS.test_srcdir, config_dir, 'run_random_collect.gin')
    gin.parse_config_file(gin_config)
    tmp_dir = absltest.get_default_test_tmpdir()
    root_dir = os.path.join(tmp_dir, str(demo_policy_cls))
    gin.bind_parameter('PoseToyEnv.urdf_root', urdf_root)
    gin.bind_parameter(
        'collect_eval_loop.root_dir', root_dir)
    gin.bind_parameter('run_meta_env.num_tasks', 2)
    gin.bind_parameter('run_meta_env.num_episodes_per_adaptation', 1)
    gin.bind_parameter(
        'collect_eval_loop.policy_class', demo_policy_cls)
    continuous_collect_eval.collect_eval_loop()
    output_files = tf.io.gfile.glob(os.path.join(
        root_dir, 'policy_collect', '*.tfrecord'))
    self.assertLen(output_files, 2) 
开发者ID:google-research,项目名称:tensor2robot,代码行数:22,代码来源:continuous_collect_eval_test.py

示例8: _runSingleTrainingStep

# 需要导入模块: import gin [as 别名]
# 或者: from gin import bind_parameter [as 别名]
def _runSingleTrainingStep(self, architecture, loss_fn, penalty_fn):
    parameters = {
        "architecture": architecture,
        "lambda": 1,
        "z_dim": 128,
    }
    with gin.unlock_config():
      gin.bind_parameter("penalty.fn", penalty_fn)
      gin.bind_parameter("loss.fn", loss_fn)
    model_dir = self._get_empty_model_dir()
    run_config = tf.contrib.tpu.RunConfig(
        model_dir=model_dir,
        tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1))
    dataset = datasets.get_dataset("cifar10")
    gan = SSGAN(
        dataset=dataset,
        parameters=parameters,
        model_dir=model_dir,
        g_optimizer_fn=tf.train.AdamOptimizer,
        g_lr=0.0002,
        rotated_batch_size=4)
    estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=False)
    estimator.train(gan.input_fn, steps=1) 
开发者ID:google,项目名称:compare_gan,代码行数:25,代码来源:ssgan_test.py

示例9: _runSingleTrainingStep

# 需要导入模块: import gin [as 别名]
# 或者: from gin import bind_parameter [as 别名]
def _runSingleTrainingStep(self, architecture, loss_fn, penalty_fn):
    parameters = {
        "architecture": architecture,
        "lambda": 1,
        "z_dim": 128,
    }
    with gin.unlock_config():
      gin.bind_parameter("penalty.fn", penalty_fn)
      gin.bind_parameter("loss.fn", loss_fn)
    dataset = datasets.get_dataset("cifar10")
    gan = ModularGAN(
        dataset=dataset,
        parameters=parameters,
        model_dir=self.model_dir,
        conditional="biggan" in architecture)
    estimator = gan.as_estimator(self.run_config, batch_size=2, use_tpu=False)
    estimator.train(gan.input_fn, steps=1) 
开发者ID:google,项目名称:compare_gan,代码行数:19,代码来源:modular_gan_test.py

示例10: testSingleTrainingStepDiscItersWithEma

# 需要导入模块: import gin [as 别名]
# 或者: from gin import bind_parameter [as 别名]
def testSingleTrainingStepDiscItersWithEma(self, disc_iters):
    parameters = {
        "architecture": c.DUMMY_ARCH,
        "lambda": 1,
        "z_dim": 128,
        "dics_iters": disc_iters,
    }
    gin.bind_parameter("ModularGAN.g_use_ema", True)
    dataset = datasets.get_dataset("cifar10")
    gan = ModularGAN(
        dataset=dataset,
        parameters=parameters,
        model_dir=self.model_dir)
    estimator = gan.as_estimator(self.run_config, batch_size=2, use_tpu=False)
    estimator.train(gan.input_fn, steps=1)
    # Check for moving average variables in checkpoint.
    checkpoint_path = tf.train.latest_checkpoint(self.model_dir)
    ema_vars = sorted([v[0] for v in tf.train.list_variables(checkpoint_path)
                       if v[0].endswith("ExponentialMovingAverage")])
    tf.logging.info("ema_vars=%s", ema_vars)
    expected_ema_vars = sorted([
        "generator/fc_noise/kernel/ExponentialMovingAverage",
        "generator/fc_noise/bias/ExponentialMovingAverage",
    ])
    self.assertAllEqual(ema_vars, expected_ema_vars) 
开发者ID:google,项目名称:compare_gan,代码行数:27,代码来源:modular_gan_test.py

示例11: testUnlabledDatasetRaisesError

# 需要导入模块: import gin [as 别名]
# 或者: from gin import bind_parameter [as 别名]
def testUnlabledDatasetRaisesError(self):
    parameters = {
        "architecture": c.RESNET_CIFAR_ARCH,
        "lambda": 1,
        "z_dim": 120,
    }
    with gin.unlock_config():
      gin.bind_parameter("loss.fn", loss_lib.hinge)
    # Use dataset without labels.
    dataset = datasets.get_dataset("celeb_a")
    model_dir = self._get_empty_model_dir()
    with self.assertRaises(ValueError):
      gan = ModularGAN(
          dataset=dataset,
          parameters=parameters,
          conditional=True,
          model_dir=model_dir)
      del gan 
开发者ID:google,项目名称:compare_gan,代码行数:20,代码来源:modular_gan_conditional_test.py

示例12: testBatchNormTwoCoresCustom

# 需要导入模块: import gin [as 别名]
# 或者: from gin import bind_parameter [as 别名]
def testBatchNormTwoCoresCustom(self):
    def computation(x):
      custom_bn = arch_ops.batch_norm(x, is_training=True, name="custom_bn")
      gin.bind_parameter("cross_replica_moments.parallel", False)
      custom_bn_seq = arch_ops.batch_norm(x, is_training=True,
                                          name="custom_bn_seq")
      return custom_bn, custom_bn_seq

    with tf.Graph().as_default():
      x = tf.constant(self._inputs)
      custom_bn, custom_bn_seq = tf.contrib.tpu.batch_parallel(
          computation, [x], num_shards=2)

      with self.session() as sess:
        sess.run(tf.contrib.tpu.initialize_system())
        sess.run(tf.global_variables_initializer())
        custom_bn, custom_bn_seq = sess.run(
            [custom_bn, custom_bn_seq])
        logging.info("custom_bn: %s", custom_bn)
        logging.info("custom_bn_seq: %s", custom_bn_seq)
        self.assertAllClose(custom_bn, self._expected_outputs)
        self.assertAllClose(custom_bn_seq, self._expected_outputs) 
开发者ID:google,项目名称:compare_gan,代码行数:24,代码来源:arch_ops_tpu_test.py

示例13: testInitializersRandomNormal

# 需要导入模块: import gin [as 别名]
# 或者: from gin import bind_parameter [as 别名]
def testInitializersRandomNormal(self):
    gin.bind_parameter("weights.initializer", consts.NORMAL_INIT)
    valid_initalizer = [
        "kernel/Initializer/random_normal",
        "bias/Initializer/Const",
        "kernel/Initializer/random_normal",
        "bias/Initializer/Const",
        "beta/Initializer/zeros",
        "gamma/Initializer/ones",
    ]
    valid_op_names = "/({}):0$".format("|".join(valid_initalizer))
    with tf.Graph().as_default():
      z = tf.zeros((2, 128))
      fake_image = resnet5.Generator(image_shape=(128, 128, 3))(
          z, y=None, is_training=True)
      resnet5.Discriminator()(fake_image, y=None, is_training=True)
      for var in tf.trainable_variables():
        op_name = var.initializer.inputs[1].name
        self.assertRegex(op_name, valid_op_names) 
开发者ID:google,项目名称:compare_gan,代码行数:21,代码来源:resnet_init_test.py

示例14: testInitializersTruncatedNormal

# 需要导入模块: import gin [as 别名]
# 或者: from gin import bind_parameter [as 别名]
def testInitializersTruncatedNormal(self):
    gin.bind_parameter("weights.initializer", consts.TRUNCATED_INIT)
    valid_initalizer = [
        "kernel/Initializer/truncated_normal",
        "bias/Initializer/Const",
        "kernel/Initializer/truncated_normal",
        "bias/Initializer/Const",
        "beta/Initializer/zeros",
        "gamma/Initializer/ones",
    ]
    valid_op_names = "/({}):0$".format("|".join(valid_initalizer))
    with tf.Graph().as_default():
      z = tf.zeros((2, 128))
      fake_image = resnet5.Generator(image_shape=(128, 128, 3))(
          z, y=None, is_training=True)
      resnet5.Discriminator()(fake_image, y=None, is_training=True)
      for var in tf.trainable_variables():
        op_name = var.initializer.inputs[1].name
        self.assertRegex(op_name, valid_op_names) 
开发者ID:google,项目名称:compare_gan,代码行数:21,代码来源:resnet_init_test.py

示例15: test_serialized_model_continuous

# 需要导入模块: import gin [as 别名]
# 或者: from gin import bind_parameter [as 别名]
def test_serialized_model_continuous(self):
    precision = 3
    gin.bind_parameter('BoxSpaceSerializer.precision', precision)

    vocab_size = 32
    obs = np.array([[[1.5, 2], [-0.3, 1.23], [0.84, 0.07], [0, 0]]])
    act = np.array([[0, 1, 0]])
    mask = np.array([[1, 1, 1, 0]])

    obs_serializer = space_serializer.create(
        gym.spaces.Box(shape=(2,), low=-2, high=2), vocab_size=vocab_size
    )
    act_serializer = space_serializer.create(
        gym.spaces.Discrete(2), vocab_size=vocab_size
    )
    serialized_model = serialization_utils.SerializedModel(
        TestModel(extra_dim=vocab_size),  # pylint: disable=no-value-for-parameter
        observation_serializer=obs_serializer,
        action_serializer=act_serializer,
        significance_decay=0.9,
    )

    example = (obs, act, obs, mask)
    serialized_model.init(shapes.signature(example))
    (obs_logits, obs_repr, weights) = serialized_model(example)
    self.assertEqual(obs_logits.shape, obs_repr.shape + (vocab_size,))
    self.assertEqual(
        obs_repr.shape, (1, obs.shape[1], obs.shape[2] * precision)
    )
    self.assertEqual(obs_repr.shape, weights.shape) 
开发者ID:google,项目名称:trax,代码行数:32,代码来源:serialization_utils_test.py


注:本文中的gin.bind_parameter方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。