当前位置: 首页>>代码示例>>Python>>正文


Python gin.unlock_config方法代码示例

本文整理汇总了Python中gin.unlock_config方法的典型用法代码示例。如果您正苦于以下问题:Python gin.unlock_config方法的具体用法?Python gin.unlock_config怎么用?Python gin.unlock_config使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在gin的用法示例。


在下文中一共展示了gin.unlock_config方法的12个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _runSingleTrainingStep

# 需要导入模块: import gin [as 别名]
# 或者: from gin import unlock_config [as 别名]
def _runSingleTrainingStep(self, architecture, loss_fn, penalty_fn):
    parameters = {
        "architecture": architecture,
        "lambda": 1,
        "z_dim": 128,
    }
    with gin.unlock_config():
      gin.bind_parameter("penalty.fn", penalty_fn)
      gin.bind_parameter("loss.fn", loss_fn)
    model_dir = self._get_empty_model_dir()
    run_config = tf.contrib.tpu.RunConfig(
        model_dir=model_dir,
        tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1))
    dataset = datasets.get_dataset("cifar10")
    gan = SSGAN(
        dataset=dataset,
        parameters=parameters,
        model_dir=model_dir,
        g_optimizer_fn=tf.train.AdamOptimizer,
        g_lr=0.0002,
        rotated_batch_size=4)
    estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=False)
    estimator.train(gan.input_fn, steps=1) 
开发者ID:google,项目名称:compare_gan,代码行数:25,代码来源:ssgan_test.py

示例2: _runSingleTrainingStep

# 需要导入模块: import gin [as 别名]
# 或者: from gin import unlock_config [as 别名]
def _runSingleTrainingStep(self, architecture, loss_fn, penalty_fn):
    parameters = {
        "architecture": architecture,
        "lambda": 1,
        "z_dim": 128,
    }
    with gin.unlock_config():
      gin.bind_parameter("penalty.fn", penalty_fn)
      gin.bind_parameter("loss.fn", loss_fn)
    dataset = datasets.get_dataset("cifar10")
    gan = ModularGAN(
        dataset=dataset,
        parameters=parameters,
        model_dir=self.model_dir,
        conditional="biggan" in architecture)
    estimator = gan.as_estimator(self.run_config, batch_size=2, use_tpu=False)
    estimator.train(gan.input_fn, steps=1) 
开发者ID:google,项目名称:compare_gan,代码行数:19,代码来源:modular_gan_test.py

示例3: _runSingleTrainingStep

# 需要导入模块: import gin [as 别名]
# 或者: from gin import unlock_config [as 别名]
def _runSingleTrainingStep(self, architecture, loss_fn, penalty_fn,
                             labeled_dataset):
    parameters = {
        "architecture": architecture,
        "lambda": 1,
        "z_dim": 120,
    }
    with gin.unlock_config():
      gin.bind_parameter("penalty.fn", penalty_fn)
      gin.bind_parameter("loss.fn", loss_fn)
    model_dir = self._get_empty_model_dir()
    run_config = tf.contrib.tpu.RunConfig(
        model_dir=model_dir,
        tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1))
    dataset = datasets.get_dataset("cifar10")
    gan = ModularGAN(
        dataset=dataset,
        parameters=parameters,
        conditional=True,
        model_dir=model_dir)
    estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=False)
    estimator.train(gan.input_fn, steps=1) 
开发者ID:google,项目名称:compare_gan,代码行数:24,代码来源:modular_gan_conditional_test.py

示例4: testUnlabledDatasetRaisesError

# 需要导入模块: import gin [as 别名]
# 或者: from gin import unlock_config [as 别名]
def testUnlabledDatasetRaisesError(self):
    parameters = {
        "architecture": c.RESNET_CIFAR_ARCH,
        "lambda": 1,
        "z_dim": 120,
    }
    with gin.unlock_config():
      gin.bind_parameter("loss.fn", loss_lib.hinge)
    # Use dataset without labels.
    dataset = datasets.get_dataset("celeb_a")
    model_dir = self._get_empty_model_dir()
    with self.assertRaises(ValueError):
      gan = ModularGAN(
          dataset=dataset,
          parameters=parameters,
          conditional=True,
          model_dir=model_dir)
      del gan 
开发者ID:google,项目名称:compare_gan,代码行数:20,代码来源:modular_gan_conditional_test.py

示例5: estimator

# 需要导入模块: import gin [as 别名]
# 或者: from gin import unlock_config [as 别名]
def estimator(self, vocabulary, init_checkpoint=None, disable_tpu=False):

    if not self._tpu or disable_tpu:
      with gin.unlock_config():
        gin.bind_parameter("utils.get_variable_dtype.slice_dtype", "float32")
        gin.bind_parameter(
            "utils.get_variable_dtype.activation_dtype", "float32")

    return utils.get_estimator(
        model_type=self._model_type,
        vocabulary=vocabulary,
        layout_rules=self._layout_rules,
        mesh_shape=mtf.Shape([]) if disable_tpu else self._mesh_shape,
        mesh_devices=self._mesh_devices,
        model_dir=self._model_dir,
        batch_size=self.batch_size,
        sequence_length=self._sequence_length,
        autostack=self._autostack,
        learning_rate_schedule=self._learning_rate_schedule,
        keep_checkpoint_max=self._keep_checkpoint_max,
        save_checkpoints_steps=self._save_checkpoints_steps,
        optimizer=self._optimizer,
        predict_fn=self._predict_fn,
        variable_filter=self._variable_filter,
        ensemble_inputs=self._ensemble_inputs,
        use_tpu=None if disable_tpu else self._tpu,
        tpu_job_name=self._tpu_job_name,
        iterations_per_loop=self._iterations_per_loop,
        cluster=self._cluster,
        init_checkpoint=init_checkpoint) 
开发者ID:google-research,项目名称:text-to-text-transfer-transformer,代码行数:32,代码来源:mtf_model.py

示例6: eval

# 需要导入模块: import gin [as 别名]
# 或者: from gin import unlock_config [as 别名]
def eval(self, mixture_or_task_name, checkpoint_steps=None, summary_dir=None,
           split="validation"):
    """Evaluate the model on the given Mixture or Task.

    Args:
      mixture_or_task_name: str, the name of the Mixture or Task to evaluate on.
        Must be pre-registered in the global `TaskRegistry` or
        `MixtureRegistry.`
      checkpoint_steps: int, list of ints, or None. If an int or list of ints,
        evaluation will be run on the checkpoint files in `model_dir` whose
        global steps are closest to the global steps provided. If None, run eval
        continuously waiting for new checkpoints. If -1, get the latest
        checkpoint from the model directory.
      summary_dir: str, path to write TensorBoard events file summaries for
        eval. If None, use model_dir/eval_{split}.
      split: str, the mixture/task split to evaluate on.
    """
    if checkpoint_steps == -1:
      checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir)
    vocabulary = t5.models.mesh_transformer.get_vocabulary(mixture_or_task_name)
    dataset_fn = functools.partial(
        t5.models.mesh_transformer.mesh_eval_dataset_fn,
        mixture_or_task_name=mixture_or_task_name,
    )
    with gin.unlock_config():
      gin.parse_config_file(_operative_config_path(self._model_dir))
    utils.eval_model(self.estimator(vocabulary), vocabulary,
                     self._sequence_length, self.batch_size, split,
                     self._model_dir, dataset_fn, summary_dir, checkpoint_steps) 
开发者ID:google-research,项目名称:text-to-text-transfer-transformer,代码行数:31,代码来源:mtf_model.py

示例7: finetune

# 需要导入模块: import gin [as 别名]
# 或者: from gin import unlock_config [as 别名]
def finetune(self, mixture_or_task_name, finetune_steps, pretrained_model_dir,
               pretrained_checkpoint_step=-1, split="train"):
    """Finetunes a model from an existing checkpoint.

    Args:
      mixture_or_task_name: str, the name of the Mixture or Task to evaluate on.
        Must be pre-registered in the global `TaskRegistry` or
        `MixtureRegistry.`
      finetune_steps: int, the number of additional steps to train for.
      pretrained_model_dir: str, directory with pretrained model checkpoints and
        operative config.
      pretrained_checkpoint_step: int, checkpoint to initialize weights from. If
        -1 (default), use the latest checkpoint from the pretrained model
        directory.
      split: str, the mixture/task split to finetune on.
    """
    if pretrained_checkpoint_step == -1:
      checkpoint_step = _get_latest_checkpoint_from_dir(pretrained_model_dir)
    else:
      checkpoint_step = pretrained_checkpoint_step
    with gin.unlock_config():
      gin.parse_config_file(_operative_config_path(pretrained_model_dir))

    model_ckpt = "model.ckpt-" + str(checkpoint_step)
    self.train(mixture_or_task_name, checkpoint_step + finetune_steps,
               init_checkpoint=os.path.join(pretrained_model_dir, model_ckpt),
               split=split) 
开发者ID:google-research,项目名称:text-to-text-transfer-transformer,代码行数:29,代码来源:mtf_model.py

示例8: predict

# 需要导入模块: import gin [as 别名]
# 或者: from gin import unlock_config [as 别名]
def predict(self, input_file, output_file, checkpoint_steps=-1,
              beam_size=1, temperature=1.0, vocabulary=None):
    """Predicts targets from the given inputs.

    Args:
      input_file: str, path to a text file containing newline-separated input
        prompts to predict from.
      output_file: str, path prefix of output file to write predictions to. Note
        the checkpoint step will be appended to the given filename.
      checkpoint_steps: int, list of ints, or None. If an int or list of ints,
        inference will be run on the checkpoint files in `model_dir` whose
        global steps are closest to the global steps provided. If None, run
        inference continuously waiting for new checkpoints. If -1, get the
        latest checkpoint from the model directory.
      beam_size: int, a number >= 1 specifying the number of beams to use for
        beam search.
      temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1)
        0.0 means argmax, 1.0 means sample according to predicted distribution.
      vocabulary: vocabularies.Vocabulary object to use for tokenization, or
        None to use the default SentencePieceVocabulary.
    """
    # TODO(sharannarang) : It would be nice to have a function like
    # load_checkpoint that loads the model once and then call decode_from_file
    # multiple times without having to restore the checkpoint weights again.
    # This would be particularly useful in colab demo.

    if checkpoint_steps == -1:
      checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir)

    with gin.unlock_config():
      gin.parse_config_file(_operative_config_path(self._model_dir))
      gin.bind_parameter("Bitransformer.decode.beam_size", beam_size)
      gin.bind_parameter("Bitransformer.decode.temperature", temperature)

    if vocabulary is None:
      vocabulary = t5.data.get_default_vocabulary()
    utils.infer_model(
        self.estimator(vocabulary), vocabulary, self._sequence_length,
        self.batch_size, self._model_type, self._model_dir, checkpoint_steps,
        input_file, output_file) 
开发者ID:google-research,项目名称:text-to-text-transfer-transformer,代码行数:42,代码来源:mtf_model.py

示例9: score

# 需要导入模块: import gin [as 别名]
# 或者: from gin import unlock_config [as 别名]
def score(self,
            inputs,
            targets,
            scores_file=None,
            checkpoint_steps=-1,
            vocabulary=None):
    """Computes log-likelihood of target per example in targets.

    Args:
      inputs: optional - a string (filename), or a list of strings (inputs)
      targets: a string (filename), or a list of strings (targets)
      scores_file: str, path to write example scores to, one per line.
      checkpoint_steps: int, list of ints, or None. If an int or list of ints,
        inference will be run on the checkpoint files in `model_dir` whose
        global steps are closest to the global steps provided. If None, run
        inference continuously waiting for new checkpoints. If -1, get the
        latest checkpoint from the model directory.
      vocabulary: vocabularies.Vocabulary object to use for tokenization, or
        None to use the default SentencePieceVocabulary.
    """
    if checkpoint_steps == -1:
      checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir)

    with gin.unlock_config():
      gin.parse_config_file(_operative_config_path(self._model_dir))
      # The following config setting ensures we do scoring instead of inference.
      gin.bind_parameter("tpu_estimator_model_fn.score_in_predict_mode", True)

    if vocabulary is None:
      vocabulary = t5.data.get_default_vocabulary()

    utils.score_from_strings(self.estimator(vocabulary), vocabulary,
                             self._model_type, self.batch_size,
                             self._sequence_length, self._model_dir,
                             checkpoint_steps, inputs, targets, scores_file) 
开发者ID:google-research,项目名称:text-to-text-transfer-transformer,代码行数:37,代码来源:mtf_model.py

示例10: testSingleTrainingStepArchitectures

# 需要导入模块: import gin [as 别名]
# 或者: from gin import unlock_config [as 别名]
def testSingleTrainingStepArchitectures(
      self, use_predictor, project_y=True, self_supervision="none"):
    parameters = {
        "architecture": c.RESNET_BIGGAN_ARCH,
        "lambda": 1,
        "z_dim": 120,
    }
    with gin.unlock_config():
      gin.bind_parameter("ModularGAN.conditional", True)
      gin.bind_parameter("loss.fn", loss_lib.hinge)
      gin.bind_parameter("S3GAN.use_predictor", use_predictor)
      gin.bind_parameter("S3GAN.project_y", project_y)
      gin.bind_parameter("S3GAN.self_supervision", self_supervision)
    # Fake ImageNet dataset by overriding the properties.
    dataset = datasets.get_dataset("imagenet_128")
    model_dir = self._get_empty_model_dir()
    run_config = tf.contrib.tpu.RunConfig(
        model_dir=model_dir,
        tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1))
    gan = S3GAN(
        dataset=dataset,
        parameters=parameters,
        model_dir=model_dir,
        g_optimizer_fn=tf.train.AdamOptimizer,
        g_lr=0.0002,
        rotated_batch_fraction=2)
    estimator = gan.as_estimator(run_config, batch_size=8, use_tpu=False)
    estimator.train(gan.input_fn, steps=1) 
开发者ID:google,项目名称:compare_gan,代码行数:30,代码来源:s3gan_test.py

示例11: parse_cmdline_gin_configurations

# 需要导入模块: import gin [as 别名]
# 或者: from gin import unlock_config [as 别名]
def parse_cmdline_gin_configurations():
  """Parse Gin configurations from all command-line sources."""
  with gin.unlock_config():
    gin.parse_config_files_and_bindings(
        FLAGS.gin_config, FLAGS.gin_bindings, finalize_config=True) 
开发者ID:google-research,项目名称:meta-dataset,代码行数:7,代码来源:train.py

示例12: load_operative_gin_configurations

# 需要导入模块: import gin [as 别名]
# 或者: from gin import unlock_config [as 别名]
def load_operative_gin_configurations(operative_config_dir):
  """Load operative Gin configurations from the given directory."""
  gin_log_file = operative_config_path(operative_config_dir)
  with gin.unlock_config():
    gin.parse_config_file(gin_log_file)
  gin.finalize()
  logging.info('Operative Gin configurations loaded from %s.', gin_log_file) 
开发者ID:google-research,项目名称:meta-dataset,代码行数:9,代码来源:train.py


注:本文中的gin.unlock_config方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。