当前位置: 首页>>代码示例>>Python>>正文


Python util.fn_args函数代码示例

本文整理汇总了Python中tensorflow.python.estimator.util.fn_args函数的典型用法代码示例。如果您正苦于以下问题:Python fn_args函数的具体用法?Python fn_args怎么用?Python fn_args使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了fn_args函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: add

 def add(self, layer_func):
   if isinstance(layer_func, base.Layer):
     args = estimator_util.fn_args(layer_func.call)
     self.track_layer(layer_func)
   elif callable(layer_func):
     args = estimator_util.fn_args(layer_func)
   else:
     raise TypeError(
         "Sequential.add() takes only tf.layers.Layer objects or callables; "
         "not '%s' of type '%s'." % (layer_func, type(layer_func)))
   self._layers_funcs.append((("training" in args), layer_func))
开发者ID:bikong2,项目名称:tensorflow,代码行数:11,代码来源:network.py

示例2: _call_model_fn

  def _call_model_fn(self, features, labels, add_batch_size_in_params=False):
    """Calls the model_fn with required parameters."""
    model_fn_args = util.fn_args(self._model_fn)
    kwargs = {}

    config = copy.deepcopy(self._config)
    params = copy.deepcopy(self._params)

    if 'labels' in model_fn_args:
      kwargs['labels'] = labels
    else:
      if labels is not None:
        raise ValueError(
            'model_fn does not take labels, but input_fn returns labels.')
    if 'mode' in model_fn_args:
      kwargs['mode'] = self._mode
    if 'config' in model_fn_args:
      kwargs['config'] = config
    if 'params' in model_fn_args:
      kwargs['params'] = params

    if add_batch_size_in_params:
      if 'params' not in model_fn_args:
        raise ValueError(
            'model_fn ({}) does not include params argument, '
            'required by TPUEstimator to pass batch size as '
            'params[\'batch_size\']'.format(self._model_fn))
      if self._mode == model_fn_lib.ModeKeys.TRAIN:
        # For TPU training. `params` is never `None`.
        params[_BATCH_SIZE_KEY] = _per_shard_batch_size(self._train_batch_size,
                                                        config)

    return self._model_fn(features=features, **kwargs)
开发者ID:awisbith,项目名称:tensorflow,代码行数:33,代码来源:tpu_estimator.py

示例3: _call_loss_fn

def _call_loss_fn(loss_fn, labels, logits, features):
  """Calls loss_fn and checks the returned shape.

  Args:
    loss_fn: The loss function.
    labels: Processed labels Tensor.
    logits: Logits Tensor of shape [batch_size, logits_dimension].
    features: Features dict.
  Returns:
    Loss Tensor with shape [batch_size, 1].
  """
  loss_fn_args = util.fn_args(loss_fn)
  kwargs = {}
  if 'features' in loss_fn_args:
    kwargs['features'] = features
  unweighted_loss = loss_fn(labels=labels, logits=logits, **kwargs)
  batch_size = array_ops.shape(logits)[0]
  loss_shape = array_ops.shape(unweighted_loss)
  check_shape_op = control_flow_ops.Assert(
      math_ops.reduce_all(math_ops.equal(loss_shape, [batch_size, 1])),
      data=[
          'loss_fn must return Tensor of shape [batch_size, 1]. Given: ',
          loss_shape])
  with ops.control_dependencies([check_shape_op]):
    return array_ops.identity(unweighted_loss)
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:25,代码来源:head.py

示例4: export

  def export(self,
             estimator,
             export_path,
             checkpoint_path=None,
             eval_result=None):
    """Exports the given Estimator to a specific format.

    Args:
      estimator: the Estimator to export.
      export_path: A string containing a directory where to write the export.
      checkpoint_path: The checkpoint path to export.  If None (the default),
        the strategy may locate a checkpoint (e.g. the most recent) by itself.
      eval_result: The output of Estimator.evaluate on this checkpoint.  This
        should be set only if checkpoint_path is provided (otherwise it is
        unclear which checkpoint this eval refers to).

    Returns:
      The string path to the exported directory.

    Raises:
      ValueError: if the export_fn does not have the required signature.
    """
    export_fn_args = util.fn_args(self.export_fn)
    kwargs = {}
    if 'checkpoint_path' in export_fn_args:
      kwargs['checkpoint_path'] = checkpoint_path
    if 'eval_result' in export_fn_args:
      if 'checkpoint_path' not in export_fn_args:
        raise ValueError('An export_fn accepting eval_result must also accept '
                         'checkpoint_path.')
      kwargs['eval_result'] = eval_result

    return self.export_fn(estimator, export_path, **kwargs)
开发者ID:Crazyonxh,项目名称:tensorflow,代码行数:33,代码来源:export_strategy.py

示例5: _call_model_fn

  def _call_model_fn(self, features, labels, mode, config):
    """Calls model function.

    Args:
      features: features dict.
      labels: labels dict.
      mode: ModeKeys
      config: RunConfig

    Returns:
      An `EstimatorSpec` object.

    Raises:
      ValueError: if model_fn returns invalid objects.
    """
    model_fn_args = util.fn_args(self._model_fn)
    kwargs = {}
    if 'labels' in model_fn_args:
      kwargs['labels'] = labels
    else:
      if labels is not None:
        raise ValueError(
            'model_fn does not take labels, but input_fn returns labels.')
    if 'mode' in model_fn_args:
      kwargs['mode'] = mode
    if 'params' in model_fn_args:
      kwargs['params'] = self.params
    if 'config' in model_fn_args:
      kwargs['config'] = config
    model_fn_results = self._model_fn(features=features, **kwargs)

    if not isinstance(model_fn_results, model_fn_lib.EstimatorSpec):
      raise ValueError('model_fn should return an EstimatorSpec.')

    return model_fn_results
开发者ID:ilya-edrenkin,项目名称:tensorflow,代码行数:35,代码来源:estimator.py

示例6: call_logit_fn

def call_logit_fn(logit_fn, features, mode, params, config):
  """Calls logit_fn.

  A utility function that calls the provided logit_fn with the relevant subset
  of provided arguments.  Similar to tf.estimator._call_model_fn().

  Args:
    logit_fn: A logit_fn as defined above.
    features: The features dict.
    mode: TRAIN / EVAL / PREDICT ModeKeys.
    params: The hyperparameter dict.
    config: The configuration object.

  Returns:
    A logit Tensor, the output of logit_fn.

  Raises:
    ValueError: if logit_fn does not return a Tensor.
  """
  logit_fn_args = util.fn_args(logit_fn)
  kwargs = {}
  if 'mode' in logit_fn_args:
    kwargs['mode'] = mode
  if 'params' in logit_fn_args:
    kwargs['params'] = params
  if 'config' in logit_fn_args:
    kwargs['config'] = config
  logit_fn_results = logit_fn(features=features, **kwargs)

  if not isinstance(logit_fn_results, ops.Tensor):
    raise ValueError('model_fn should return a Tensor.')

  return logit_fn_results
开发者ID:1000sprites,项目名称:tensorflow,代码行数:33,代码来源:logit_fns.py

示例7: _call_input_fn

  def _call_input_fn(self, input_fn, mode):
    """Calls the input function.

    Args:
      input_fn: The input function.
      mode: ModeKeys

    Returns:
      Either features or (features, labels) where features and labels are:
        features - `Tensor` or dictionary of string feature name to `Tensor`.
        labels - `Tensor` or dictionary of `Tensor` with labels.

    Raises:
      ValueError: if input_fn takes invalid arguments.
    """
    input_fn_args = util.fn_args(input_fn)
    kwargs = {}
    if 'mode' in input_fn_args:
      kwargs['mode'] = mode
    if 'params' in input_fn_args:
      kwargs['params'] = self.params
    if 'config' in input_fn_args:
      kwargs['config'] = self.config
    with ops.device('/cpu:0'):
      return input_fn(**kwargs)
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:25,代码来源:estimator.py

示例8: run_step_fn

  def run_step_fn(self, step_fn):
    """Run ops using a step function.

    Args:
      step_fn: A function or a method with a single argument of type
        `StepContext`.  The function may use methods of the argument to
        perform computations with access to a raw session.

        The returned value of the `step_fn` will be returned from `run_step_fn`,
        unless a stop is requested.  In that case, the next `should_stop` call
        will return True.

        Example usage:

        ```python
           with tf.Graph().as_default():
             c = tf.placeholder(dtypes.float32)
             v = tf.add(c, 4.0)
             w = tf.add(c, 0.5)

             def step_fn(step_context):
               a = step_context.session.run(fetches=v, feed_dict={c: 0.5})
               if a <= 4.5:
                 step_context.request_stop()
               return step_context.run_with_hooks(fetches=w, feed_dict={c: 0.1})

             with tf.MonitoredSession() as session:
               while not session.should_stop():
                 a = session.run_step_fn(step_fn)
        ```

        Hooks interact with the `run_with_hooks()` call inside the `step_fn`
        as they do with a `MonitoredSession.run` call.

    Returns:
      Returns the returned value of `step_fn`.

    Raises:
      StopIteration: if `step_fn` has called `request_stop()`.  It may be
        caught by `with tf.MonitoredSession()` to close the session.
      ValueError: if `step_fn` doesn't have a single argument called
        `step_context`. It may also optionally have `self` for cases when it
        belongs to an object.
    """
    step_fn_arguments = util.fn_args(step_fn)
    if step_fn_arguments != ('step_context',) and step_fn_arguments != (
        'self',
        'step_context',
    ):
      raise ValueError(
          '`step_fn` may either have one `step_context` argument, or'
          ' `self` and `step_context` arguments if it\'s an instance'
          ' method. Got {} instead.'.format(step_fn_arguments))

    # `self._sess` is either `_RecoverableSession` or a `_CoordinatedSession`.
    # Setting `run_with_hooks` to `None` will cause `run_with_hooks` to be
    # `_CoordinatedSession.run` downstream in either case. This allows
    # `_PREEMPTION_ERRORS` to propage from within `step_fn` to
    # `_RecoverableSession.run_step_fn`.
    return self._sess.run_step_fn(step_fn, self._tf_sess(), run_with_hooks=None)
开发者ID:SylChan,项目名称:tensorflow,代码行数:60,代码来源:monitored_session.py

示例9: test_callable

  def test_callable(self):

    class Foo(object):

      def __call__(self, a, b):
        return a + b

    self.assertEqual(('self', 'a', 'b'), util.fn_args(Foo()))
开发者ID:1000sprites,项目名称:tensorflow,代码行数:8,代码来源:util_test.py

示例10: test_bounded_method

  def test_bounded_method(self):

    class Foo(object):

      def bar(self, a, b):
        return a + b

    self.assertEqual(('a', 'b'), util.fn_args(Foo().bar))
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:8,代码来源:util_test.py

示例11: _call_input_fn

  def _call_input_fn(self, input_fn, mode):
    """Calls the input function.

    Args:
      input_fn: The input function.
      mode: ModeKeys

    Returns:
      Either features or (features, labels) where features and labels are:
        features - `Tensor` or dictionary of string feature name to `Tensor`.
        labels - `Tensor` or dictionary of `Tensor` with labels.

    Raises:
      ValueError: if input_fn takes invalid arguments or does not have `params`.
    """
    input_fn_args = util.fn_args(input_fn)
    config = self.config  # a deep copy.
    kwargs = {}
    if 'params' in input_fn_args:
      kwargs['params'] = self.params  # a deep copy.
    else:
      raise ValueError('input_fn ({}) does not include params argument, '
                       'required by TPUEstimator to pass batch size as '
                       'params["batch_size"]'.format(input_fn))
    if 'config' in input_fn_args:
      kwargs['config'] = config

    # Now for TPU training.
    if mode == model_fn_lib.ModeKeys.TRAIN:
      kwargs['params'][_BATCH_SIZE_KEY] = (
          _per_shard_batch_size(self._train_batch_size, config, self._use_tpu)
          if not config.tpu_config.per_host_input_for_training else
          self._train_batch_size)

    if not self._use_tpu or mode != model_fn_lib.ModeKeys.TRAIN:
      with ops.device('/cpu:0'):
        return input_fn(**kwargs)

    job = _tpu_job(config)
    def placement_function(index):
      if job is None:
        return '/replica:0/task:0/device:CPU:0'
      else:
        return '/job:%s/replica:0/task:%d/device:CPU:0' % (job, index / 8)

    if not config.tpu_config.per_host_input_for_training:
      num_shards = config.tpu_config.num_shards
      inputs = _InputsHolder(num_shards=num_shards)
      for i in range(config.tpu_config.num_shards):
        with ops.device(placement_function(i)):
          inputs.append_tuple(input_fn(**kwargs))

      return inputs.as_features_and_labels_tuple()
    else:
      # TODO(xiejw): Extend this to multi-host support.
      with ops.device(placement_function(0)):
        return input_fn(**kwargs)
开发者ID:Dr4KK,项目名称:tensorflow,代码行数:57,代码来源:tpu_estimator.py

示例12: _call_input_fn

  def _call_input_fn(self, input_fn, mode):
    """Calls the input function.

    Args:
      input_fn: The input function.
      mode: ModeKeys

    Returns:
      Either features or (features, labels) where features and labels are:
        features - `Tensor` or dictionary of string feature name to `Tensor`.
        labels - `Tensor` or dictionary of `Tensor` with labels.

    Raises:
      ValueError: if input_fn takes invalid arguments or does not have `params`.
    """
    if not self._use_tpu or mode != model_fn_lib.ModeKeys.TRAIN:
      return super(TpuEstimator, self)._call_input_fn(input_fn, mode)

    input_fn_args = util.fn_args(input_fn)
    config = self.config  # a deep copy.
    kwargs = {}
    if 'params' in input_fn_args:
      kwargs['params'] = self.params  # a deep copy.
    else:
      raise ValueError('input_fn ({}) does not include params argument, '
                       'required by TPUEstimator to pass batch size as '
                       'params["batch_size"]'.format(input_fn))
    if 'config' in input_fn_args:
      kwargs['config'] = config

    # Now for TPU training.
    per_shard_batch_size = _per_shard_batch_size(self._train_batch_size, config)
    kwargs['params'][_BATCH_SIZE_KEY] = per_shard_batch_size

    job = _tpu_job(config)
    def placement_function(index):
      if job is None:
        return '/replica:0/task:0/device:CPU:0'
      else:
        return '/job:%s/replica:0/task:%d/device:CPU:0' % (job, index / 8)

    features = []
    labels = []
    for i in range(config.tpu_config.num_shards):
      with ops.device(placement_function(i)):
        result = input_fn(**kwargs)
        # input_fn may return either features or (features, labels)
        if isinstance(result, tuple):
          features.append(result[0])
          labels.append(result[1])
        else:
          features.append(result)

    if not labels or all(l is None for l in labels):
      return _PerShardOutput(features), None

    return _PerShardOutput(features), _PerShardOutput(labels)
开发者ID:AutumnQYN,项目名称:tensorflow,代码行数:57,代码来源:tpu_estimator.py

示例13: _verify_metric_fn_args

def _verify_metric_fn_args(metric_fn):
  args = set(estimator_util.fn_args(metric_fn))
  if tf_inspect.ismethod(metric_fn):
    if 'self' in args:
      args.remove('self')
  invalid_args = list(args - _VALID_METRIC_FN_ARGS)
  if invalid_args:
    raise ValueError('metric_fn (%s) has following not expected args: %s' %
                     (metric_fn, invalid_args))
开发者ID:1000sprites,项目名称:tensorflow,代码行数:9,代码来源:extenders.py

示例14: _get_standardized_predicate_fn

def _get_standardized_predicate_fn(predicate_fn):
  pred_fn_args = estimator_util.fn_args(predicate_fn)
  if "checkpoint_path" not in pred_fn_args:
    # pylint: disable=unused-argument
    def _pred_fn_wrapper(eval_results, checkpoint_path):
      return predicate_fn(eval_results)

    return _pred_fn_wrapper
  else:
    return predicate_fn
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:10,代码来源:experiment.py

示例15: test_partial_function

  def test_partial_function(self):
    expected_test_arg = 123

    def fn(a, test_arg):
      if test_arg != expected_test_arg:
        return ValueError('partial fn does not work correctly')
      return a

    wrapped_fn = functools.partial(fn, test_arg=123)

    self.assertEqual(('a',), util.fn_args(wrapped_fn))
开发者ID:1000sprites,项目名称:tensorflow,代码行数:11,代码来源:util_test.py


注:本文中的tensorflow.python.estimator.util.fn_args函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。