当前位置: 首页>>代码示例>>Python>>正文


Python gin.configurable方法代码示例

本文整理汇总了Python中gin.configurable方法的典型用法代码示例。如果您正苦于以下问题:Python gin.configurable方法的具体用法?Python gin.configurable怎么用?Python gin.configurable使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在gin的用法示例。


在下文中一共展示了gin.configurable方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get_random_number_generator_and_set_seed

# 需要导入模块: import gin [as 别名]
# 或者: from gin import configurable [as 别名]
def get_random_number_generator_and_set_seed(seed=None):
  """Get a JAX random number generator and set random seed everywhere."""
  random.seed(seed)
  # While python random accepts None as seed and uses time/os seed then,
  # some other functions expect integers so we create one here.
  if seed is None:
    seed = random.randint(0, 2**31 - 1)
  tf.set_random_seed(seed)
  numpy.random.seed(seed)
  return jax_random.get_prng(seed)


# TODO(trax):
# * Make configurable:
#   * loss
#   * metrics
# * Training loop callbacks/hooks/...
# * Save/restore: pickle unsafe. Use np.array.savez + MessagePack?
# * Move metrics to metrics.py
# * Setup namedtuples for interfaces (e.g. lr fun constructors can take a
#   LearningRateInit, metric funs, etc.).
# * Allow disabling eval 
开发者ID:yyht,项目名称:BERT,代码行数:24,代码来源:trax.py

示例2: _step_impl

# 需要导入模块: import gin [as 别名]
# 或者: from gin import configurable [as 别名]
def _step_impl(self, state, action):
    """Run one timestep of the environment's dynamics.

    At each timestep, x is flipped from zero to one or one to zero.

    Args:
      state: A `State` object containing the current state.
      action: An action in `action_space`.

    Returns:
      A `State` object containing the updated state.
    """
    del action  # Unused.
    state.x = [1 - x for x in state.x]
    return state


# TODO(): There isn't actually anything to configure in DummyMetric,
# but we mark it as configurable so that we can refer to it on the
# right-hand-side of expressions in gin configurations.  Find out whether
# there's a better way of indicating that than gin.configurable. 
开发者ID:google,项目名称:ml-fairness-gym,代码行数:23,代码来源:test_util.py

示例3: initialize

# 需要导入模块: import gin [as 别名]
# 或者: from gin import configurable [as 别名]
def initialize(self):
    """Initialize the teacher model from the checkpoint.

    This function will be called after the graph has been constructed.
    """
    if self.fraction_soft == 0.0:
      # Do nothing if we do not need the teacher.
      return
    vars_to_restore = tf.get_collection(
        tf.GraphKeys.GLOBAL_VARIABLES, scope="teacher")
    tf.train.init_from_checkpoint(
        self.teacher_checkpoint,
        {v.name[len("teacher/"):].split(":")[0]: v for v in vars_to_restore})


# gin-configurable constructors 
开发者ID:tensorflow,项目名称:mesh,代码行数:18,代码来源:transformer.py

示例4: _output_dir_or_default

# 需要导入模块: import gin [as 别名]
# 或者: from gin import configurable [as 别名]
def _output_dir_or_default():
  """Returns a path to the output directory."""
  if FLAGS.output_dir:
    output_dir = FLAGS.output_dir
    trainer_lib.log('Using --output_dir {}'.format(output_dir))
    return os.path.expanduser(output_dir)

  # Else, generate a default output dir (under the user's home directory).
  try:
    dataset_name = gin.query_parameter('data_streams.dataset_name')
  except ValueError:
    dataset_name = 'random'
  output_name = '{model_name}_{dataset_name}_{timestamp}'.format(
      model_name=gin.query_parameter('train.model').configurable.name,
      dataset_name=dataset_name,
      timestamp=datetime.datetime.now().strftime('%Y%m%d_%H%M'),
  )
  output_dir = os.path.join('~', 'trax', output_name)
  output_dir = os.path.expanduser(output_dir)
  print()
  trainer_lib.log('No --output_dir specified')
  trainer_lib.log('Using default output_dir: {}'.format(output_dir))
  return output_dir


# TODO(afrozm): Share between trainer.py and rl_trainer.py 
开发者ID:google,项目名称:trax,代码行数:28,代码来源:trainer.py

示例5: create_train_op

# 需要导入模块: import gin [as 别名]
# 或者: from gin import configurable [as 别名]
def create_train_op(self,
                      loss,
                      optimizer,
                      update_ops=None,
                      train_outputs=None):
    """Create meta-training op.

    MAMLModel has a configurable var_scope used to select which variables to
    train on. Note that MAMLInnerLoopGradientDescent also has such a parameter
    to decide which variables to update in the *inner* loop. If you don't want
    to update a set of variables in both the inner and outer loop, you'll need
    to configure var_scope for both MAMLModel *and*
    MAMLInnerLoopGradientDescent.

    Args:
      loss: The loss we compute within model_train_fn.
      optimizer: An instance of `tf.train.Optimizer`.
      update_ops: List of update ops to execute alongside the training op.
      train_outputs: (Optional) A dict with additional tensors the training
        model generates.

    Returns:
      train_op: Op for the training step.
    """
    vars_to_train = tf.trainable_variables()
    if self._var_scope is not None:
      vars_to_train = [
          v for v in vars_to_train if v.op.name.startswith(self._var_scope)]
    summarize_gradients = self._summarize_gradients
    if self.is_device_tpu:
      # TPUs don't support summaries up until now. Hence, we overwrite the user
      # provided summarize_gradients option to False.
      if self._summarize_gradients:
        logging.info('We cannot use summarize_gradients on TPUs.')
      summarize_gradients = False
    return contrib_training.create_train_op(
        loss,
        optimizer,
        variables_to_train=vars_to_train,
        summarize_gradients=summarize_gradients,
        update_ops=update_ops) 
开发者ID:google-research,项目名称:tensor2robot,代码行数:43,代码来源:maml_model.py

示例6: loss_fn

# 需要导入模块: import gin [as 别名]
# 或者: from gin import configurable [as 别名]
def loss_fn(self, labels, inference_outputs, mode, params=None):
    """This implements outer loss and configurable inner losses."""
    if params and params.get('is_outer_loss', False):
      pass
    if self._num_mixture_components > 1:
      gm = mdn.get_mixture_distribution(
          inference_outputs['dist_params'], self._num_mixture_components,
          self._action_size,
          self._output_mean if self._normalize_outputs else None)
      return -tf.reduce_mean(gm.log_prob(labels.action))
    else:
      return self._outer_loss_multiplier * tf.losses.mean_squared_error(
          labels=labels.action,
          predictions=inference_outputs['inference_output']) 
开发者ID:google-research,项目名称:tensor2robot,代码行数:16,代码来源:vrgripper_env_models.py

示例7: parse_gin_defaults_and_flags

# 需要导入模块: import gin [as 别名]
# 或者: from gin import configurable [as 别名]
def parse_gin_defaults_and_flags():
  """Parses all default gin files and those provided via flags."""
  # Register .gin file search paths with gin
  for gin_file_path in FLAGS.gin_location_prefix:
    gin.add_config_file_search_path(gin_file_path)
  # Set up the default values for the configurable parameters. These values will
  # be overridden by any user provided gin files/parameters.
  gin.parse_config_file(
      pkg_resources.resource_filename(__name__, _DEFAULT_CONFIG_FILE))
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)


# TODO(noam): maybe add gin-config to mtf.get_variable so we can delete
#  this stupid VariableDtype class and stop passing it all over creation. 
开发者ID:tensorflow,项目名称:mesh,代码行数:16,代码来源:utils.py

示例8: separate_vocabularies

# 需要导入模块: import gin [as 别名]
# 或者: from gin import configurable [as 别名]
def separate_vocabularies(inputs=gin.REQUIRED, targets=gin.REQUIRED):
  """Gin-configurable helper function to generate a tuple of vocabularies."""
  return (inputs, targets)


# TODO(katherinelee): Update layout_rules string when noam updates the
# definition in run 
开发者ID:tensorflow,项目名称:mesh,代码行数:9,代码来源:utils.py

示例9: attention_internal

# 需要导入模块: import gin [as 别名]
# 或者: from gin import configurable [as 别名]
def attention_internal(self, context, q, m, memory_length, bias):
    logits = mtf.layers.us_einsum(
        [q, m], reduced_dims=[context.model.model_dim])
    if bias is not None:
      logits += bias
    weights = mtf.softmax(logits, memory_length)
    # TODO(noam): make dropout_broadcast_dims configurable
    dropout_broadcast_dims = [context.length_dim]
    weights = mtf.dropout(
        weights, rate=self.dropout_rate if context.train else 0.0,
        noise_shape=weights.shape - dropout_broadcast_dims)
    u = mtf.einsum([weights, m], reduced_dims=[memory_length])
    return self.compute_y(context, u) 
开发者ID:tensorflow,项目名称:mesh,代码行数:15,代码来源:transformer_layers.py

示例10: load

# 需要导入模块: import gin [as 别名]
# 或者: from gin import configurable [as 别名]
def load(
    environment_name: Text,
    discount: types.Float = 1.0,
    max_episode_steps: Optional[types.Int] = None,
    gym_env_wrappers: Sequence[types.GymEnvWrapper] = (),
    env_wrappers: Sequence[types.PyEnvWrapper] = (),
    spec_dtype_map: Optional[Dict[gym.Space, np.dtype]] = None
) -> py_environment.PyEnvironment:
  """Loads the selected environment and wraps it with the specified wrappers.

  Note that by default a TimeLimit wrapper is used to limit episode lengths
  to the default benchmarks defined by the registered environments.

  Args:
    environment_name: Name for the environment to load.
    discount: Discount to use for the environment.
    max_episode_steps: If None the max_episode_steps will be set to the default
      step limit defined in the environment's spec. No limit is applied if set
      to 0 or if there is no timestep_limit set in the environment's spec.
    gym_env_wrappers: Iterable with references to wrapper classes to use
      directly on the gym environment.
    env_wrappers: Iterable with references to wrapper classes to use on the
      gym_wrapped environment.
    spec_dtype_map: A dict that maps gym specs to tf dtypes to use as the
      default dtype for the tensors. An easy way how to configure a custom
      mapping through Gin is to define a gin-configurable function that returns
      desired mapping and call it in your Gin config file, for example:
      `suite_gym.load.spec_dtype_map = @get_custom_mapping()`.

  Returns:
    A PyEnvironmentBase instance.
  """
  return suite_gym.load(environment_name, discount, max_episode_steps,
                        gym_env_wrappers, env_wrappers, spec_dtype_map) 
开发者ID:tensorflow,项目名称:agents,代码行数:36,代码来源:suite_mujoco.py

示例11: compute_optimal_action_with_classification_environment

# 需要导入模块: import gin [as 别名]
# 或者: from gin import configurable [as 别名]
def compute_optimal_action_with_classification_environment(
    observation, environment):
  """Helper function for gin configurable SuboptimalArms metric."""
  del observation
  return environment.compute_optimal_action() 
开发者ID:tensorflow,项目名称:agents,代码行数:7,代码来源:environment_utilities.py

示例12: compute_optimal_reward_with_classification_environment

# 需要导入模块: import gin [as 别名]
# 或者: from gin import configurable [as 别名]
def compute_optimal_reward_with_classification_environment(
    observation, environment):
  """Helper function for gin configurable Regret metric."""
  del observation
  return environment.compute_optimal_reward() 
开发者ID:tensorflow,项目名称:agents,代码行数:7,代码来源:environment_utilities.py

示例13: rate_unsupervised

# 需要导入模块: import gin [as 别名]
# 或者: from gin import configurable [as 别名]
def rate_unsupervised(task, value=1e6):
  """Gin-configurable mixing rate for the unsupervised co-training task."""
  del task
  return value 
开发者ID:google-research,项目名称:text-to-text-transfer-transformer,代码行数:6,代码来源:utils.py

示例14: default_input_fn_tmpl

# 需要导入模块: import gin [as 别名]
# 或者: from gin import configurable [as 别名]
def default_input_fn_tmpl(
    file_patterns,
    batch_size,
    feature_spec,
    label_spec,
    num_parallel_calls = 4,
    is_training = False,
    preprocess_fn=None,
    shuffle_buffer_size = 500,
    prefetch_buffer_size = (tf.data.experimental.AUTOTUNE),
    parallel_shards = 10):
  """Generic gin-configurable tf.data input pipeline."""
  if isinstance(file_patterns, dict):
    file_patterns_map = file_patterns
  else:
    file_patterns_map = {'': file_patterns}
  datasets = {}
  # Read Each Dataset
  for dataset_key, file_patterns in file_patterns_map.items():
    data_format, filenames = get_data_format_and_filenames(file_patterns)
    filenames_dataset = tf.data.Dataset.list_files(
        filenames, shuffle=is_training)
    if is_training:
      cycle_length = min(parallel_shards, len(filenames))
    else:
      cycle_length = 1
    dataset = filenames_dataset.apply(
        tf.data.experimental.parallel_interleave(
            DATA_FORMAT[data_format],
            cycle_length=cycle_length,
            sloppy=is_training))

    if is_training:
      dataset = dataset.shuffle(buffer_size=shuffle_buffer_size).repeat()
    else:
      dataset = dataset.repeat()
    dataset = dataset.batch(batch_size, drop_remainder=True)
    datasets[dataset_key] = dataset
  # Merge dict of datasets of batched serialized examples into a single dataset
  # of dicts of batched serialized examples.
  dataset = tf.data.Dataset.zip(datasets)
  # Parse all datasets together.
  dataset = serialized_to_parsed(
      dataset, feature_spec, label_spec, num_parallel_calls=num_parallel_calls)
  if preprocess_fn is not None:
    # TODO(psanketi): Consider adding num_parallel calls here.
    dataset = dataset.map(preprocess_fn, num_parallel_calls=parallel_shards)
  if prefetch_buffer_size is not None:
    dataset = dataset.prefetch(prefetch_buffer_size)
  return dataset 
开发者ID:google-research,项目名称:tensor2robot,代码行数:52,代码来源:tfdata.py

示例15: make_bitransformer

# 需要导入模块: import gin [as 别名]
# 或者: from gin import configurable [as 别名]
def make_bitransformer(
    input_vocab_size=gin.REQUIRED,
    output_vocab_size=gin.REQUIRED,
    layout=None,
    mesh_shape=None,
    encoder_name="encoder",
    decoder_name="decoder"):
  """Gin-configurable bitransformer constructor.

  In your config file you need to set the encoder and decoder layers like this:
  encoder/make_layer_stack.layers = [
    @transformer_layers.SelfAttention,
    @transformer_layers.DenseReluDense,
  ]
  decoder/make_layer_stack.layers = [
    @transformer_layers.SelfAttention,
    @transformer_layers.EncDecAttention,
    @transformer_layers.DenseReluDense,
  ]

  Args:
    input_vocab_size: a integer
    output_vocab_size: an integer
    layout: optional - an input to mtf.convert_to_layout_rules
      Some layers (e.g. MoE layers) cheat by looking at layout and mesh_shape
    mesh_shape: optional - an input to mtf.convert_to_shape
      Some layers (e.g. MoE layers) cheat by looking at layout and mesh_shape
    encoder_name: optional - a string giving the Unitransformer encoder name.
    decoder_name: optional - a string giving the Unitransformer decoder name.
  Returns:
    a Bitransformer
  """
  with gin.config_scope("encoder"):
    encoder = Unitransformer(
        layer_stack=make_layer_stack(),
        input_vocab_size=input_vocab_size,
        output_vocab_size=None,
        autoregressive=False,
        name=encoder_name,
        layout=layout,
        mesh_shape=mesh_shape)
  with gin.config_scope("decoder"):
    decoder = Unitransformer(
        layer_stack=make_layer_stack(),
        input_vocab_size=output_vocab_size,
        output_vocab_size=output_vocab_size,
        autoregressive=True,
        name=decoder_name,
        layout=layout,
        mesh_shape=mesh_shape)
  return Bitransformer(encoder, decoder) 
开发者ID:tensorflow,项目名称:mesh,代码行数:53,代码来源:transformer.py


注:本文中的gin.configurable方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。