当前位置: 首页>>代码示例>>Python>>正文


Python gin.REQUIRED属性代码示例

本文整理汇总了Python中gin.REQUIRED属性的典型用法代码示例。如果您正苦于以下问题:Python gin.REQUIRED属性的具体用法?Python gin.REQUIRED怎么用?Python gin.REQUIRED使用的例子?那么, 这里精选的属性代码示例或许可以为您提供帮助。您也可以进一步了解该属性所在gin的用法示例。


在下文中一共展示了gin.REQUIRED属性的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __init__

# 需要导入模块: import gin [as 别名]
# 或者: from gin import REQUIRED [as 别名]
def __init__(self,
               embedding_model_class=gin.REQUIRED,
               reasoning_model_class=gin.REQUIRED,
               optimizer_fn=None):
    """Constructs a TwoStageModel.

    Args:
      embedding_model_class: Either `values`, `onehot`, or a class that has a
        __call__ function that takes as input a two-tuple of
        (batch_size, num_nodes, heigh, width, num_channels) tensors and returns
        two (batch_size, num_nodes, num_embedding_dims) tensors for both the
        context panels and the answer panels.
      reasoning_model_class: Class that has a __call__ function that takes as
        input a two-tuple of (batch_size, num_nodes, num_embedding_dims) tensors
        and returns the solution in a (batch_size,) tensor.
      optimizer_fn: Function that creates a tf.train optimizer.
    """
    if optimizer_fn is None:
      optimizer_fn = tf.train.AdamOptimizer
    self.optimizer_fn = optimizer_fn
    self.embedding_model_class = embedding_model_class
    self.reasoning_model_class = reasoning_model_class 
开发者ID:google-research,项目名称:disentanglement_lib,代码行数:24,代码来源:models.py

示例2: product_learning_rate

# 需要导入模块: import gin [as 别名]
# 或者: from gin import REQUIRED [as 别名]
def product_learning_rate(step,
                          total_train_steps,
                          factors=gin.REQUIRED,
                          offset=0):
  """Learning rate is the product of one or more factors.

  Takes a list of factors which are either numbers or learning-rate functions
  each taking step and total_train_step arguments.

  If `offset` is nonzero, then subtract offset from the step and from
  total_train_steps before computing the learning rate.

  Args:
    step: a tf.Scalar
    total_train_steps: a number
    factors: a list of numbers and/or functions
    offset: an optional float

  Returns:
    a tf.Scalar, the learning rate for the step.
  """
  ret = 1.0
  for f in factors:
    ret *= f(step - offset, total_train_steps - offset) if callable(f) else f
  return ret 
开发者ID:tensorflow,项目名称:mesh,代码行数:27,代码来源:learning_rate_schedules.py

示例3: evaluate_metrics

# 需要导入模块: import gin [as 别名]
# 或者: from gin import REQUIRED [as 别名]
def evaluate_metrics():
  """Evaluates metrics specified in the gin config."""
  # Parse gin config.
  gin.parse_config_files_and_bindings([p.gin_file], [])

  for algo in p.algos:
    for task in p.tasks:
      # Get the subdirectories corresponding to each run.
      summary_path = os.path.join(p.data_dir, algo, task)
      run_dirs = eval_metrics.get_run_dirs(summary_path, 'train', p.runs)

      # Evaluate metrics.
      outfile_prefix = os.path.join(p.metric_values_dir, algo, task) + '/'
      evaluator = eval_metrics.Evaluator(metrics=gin.REQUIRED)
      evaluator.write_metric_params(outfile_prefix)
      evaluator.evaluate(run_dirs=run_dirs, outfile_prefix=outfile_prefix) 
开发者ID:google-research,项目名称:rl-reliability-metrics,代码行数:18,代码来源:evaluate_metrics.py

示例4: batcher

# 需要导入模块: import gin [as 别名]
# 或者: from gin import REQUIRED [as 别名]
def batcher(data_streams=gin.REQUIRED, variable_shapes=True,
            batch_size_per_device=32, batch_size=None, eval_batch_size=32,
            bucket_length=32, buckets=None,
            buckets_include_inputs_in_length=False,
            batch_shuffle_size=None, max_eval_length=None,
            # TODO(afrozm): Unify padding logic.
            id_to_mask=None, strict_pad_on_len=False):
  """Batcher: create trax Inputs from single-example data-streams."""
  # TODO(lukaszkaiser, jonni): revisit arguments, their semantics and naming.
  # For now leaving the arguments as in batch_fn to reduce gin config changes.
  if callable(data_streams):  # If we pass a function, e.g., through gin, call.
    train_stream, eval_stream = data_streams()
  else:
    train_stream, eval_stream = data_streams
  # pylint: disable=g-long-lambda
  batch_train_stream = lambda n_devices: batch_fn(
      train_stream(), True, n_devices, variable_shapes,
      batch_size_per_device, batch_size, eval_batch_size,
      bucket_length, buckets, buckets_include_inputs_in_length,
      batch_shuffle_size, max_eval_length, id_to_mask, strict_pad_on_len)
  batch_eval_stream = lambda n_devices: batch_fn(
      eval_stream(), False, n_devices, variable_shapes,
      batch_size_per_device, batch_size, eval_batch_size,
      bucket_length, buckets, buckets_include_inputs_in_length,
      batch_shuffle_size, max_eval_length, id_to_mask, strict_pad_on_len)
  batch_train_eval_stream = lambda n_devices: batch_fn(
      train_stream(), False, n_devices, variable_shapes,
      batch_size_per_device, batch_size, eval_batch_size,
      bucket_length, buckets, buckets_include_inputs_in_length,
      batch_shuffle_size, max_eval_length, id_to_mask, strict_pad_on_len)
  # pylint: enable=g-long-lambda
  return Inputs(train_stream=batch_train_stream,
                eval_stream=batch_eval_stream,
                train_eval_stream=batch_train_eval_stream) 
开发者ID:google,项目名称:trax,代码行数:36,代码来源:inputs.py

示例5: random_inputs

# 需要导入模块: import gin [as 别名]
# 或者: from gin import REQUIRED [as 别名]
def random_inputs(
    input_shape=gin.REQUIRED, input_dtype=jnp.int32, input_range=(0, 255),
    output_shape=gin.REQUIRED, output_dtype=jnp.int32, output_range=(0, 9)):
  """Make random Inputs for debugging.

  Args:
    input_shape: the shape of inputs (including batch dimension).
    input_dtype: the type of the inputs (int32 by default).
    input_range: the range of inputs (defaults to (0, 255)).
    output_shape: the shape of outputs (including batch dimension).
    output_dtype: the type of the outputs (int32 by default).
    output_range: the range of outputs (defaults to (0, 9)).

  Returns:
    trax.inputs.Inputs
  """
  def random_minibatches(n_devices):
    """Generate a stream of random mini-batches."""
    assert input_range[0] % n_devices == 0
    if input_dtype in [jnp.float16, jnp.float32, jnp.float64]:
      rand = np.random.uniform
    else:
      rand = np.random.random_integers
    while True:
      inp = rand(input_range[0], input_range[1], input_shape)
      inp = inp.astype(input_dtype)
      out = rand(output_range[0], output_range[1], output_shape)
      out = out.astype(output_dtype)
      yield inp, out

  return Inputs(random_minibatches) 
开发者ID:google,项目名称:trax,代码行数:33,代码来源:inputs.py

示例6: get_tfds_vocabulary

# 需要导入模块: import gin [as 别名]
# 或者: from gin import REQUIRED [as 别名]
def get_tfds_vocabulary(dataset_name=gin.REQUIRED):
  info = tfds.builder(dataset_name).info
  # this assumes that either there are no inputs, or that the
  # inputs and targets have the same vocabulary.
  return TFDSVocabulary(info.features[info.supervised_keys[1]].encoder) 
开发者ID:tensorflow,项目名称:mesh,代码行数:7,代码来源:vocabulary.py

示例7: separate_vocabularies

# 需要导入模块: import gin [as 别名]
# 或者: from gin import REQUIRED [as 别名]
def separate_vocabularies(inputs=gin.REQUIRED, targets=gin.REQUIRED):
  """Gin-configurable helper function to generate a tuple of vocabularies."""
  return (inputs, targets)


# TODO(katherinelee): Update layout_rules string when noam updates the
# definition in run 
开发者ID:tensorflow,项目名称:mesh,代码行数:9,代码来源:utils.py

示例8: get_t2t_vocabulary

# 需要导入模块: import gin [as 别名]
# 或者: from gin import REQUIRED [as 别名]
def get_t2t_vocabulary(data_dir=gin.REQUIRED,
                       vocabulary_filename=gin.REQUIRED):
  return T2tVocabulary(os.path.join(data_dir, vocabulary_filename)) 
开发者ID:tensorflow,项目名称:mesh,代码行数:5,代码来源:t2t_vocabulary.py

示例9: constant_learning_rate

# 需要导入模块: import gin [as 别名]
# 或者: from gin import REQUIRED [as 别名]
def constant_learning_rate(step, total_train_steps, learning_rate=gin.REQUIRED):
  """Learning rate independent of step.

  DEPRECATED: use constant() or pass a float directly to utils.run.learning_rate

  Args:
    step: a tf.Scalar
    total_train_steps: a number
    learning_rate: a number or tf.Scalar

  Returns:
    a tf.Scalar, the learning rate for the step.
  """
  del step, total_train_steps
  return tf.cast(learning_rate, tf.float32) 
开发者ID:tensorflow,项目名称:mesh,代码行数:17,代码来源:learning_rate_schedules.py

示例10: __init__

# 需要导入模块: import gin [as 别名]
# 或者: from gin import REQUIRED [as 别名]
def __init__(self,
               layers_per_encoder_module=gin.REQUIRED,
               layers_per_decoder_module=gin.REQUIRED,
               encoder_num_modules=gin.REQUIRED,
               decoder_num_modules=gin.REQUIRED,
               dropout_rate=0.0,
               **kwargs):
    """Create a transparent attention EncDec Layer.

    Args:
      layers_per_encoder_module: positive integer telling how many layer are in
        each repeated module in the encoder
      layers_per_decoder_module: positive integer telling how many layer are in
        each repeated module in the decoder
      encoder_num_modules: positive integer of how many repeated modules there
        are in the encoder
      decoder_num_modules: positive integer of how many repeated modules there
        are in the decoder
      dropout_rate: positive float, the dropout rate for the matrix relating
        encoder outputs to decoder inputs
      **kwargs: additional constructor params
    """
    super(TransparentEncDecAttention, self).__init__(**kwargs)
    self.layers_per_encoder_module = layers_per_encoder_module
    self.layers_per_decoder_module = layers_per_decoder_module
    self.encoder_num_modules = encoder_num_modules
    self.decoder_num_modules = decoder_num_modules
    self.dropout_rate = dropout_rate 
开发者ID:tensorflow,项目名称:mesh,代码行数:30,代码来源:transformer_layers.py

示例11: make_text_line_dataset

# 需要导入模块: import gin [as 别名]
# 或者: from gin import REQUIRED [as 别名]
def make_text_line_dataset(glob=gin.REQUIRED):
  return sample_from_text_line_datasets([(glob, 1.0)]) 
开发者ID:tensorflow,项目名称:mesh,代码行数:4,代码来源:dataset.py

示例12: simple_text_line_dataset

# 需要导入模块: import gin [as 别名]
# 或者: from gin import REQUIRED [as 别名]
def simple_text_line_dataset(glob=gin.REQUIRED, shuffle_buffer_size=100000):
  return tf.data.TextLineDataset(
      tf.gfile.Glob(glob)).shuffle(shuffle_buffer_size) 
开发者ID:tensorflow,项目名称:mesh,代码行数:5,代码来源:dataset.py

示例13: untokenized_tfds_dataset

# 需要导入模块: import gin [as 别名]
# 或者: from gin import REQUIRED [as 别名]
def untokenized_tfds_dataset(dataset_name=gin.REQUIRED,
                             text2self=gin.REQUIRED,
                             tfds_data_dir=gin.REQUIRED,
                             dataset_split=gin.REQUIRED,
                             batch_size=None,
                             sequence_length=gin.REQUIRED,
                             vocabulary=gin.REQUIRED,
                             pack=gin.REQUIRED):
  """Reads a tensorflow_datasets dataset.

  Returns a tf.data.Dataset containing single tokenized examples where each
  feature ends in EOS=1.

  Args:
    dataset_name: a string
    text2self: a boolean, if true, run unsupervised LM-style training. if false,
      the dataset must support supervised mode.
    tfds_data_dir: a boolean
    dataset_split: a string
    batch_size: an integer
    sequence_length: an integer
    vocabulary: a vocabulary.Vocabulary
    pack: if True, multiple examples emitted by load_internal() are concatenated
        to form one combined example.
  Returns:
    a tf.data.Dataset of batches
  """
  del batch_size
  dataset = tfds.load(
      dataset_name, split=dataset_split,
      as_supervised=not text2self, data_dir=tfds_data_dir)
  if dataset_split == "train":
    dataset = dataset.repeat()
    dataset = dataset.shuffle(1000)
  if not text2self:
    dataset = supervised_to_dict(dataset, text2self)
  dataset = encode_all_features(dataset, vocabulary)
  return pack_or_pad(dataset, sequence_length, pack) 
开发者ID:tensorflow,项目名称:mesh,代码行数:40,代码来源:dataset.py

示例14: select_random_chunk

# 需要导入模块: import gin [as 别名]
# 或者: from gin import REQUIRED [as 别名]
def select_random_chunk(dataset,
                        max_length=gin.REQUIRED,
                        feature_key='targets',
                        **unused_kwargs):
  """Token-preprocessor to extract one span of at most `max_length` tokens.

  If the token sequence is longer than `max_length`, then we return a random
  subsequence.  Otherwise, we return the full sequence.

  This is generally followed by split_tokens.

  Args:
    dataset: a tf.data.Dataset with dictionaries containing the key feature_key.
    max_length: an integer
    feature_key: an string

  Returns:
    a dataset
  """
  def _my_fn(x):
    """Select a random chunk of tokens.

    Args:
      x: a 1d Tensor
    Returns:
      a 1d Tensor
    """
    tokens = x[feature_key]
    n_tokens = tf.size(tokens)
    num_segments = tf.cast(
        tf.ceil(tf.cast(n_tokens, tf.float32)
                / tf.cast(max_length, tf.float32)),
        tf.int32)
    start = max_length * tf.random_uniform(
        [], maxval=num_segments, dtype=tf.int32)
    end = tf.minimum(start + max_length, n_tokens)
    return {feature_key: tokens[start:end]}
  # Filter empty examples.
  dataset = dataset.filter(lambda x: tf.not_equal(tf.size(x[feature_key]), 0))
  return dataset.map(_my_fn, num_parallel_calls=num_parallel_calls()) 
开发者ID:google-research,项目名称:text-to-text-transfer-transformer,代码行数:42,代码来源:preprocessors.py

示例15: __init__

# 需要导入模块: import gin [as 别名]
# 或者: from gin import REQUIRED [as 别名]
def __init__(self,
               self_supervision="rotation_gan",
               rotated_batch_size=gin.REQUIRED,
               weight_rotation_loss_d=1.0,
               weight_rotation_loss_g=0.2,
               **kwargs):
    """Creates a new Self-Supervised GAN.

    Args:
      self_supervision: One of [rotation_gan, rotation_only, None]. When it is
        rotation_only, no GAN loss is used, degenerates to a pure rotation
        model.
      rotated_batch_size: The total number images per batch for the rotation
        loss. This must be a multiple of (4 * #CORES) since we consider 4
        rotations of each images on each TPU core. For GPU training #CORES is 1.
      weight_rotation_loss_d: Weight for the rotation loss for the discriminator
        on real images.
      weight_rotation_loss_g: Weight for the rotation loss for the generator
        on fake images.
      **kwargs: Additional arguments passed to `ModularGAN` constructor.
    """
    super(SSGAN, self).__init__(**kwargs)

    self._self_supervision = self_supervision
    self._rotated_batch_size = rotated_batch_size
    self._weight_rotation_loss_d = weight_rotation_loss_d
    self._weight_rotation_loss_g = weight_rotation_loss_g

    # To safe memory ModularGAN supports feeding real and fake samples
    # separately through the discriminator. SSGAN does not support this to
    # avoid additional additional complexity in create_loss().
    assert not self._deprecated_split_disc_calls, \
        "Splitting discriminator calls is not supported in SSGAN." 
开发者ID:google,项目名称:compare_gan,代码行数:35,代码来源:ssgan.py


注:本文中的gin.REQUIRED属性示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。