当前位置: 首页>>代码示例>>Python>>正文


Python function_utils.fn_args函数代码示例

本文整理汇总了Python中tensorflow.python.util.function_utils.fn_args函数的典型用法代码示例。如果您正苦于以下问题:Python fn_args函数的具体用法?Python fn_args怎么用?Python fn_args使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了fn_args函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: add

 def add(self, layer_func):
   if isinstance(layer_func, base.Layer):
     args = function_utils.fn_args(layer_func.call)
     self.track_layer(layer_func)
   elif callable(layer_func):
     args = function_utils.fn_args(layer_func)
   else:
     raise TypeError(
         "Sequential.add() takes only tf.layers.Layer objects or callables; "
         "not '%s' of type '%s'." % (layer_func, type(layer_func)))
   self._layers_funcs.append((("training" in args), layer_func))
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:11,代码来源:network.py

示例2: eval_step

    def eval_step():
      """A single step of evaluation."""
      estimator_spec = self._call_model_fn(features, labels,
                                           model_fn_lib.ModeKeys.EVAL, params)

      try:
        captured_scaffold_fn.capture(estimator_spec.scaffold_fn)
      except AttributeError:
        captured_scaffold_fn.capture(None)

      eval_metric_fn = None
      eval_metric_fn_tensors = []
      try:
        if estimator_spec.eval_metrics:
          (eval_metric_fn, eval_metric_fn_tensors) = estimator_spec.eval_metrics
      except AttributeError:
        pass

      # If a dictionary is provided, we need to convert it into a list sorted
      # according to order of eval_metric_fn positional arguments.
      if isinstance(eval_metric_fn_tensors, dict):
        eval_metric_fn_args = function_utils.fn_args(eval_metric_fn)
        eval_metric_fn_tensors = [
            eval_metric_fn_tensors[i] for i in eval_metric_fn_args
        ]

      captured_eval_metric_fn.capture(eval_metric_fn)

      return tuple([estimator_spec.loss] + eval_metric_fn_tensors)
开发者ID:baojianzhou,项目名称:tensorflow,代码行数:29,代码来源:xla.py

示例3: call

    def call(*args):
      kwargs = dict(
          zip(function_utils.fn_args(getattr(self._type, name))[1:], args))
      specs = self._type._tensor_specs(name, kwargs, self._constructor_kwargs)

      if specs is None:
        raise ValueError(
            'No tensor specifications were provided for: %s' % name)

      flat_dtypes = nest.flatten(nest.map_structure(lambda s: s.dtype, specs))
      flat_shapes = nest.flatten(nest.map_structure(lambda s: s.shape, specs))

      def py_call(*args):
        try:
          self._out.send(args)
          result = self._out.recv()
          if isinstance(result, Exception):
            raise result
          if result is not None:
            return result
        except Exception as e:
          if isinstance(e, IOError):
            raise StopIteration()  # Clean exit.
          else:
            raise

      result = tf.py_func(py_call, (name,) + tuple(args), flat_dtypes,
                          name=name)

      if isinstance(result, tf.Operation):
        return result

      for t, shape in zip(result, flat_shapes):
        t.set_shape(shape)
      return nest.pack_sequence_as(specs, result)
开发者ID:reinforcementdriving,项目名称:scalable_agent,代码行数:35,代码来源:py_process.py

示例4: run_step_fn

  def run_step_fn(self, step_fn):
    """Run ops using a step function.

    Args:
      step_fn: A function or a method with a single argument of type
        `StepContext`.  The function may use methods of the argument to
        perform computations with access to a raw session.

        The returned value of the `step_fn` will be returned from `run_step_fn`,
        unless a stop is requested.  In that case, the next `should_stop` call
        will return True.

        Example usage:

        ```python
           with tf.Graph().as_default():
             c = tf.placeholder(dtypes.float32)
             v = tf.add(c, 4.0)
             w = tf.add(c, 0.5)

             def step_fn(step_context):
               a = step_context.session.run(fetches=v, feed_dict={c: 0.5})
               if a <= 4.5:
                 step_context.request_stop()
               return step_context.run_with_hooks(fetches=w, feed_dict={c: 0.1})

             with tf.MonitoredSession() as session:
               while not session.should_stop():
                 a = session.run_step_fn(step_fn)
        ```

        Hooks interact with the `run_with_hooks()` call inside the `step_fn`
        as they do with a `MonitoredSession.run` call.

    Returns:
      Returns the returned value of `step_fn`.

    Raises:
      StopIteration: if `step_fn` has called `request_stop()`.  It may be
        caught by `with tf.MonitoredSession()` to close the session.
      ValueError: if `step_fn` doesn't have a single argument called
        `step_context`. It may also optionally have `self` for cases when it
        belongs to an object.
    """
    step_fn_arguments = function_utils.fn_args(step_fn)
    if step_fn_arguments != ('step_context',) and step_fn_arguments != (
        'self',
        'step_context',
    ):
      raise ValueError(
          '`step_fn` may either have one `step_context` argument, or'
          ' `self` and `step_context` arguments if it\'s an instance'
          ' method. Got {} instead.'.format(step_fn_arguments))

    # `self._sess` is either `_RecoverableSession` or a `_CoordinatedSession`.
    # Setting `run_with_hooks` to `None` will cause `run_with_hooks` to be
    # `_CoordinatedSession.run` downstream in either case. This allows
    # `_PREEMPTION_ERRORS` to propage from within `step_fn` to
    # `_RecoverableSession.run_step_fn`.
    return self._sess.run_step_fn(step_fn, self._tf_sess(), run_with_hooks=None)
开发者ID:Huoxubeiyin,项目名称:tensorflow,代码行数:60,代码来源:monitored_session.py

示例5: test_bounded_method

  def test_bounded_method(self):

    class Foo(object):

      def bar(self, a, b):
        return a + b

    self.assertEqual(('a', 'b'), function_utils.fn_args(Foo().bar))
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:8,代码来源:function_utils_test.py

示例6: test_callable

  def test_callable(self):

    class Foo(object):

      def __call__(self, a, b):
        return a + b

    self.assertEqual(('a', 'b'), function_utils.fn_args(Foo()))
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:8,代码来源:function_utils_test.py

示例7: __init__

  def __init__(self, type_, *constructor_args, **constructor_kwargs):
    self._type = type_
    self._constructor_kwargs = dict(
        zip(function_utils.fn_args(type_.__init__)[1:], constructor_args))
    self._constructor_kwargs.update(constructor_kwargs)

    tf.add_to_collection(PyProcess.COLLECTION, self)

    self._proxy = _TFProxy(type_, self._constructor_kwargs)
开发者ID:reinforcementdriving,项目名称:scalable_agent,代码行数:9,代码来源:py_process.py

示例8: _get_standardized_predicate_fn

def _get_standardized_predicate_fn(predicate_fn):
  pred_fn_args = function_utils.fn_args(predicate_fn)
  if "checkpoint_path" not in pred_fn_args:
    # pylint: disable=unused-argument
    def _pred_fn_wrapper(eval_results, checkpoint_path):
      return predicate_fn(eval_results)

    return _pred_fn_wrapper
  else:
    return predicate_fn
开发者ID:BhaskarNallani,项目名称:tensorflow,代码行数:10,代码来源:experiment.py

示例9: _verify_estimator_spec

  def _verify_estimator_spec(self, estimator_spec):
    """Verifies estimator spec contains correct data."""
    # TODO(ycao): Implement estimator spec verification for other modes.

    try:
      if estimator_spec.scaffold:
        logging.warning('EstimatorSpec.scaffold is ignored with XLA compilation'
                        '. Please use TPUEstimatorSpec.scaffold_fn instead.')
    except AttributeError:
      pass

    try:
      if estimator_spec.eval_metric_ops:
        raise ValueError('EstimatorSpec.eval_metric_ops is not supported with '
                         'XLA compilation. Please use '
                         'TPUEstimatorSpec.eval_metrics instead.')
    except AttributeError:
      pass

    if estimator_spec.mode == model_fn_lib.ModeKeys.EVAL:
      # If estimator_spec is of type TPUEstimatorSpec and contains eval_metrics,
      # check that eval_metrics contains eval_metric_fn and
      # eval_metric_fn_tensors with matching arguments.
      try:
        eval_metrics = estimator_spec.eval_metrics
      except AttributeError:
        eval_metrics = None

      if eval_metrics:
        (eval_metric_fn, eval_metric_fn_tensors) = eval_metrics
        eval_metric_fn_args = function_utils.fn_args(eval_metric_fn)

        if isinstance(eval_metric_fn_tensors, dict):
          missing_tensors = [
              i for i in eval_metric_fn_args if i not in eval_metric_fn_tensors
          ]
          additional_tensors = [
              i for i in eval_metric_fn_tensors if i not in eval_metric_fn_args
          ]

          if missing_tensors:
            raise ValueError('Arguments %s are needed by metric_fn (first '
                             'element of TPUEstimatorSpec.eval_metrics) but '
                             'they are not provided by evaluation tensors '
                             '(second element of TPUEstimatorSpec.eval_metrics)'
                             '.' % missing_tensors)

          if additional_tensors:
            raise ValueError('Arguments %s are provided by evaluation tensors '
                             '(second element of TPUEstimatorSpec.eval_metrics)'
                             ' but they are not needed by metric_fn (first '
                             'element of TPUEstimatorSpec.eval_metrics).' %
                             additional_tensors)

    return estimator_spec
开发者ID:baojianzhou,项目名称:tensorflow,代码行数:55,代码来源:xla.py

示例10: test_partial_function

  def test_partial_function(self):
    expected_test_arg = 123

    def fn(a, test_arg):
      if test_arg != expected_test_arg:
        return ValueError('partial fn does not work correctly')
      return a

    wrapped_fn = functools.partial(fn, test_arg=123)

    self.assertEqual(('a',), function_utils.fn_args(wrapped_fn))
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:11,代码来源:function_utils_test.py

示例11: _call_metric_fn

def _call_metric_fn(metric_fn, features, labels, predictions, config):
  """Calls metric fn with proper arguments."""
  metric_fn_args = function_utils.fn_args(metric_fn)
  kwargs = {}
  if 'features' in metric_fn_args:
    kwargs['features'] = features
  if 'labels' in metric_fn_args:
    kwargs['labels'] = labels
  if 'predictions' in metric_fn_args:
    kwargs['predictions'] = predictions
  if 'config' in metric_fn_args:
    kwargs['config'] = config
  return metric_fn(**kwargs)
开发者ID:ZhangXinNan,项目名称:tensorflow,代码行数:13,代码来源:extenders.py

示例12: test_double_partial

  def test_double_partial(self):
    expected_test_arg1 = 123
    expected_test_arg2 = 456

    def fn(a, test_arg1, test_arg2):
      if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2:
        return ValueError('partial does not work correctly')
      return a

    wrapped_fn = functools.partial(fn, test_arg2=456)
    double_wrapped_fn = functools.partial(wrapped_fn, test_arg1=123)

    self.assertEqual(('a',), function_utils.fn_args(double_wrapped_fn))
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:13,代码来源:function_utils_test.py

示例13: test_double_partial_with_positional_args_in_both_layers

  def test_double_partial_with_positional_args_in_both_layers(self):
    expected_test_arg1 = 123
    expected_test_arg2 = 456

    def fn(test_arg1, test_arg2, a):
      if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2:
        return ValueError('partial fn does not work correctly')
      return a

    wrapped_fn = functools.partial(fn, 123)  # binds to test_arg1
    double_wrapped_fn = functools.partial(wrapped_fn, 456)  # binds to test_arg2

    self.assertEqual(('a',), function_utils.fn_args(double_wrapped_fn))

    self.assertEqual(3, double_wrapped_fn(3))
    self.assertEqual(3, double_wrapped_fn(a=3))
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:16,代码来源:function_utils_test.py

示例14: _call_model_fn

  def _call_model_fn(self, features, labels, mode, params):
    """Calls the model_fn with required parameters."""
    model_fn_args = function_utils.fn_args(self._model_fn)
    kwargs = {}

    if 'labels' in model_fn_args:
      kwargs['labels'] = labels
    elif labels is not None:
      raise ValueError(
          'model_fn does not take labels, but input_fn returns labels.')
    if 'mode' in model_fn_args:
      kwargs['mode'] = mode

    if 'params' in model_fn_args:
      kwargs['params'] = params

    return self._verify_estimator_spec(
        self._model_fn(features=features, **kwargs))
开发者ID:baojianzhou,项目名称:tensorflow,代码行数:18,代码来源:xla.py

示例15: call_logit_fn

def call_logit_fn(logit_fn, features, mode, params, config):
  """Calls logit_fn.

  A utility function that calls the provided logit_fn with the relevant subset
  of provided arguments.  Similar to tf.estimator._call_model_fn().

  Args:
    logit_fn: A logit_fn as defined above.
    features: The features dict.
    mode: TRAIN / EVAL / PREDICT ModeKeys.
    params: The hyperparameter dict.
    config: The configuration object.

  Returns:
    A logit Tensor, the output of logit_fn.

  Raises:
    ValueError: if logit_fn does not return a Tensor or a dictionary mapping
      strings to Tensors.
  """
  logit_fn_args = function_utils.fn_args(logit_fn)
  kwargs = {}
  if 'mode' in logit_fn_args:
    kwargs['mode'] = mode
  if 'params' in logit_fn_args:
    kwargs['params'] = params
  if 'config' in logit_fn_args:
    kwargs['config'] = config
  logit_fn_results = logit_fn(features=features, **kwargs)

  result_is_valid_dictionary = (
      isinstance(logit_fn_results, dict) and
      all([(isinstance(k, six.string_types) and isinstance(v, ops.Tensor))
           for k, v in six.iteritems(logit_fn_results)]))
  result_is_tensor = isinstance(logit_fn_results, ops.Tensor)

  if not (result_is_valid_dictionary or result_is_tensor):
    raise ValueError('logit_fn should return a Tensor or a dictionary mapping '
                     'strings to Tensors.  logit_fn returned: %s' %
                     logit_fn_results)

  return logit_fn_results
开发者ID:AnishShah,项目名称:tensorflow,代码行数:42,代码来源:logit_fns.py


注:本文中的tensorflow.python.util.function_utils.fn_args函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。