当前位置: 首页>>代码示例>>Python>>正文


Python gin.tf方法代码示例

本文整理汇总了Python中gin.tf方法的典型用法代码示例。如果您正苦于以下问题:Python gin.tf方法的具体用法?Python gin.tf怎么用?Python gin.tf使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在gin的用法示例。


在下文中一共展示了gin.tf方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get_eval_hooks

# 需要导入模块: import gin [as 别名]
# 或者: from gin import tf [as 别名]
def get_eval_hooks(self, config, params):
    """Get eval_hooks to be passed to estimator spec."""
    logging.warning('This function is deprecated and will be replaced.')
    hooks = []
    summary_op = tf.summary.merge_all()
    if summary_op is not None:
      eval_name = 'eval'
      if params is not None:
        eval_name = params.get('eval_name', eval_name)
      hooks = [
          tf.train.SummarySaverHook(
              output_dir=os.path.join(config.model_dir, eval_name),
              save_steps=config.save_summary_steps,
              summary_op=summary_op),
      ]
    return hooks

  #############################################################################
  # END DEPRECATED functions which will be removed soon.
  ############################################################################# 
开发者ID:google-research,项目名称:tensor2robot,代码行数:22,代码来源:abstract_model.py

示例2: get_variable_dtype

# 需要导入模块: import gin [as 别名]
# 或者: from gin import tf [as 别名]
def get_variable_dtype(
    master_dtype=tf.bfloat16,
    slice_dtype=tf.float32,
    activation_dtype=tf.float32):
  """Datatypes to use for the run.

  Args:
    master_dtype: string, datatype for checkpoints
      keep this the same between training and eval/inference
    slice_dtype: string, datatype for variables in memory
      must be tf.float32 for training
    activation_dtype: string, datatype for activations
      less memory usage if tf.bfloat16 but possible numerical issues
  Returns:
    a mtf.VariableDtype
  """
  return mtf.VariableDType(
      master_dtype=tf.as_dtype(master_dtype),
      slice_dtype=tf.as_dtype(slice_dtype),
      activation_dtype=tf.as_dtype(activation_dtype)) 
开发者ID:tensorflow,项目名称:mesh,代码行数:22,代码来源:utils.py

示例3: clean_decodes

# 需要导入模块: import gin [as 别名]
# 或者: from gin import tf [as 别名]
def clean_decodes(ids, eos_id=1, pad_id=0, length_axis=-1):
  """Replaces everything after EOS with PAD (along last axis).

  Args:
    ids: a d Tensor of type int.
    eos_id: int, EOS id.
    pad_id: int, PAD id.
    length_axis: an integer.

  Returns:
    a Tensor of type int of ids.
  """
  eos_and_after = tf.cumsum(tf.cast(tf.equal(ids, eos_id), tf.int32),
                            exclusive=True, axis=length_axis)
  valid_ids = tf.equal(eos_and_after, 0)
  return tf.where_v2(valid_ids, ids, pad_id) 
开发者ID:tensorflow,项目名称:mesh,代码行数:18,代码来源:utils.py

示例4: _get_latest_checkpoint_from_dir

# 需要导入模块: import gin [as 别名]
# 或者: from gin import tf [as 别名]
def _get_latest_checkpoint_from_dir(model_dir):
  """Helper function to return the latest checkpoint number from a directory.

  Args:
    model_dir: str, Directory with checkpoint files.

  Returns:
    an int, latest checkpoint number.

  Raises:
    ValueError: if no checkpoints are found.
  """
  ckpt = tf.train.latest_checkpoint(model_dir)
  if ckpt is None:
    raise ValueError("No checkpoints found in model directory: %s" % model_dir)
  return int(re.sub(".*ckpt-", "", ckpt)) 
开发者ID:google-research,项目名称:text-to-text-transfer-transformer,代码行数:18,代码来源:mtf_model.py

示例5: default_init_from_checkpoint_fn

# 需要导入模块: import gin [as 别名]
# 或者: from gin import tf [as 别名]
def default_init_from_checkpoint_fn(checkpoint,
                                    allow_partial_restore = False):
  """init_from_checkpoint_fn that can be used to init a model from a checkpoint.

  Args:
    checkpoint: String pointing to path of TF checkpoint.
    allow_partial_restore: If True, we allow partial restore, otherwise we raise
      an error if a variable cannot be restored.

  Raises:
    A ValueError if a variable(s) is missing and partial restore is not
    explicitly enabled.
  """
  logging.info('Initializing model weights from %s', checkpoint)
  reader = tf.train.load_checkpoint(checkpoint)
  variables_to_restore = contrib_framework.get_variables()
  assignment_map = {}
  for v in variables_to_restore:
    op_name = v.op.name
    if reader.has_tensor(op_name):
      logging.info('Loading variable %s from checkpoint', op_name)
      assignment_map[op_name] = v
    elif allow_partial_restore:
      logging.warning('Variable %s is not in the checkpoint, skipping.',
                      op_name)
    else:
      raise ValueError('Attempting to restore variable {} which is '
                       'not in the checkpoint.'.format(op_name))

  tf.train.init_from_checkpoint(checkpoint, assignment_map) 
开发者ID:google-research,项目名称:tensor2robot,代码行数:32,代码来源:abstract_model.py

示例6: get_feature_specification

# 需要导入模块: import gin [as 别名]
# 或者: from gin import tf [as 别名]
def get_feature_specification(
      self, mode):
    """Required features for the model_fn/model_inference_fn.

    Note, the model_fn might use additional features for debugging/development
    purposes. The create_export_outputs_fn will however only require the
    specified required features. Only this subset of features will be used to
    generate automatic tf.Example extractors and numpy placeholders for the
    serving models.

    Args:
      mode: The mode for feature specifications
    """ 
开发者ID:google-research,项目名称:tensor2robot,代码行数:15,代码来源:abstract_model.py

示例7: get_run_config

# 需要导入模块: import gin [as 别名]
# 或者: from gin import tf [as 别名]
def get_run_config(self):
    """Get the RunConfig for Estimator model.

    Returns:
      tf.estimator.RunConfig() for this model.
    """
    return gin_configurable_run_config_cls(
        session_config=self.get_session_config()) 
开发者ID:google-research,项目名称:tensor2robot,代码行数:10,代码来源:abstract_model.py

示例8: get_session_config

# 需要导入模块: import gin [as 别名]
# 或者: from gin import tf [as 别名]
def get_session_config(self):
    """Get the session config for Estimator model.

    Defaults to None which tells tf.Estimator to use its default session config.
    Not used in TPU jobs at the moment.

    Returns:
      None, or the desired session config.
    """
    return None 
开发者ID:google-research,项目名称:tensor2robot,代码行数:12,代码来源:abstract_model.py

示例9: get_feature_specification

# 需要导入模块: import gin [as 别名]
# 或者: from gin import tf [as 别名]
def get_feature_specification(
      self, mode):
    """Returns the feature specification with bfloat16 replacing float32."""
    return tensorspec_utils.replace_dtype(
        self._t2r_model.get_feature_specification(mode),
        from_dtype=tf.float32,
        to_dtype=tf.bfloat16) 
开发者ID:google-research,项目名称:tensor2robot,代码行数:9,代码来源:tpu_model_wrapper.py

示例10: get_label_specification

# 需要导入模块: import gin [as 别名]
# 或者: from gin import tf [as 别名]
def get_label_specification(self, mode):
    """Returns the label specification with bfloat16 replacing float32."""
    return tensorspec_utils.replace_dtype(
        self._t2r_model.get_label_specification(mode),
        from_dtype=tf.float32,
        to_dtype=tf.bfloat16) 
开发者ID:google-research,项目名称:tensor2robot,代码行数:8,代码来源:tpu_model_wrapper.py

示例11: metric_sum

# 需要导入模块: import gin [as 别名]
# 或者: from gin import tf [as 别名]
def metric_sum(values, name=None, **kwargs):
  del kwargs
  with tf.variable_scope(name, "metric_sum", [values]):
    accum = tf.get_variable(
        "accum", shape=[], dtype=tf.float32, trainable=False,
        collections=[tf.GraphKeys.LOCAL_VARIABLES],
        initializer=tf.zeros_initializer())
    update_op = tf.assign_add(accum, tf.reduce_sum(tf.cast(values, tf.float32)))
    return accum, update_op 
开发者ID:tensorflow,项目名称:mesh,代码行数:11,代码来源:utils.py

示例12: metric_max

# 需要导入模块: import gin [as 别名]
# 或者: from gin import tf [as 别名]
def metric_max(values, name=None, **kwargs):
  del kwargs
  with tf.variable_scope(name, "metric_max", [values]):
    accum = tf.get_variable(
        "accum", shape=[], dtype=tf.float32, trainable=False,
        collections=[tf.GraphKeys.LOCAL_VARIABLES],
        initializer=tf.zeros_initializer())
    update_op = tf.assign(
        accum, tf.maximum(accum, tf.reduce_max(tf.cast(values, tf.float32))))
    return accum, update_op 
开发者ID:tensorflow,项目名称:mesh,代码行数:12,代码来源:utils.py

示例13: get_inputs_from_file

# 需要导入模块: import gin [as 别名]
# 或者: from gin import tf [as 别名]
def get_inputs_from_file(input_filename, ignore_comments=False):
  """Read data from file and strip new lines."""
  inputs = [line.rstrip() for line in tf.io.gfile.GFile(input_filename)]

  # Strip the last empty line.
  if not inputs[-1]:
    inputs.pop()

  if ignore_comments:
    inputs = [l for l in inputs if not l.startswith("#")]

  return inputs 
开发者ID:tensorflow,项目名称:mesh,代码行数:14,代码来源:utils.py

示例14: decode

# 需要导入模块: import gin [as 别名]
# 或者: from gin import tf [as 别名]
def decode(estimator,
           input_fn,
           vocabulary,
           checkpoint_path=None):
  """Decode from an input_fn.

  Args:
    estimator: a TPUEstimator
    input_fn: function that returns a tf.Dataset
    vocabulary: a vocabulary.Vocabulary or (inputs_vocabulary,
      targets_vocabulary) tuple
    checkpoint_path: an optional string

  Returns:
    list of decoded strings
  """
  result_iter = estimator.predict(
      input_fn, checkpoint_path=checkpoint_path)

  def _maybe_detokenize(value, vocab):
    if isinstance(value, six.binary_type):
      return value
    return vocab.decode([int(x) for x in value])

  decodes = []
  for i, result in enumerate(result_iter):
    input_string = _maybe_detokenize(
        result["inputs"], inputs_vocabulary(vocabulary))
    output_string = _maybe_detokenize(
        result["outputs"], targets_vocabulary(vocabulary))
    decodes.append(output_string)
    if i & (i - 1) == 0:
      # LOG every power of 2.
      tf.logging.info("decoded {}: {}".format(i, input_string))
      tf.logging.info("            -> {}".format(output_string))
  return decodes 
开发者ID:tensorflow,项目名称:mesh,代码行数:38,代码来源:utils.py

示例15: write_lines_to_file

# 需要导入模块: import gin [as 别名]
# 或者: from gin import tf [as 别名]
def write_lines_to_file(lines, filename):
  """Write each line to a filename, replacing the file if it exists.

  Args:
    lines: list of str, lines to write out.
    filename: str, path to filename.
  """
  if tf.io.gfile.exists(filename):
    tf.io.gfile.remove(filename)
  with tf.io.gfile.GFile(filename, "w") as output_file:
    for line in lines:
      output_file.write("{}\n".format(line)) 
开发者ID:tensorflow,项目名称:mesh,代码行数:14,代码来源:utils.py


注:本文中的gin.tf方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。