本文整理汇总了Python中tensorflow.python.util.function_utils.fn_args函数的典型用法代码示例。如果您正苦于以下问题:Python fn_args函数的具体用法?Python fn_args怎么用?Python fn_args使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了fn_args函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: add
def add(self, layer_func):
if isinstance(layer_func, base.Layer):
args = function_utils.fn_args(layer_func.call)
self.track_layer(layer_func)
elif callable(layer_func):
args = function_utils.fn_args(layer_func)
else:
raise TypeError(
"Sequential.add() takes only tf.layers.Layer objects or callables; "
"not '%s' of type '%s'." % (layer_func, type(layer_func)))
self._layers_funcs.append((("training" in args), layer_func))
示例2: eval_step
def eval_step():
"""A single step of evaluation."""
estimator_spec = self._call_model_fn(features, labels,
model_fn_lib.ModeKeys.EVAL, params)
try:
captured_scaffold_fn.capture(estimator_spec.scaffold_fn)
except AttributeError:
captured_scaffold_fn.capture(None)
eval_metric_fn = None
eval_metric_fn_tensors = []
try:
if estimator_spec.eval_metrics:
(eval_metric_fn, eval_metric_fn_tensors) = estimator_spec.eval_metrics
except AttributeError:
pass
# If a dictionary is provided, we need to convert it into a list sorted
# according to order of eval_metric_fn positional arguments.
if isinstance(eval_metric_fn_tensors, dict):
eval_metric_fn_args = function_utils.fn_args(eval_metric_fn)
eval_metric_fn_tensors = [
eval_metric_fn_tensors[i] for i in eval_metric_fn_args
]
captured_eval_metric_fn.capture(eval_metric_fn)
return tuple([estimator_spec.loss] + eval_metric_fn_tensors)
示例3: call
def call(*args):
kwargs = dict(
zip(function_utils.fn_args(getattr(self._type, name))[1:], args))
specs = self._type._tensor_specs(name, kwargs, self._constructor_kwargs)
if specs is None:
raise ValueError(
'No tensor specifications were provided for: %s' % name)
flat_dtypes = nest.flatten(nest.map_structure(lambda s: s.dtype, specs))
flat_shapes = nest.flatten(nest.map_structure(lambda s: s.shape, specs))
def py_call(*args):
try:
self._out.send(args)
result = self._out.recv()
if isinstance(result, Exception):
raise result
if result is not None:
return result
except Exception as e:
if isinstance(e, IOError):
raise StopIteration() # Clean exit.
else:
raise
result = tf.py_func(py_call, (name,) + tuple(args), flat_dtypes,
name=name)
if isinstance(result, tf.Operation):
return result
for t, shape in zip(result, flat_shapes):
t.set_shape(shape)
return nest.pack_sequence_as(specs, result)
示例4: run_step_fn
def run_step_fn(self, step_fn):
"""Run ops using a step function.
Args:
step_fn: A function or a method with a single argument of type
`StepContext`. The function may use methods of the argument to
perform computations with access to a raw session.
The returned value of the `step_fn` will be returned from `run_step_fn`,
unless a stop is requested. In that case, the next `should_stop` call
will return True.
Example usage:
```python
with tf.Graph().as_default():
c = tf.placeholder(dtypes.float32)
v = tf.add(c, 4.0)
w = tf.add(c, 0.5)
def step_fn(step_context):
a = step_context.session.run(fetches=v, feed_dict={c: 0.5})
if a <= 4.5:
step_context.request_stop()
return step_context.run_with_hooks(fetches=w, feed_dict={c: 0.1})
with tf.MonitoredSession() as session:
while not session.should_stop():
a = session.run_step_fn(step_fn)
```
Hooks interact with the `run_with_hooks()` call inside the `step_fn`
as they do with a `MonitoredSession.run` call.
Returns:
Returns the returned value of `step_fn`.
Raises:
StopIteration: if `step_fn` has called `request_stop()`. It may be
caught by `with tf.MonitoredSession()` to close the session.
ValueError: if `step_fn` doesn't have a single argument called
`step_context`. It may also optionally have `self` for cases when it
belongs to an object.
"""
step_fn_arguments = function_utils.fn_args(step_fn)
if step_fn_arguments != ('step_context',) and step_fn_arguments != (
'self',
'step_context',
):
raise ValueError(
'`step_fn` may either have one `step_context` argument, or'
' `self` and `step_context` arguments if it\'s an instance'
' method. Got {} instead.'.format(step_fn_arguments))
# `self._sess` is either `_RecoverableSession` or a `_CoordinatedSession`.
# Setting `run_with_hooks` to `None` will cause `run_with_hooks` to be
# `_CoordinatedSession.run` downstream in either case. This allows
# `_PREEMPTION_ERRORS` to propage from within `step_fn` to
# `_RecoverableSession.run_step_fn`.
return self._sess.run_step_fn(step_fn, self._tf_sess(), run_with_hooks=None)
示例5: test_bounded_method
def test_bounded_method(self):
class Foo(object):
def bar(self, a, b):
return a + b
self.assertEqual(('a', 'b'), function_utils.fn_args(Foo().bar))
示例6: test_callable
def test_callable(self):
class Foo(object):
def __call__(self, a, b):
return a + b
self.assertEqual(('a', 'b'), function_utils.fn_args(Foo()))
示例7: __init__
def __init__(self, type_, *constructor_args, **constructor_kwargs):
self._type = type_
self._constructor_kwargs = dict(
zip(function_utils.fn_args(type_.__init__)[1:], constructor_args))
self._constructor_kwargs.update(constructor_kwargs)
tf.add_to_collection(PyProcess.COLLECTION, self)
self._proxy = _TFProxy(type_, self._constructor_kwargs)
示例8: _get_standardized_predicate_fn
def _get_standardized_predicate_fn(predicate_fn):
pred_fn_args = function_utils.fn_args(predicate_fn)
if "checkpoint_path" not in pred_fn_args:
# pylint: disable=unused-argument
def _pred_fn_wrapper(eval_results, checkpoint_path):
return predicate_fn(eval_results)
return _pred_fn_wrapper
else:
return predicate_fn
示例9: _verify_estimator_spec
def _verify_estimator_spec(self, estimator_spec):
"""Verifies estimator spec contains correct data."""
# TODO(ycao): Implement estimator spec verification for other modes.
try:
if estimator_spec.scaffold:
logging.warning('EstimatorSpec.scaffold is ignored with XLA compilation'
'. Please use TPUEstimatorSpec.scaffold_fn instead.')
except AttributeError:
pass
try:
if estimator_spec.eval_metric_ops:
raise ValueError('EstimatorSpec.eval_metric_ops is not supported with '
'XLA compilation. Please use '
'TPUEstimatorSpec.eval_metrics instead.')
except AttributeError:
pass
if estimator_spec.mode == model_fn_lib.ModeKeys.EVAL:
# If estimator_spec is of type TPUEstimatorSpec and contains eval_metrics,
# check that eval_metrics contains eval_metric_fn and
# eval_metric_fn_tensors with matching arguments.
try:
eval_metrics = estimator_spec.eval_metrics
except AttributeError:
eval_metrics = None
if eval_metrics:
(eval_metric_fn, eval_metric_fn_tensors) = eval_metrics
eval_metric_fn_args = function_utils.fn_args(eval_metric_fn)
if isinstance(eval_metric_fn_tensors, dict):
missing_tensors = [
i for i in eval_metric_fn_args if i not in eval_metric_fn_tensors
]
additional_tensors = [
i for i in eval_metric_fn_tensors if i not in eval_metric_fn_args
]
if missing_tensors:
raise ValueError('Arguments %s are needed by metric_fn (first '
'element of TPUEstimatorSpec.eval_metrics) but '
'they are not provided by evaluation tensors '
'(second element of TPUEstimatorSpec.eval_metrics)'
'.' % missing_tensors)
if additional_tensors:
raise ValueError('Arguments %s are provided by evaluation tensors '
'(second element of TPUEstimatorSpec.eval_metrics)'
' but they are not needed by metric_fn (first '
'element of TPUEstimatorSpec.eval_metrics).' %
additional_tensors)
return estimator_spec
示例10: test_partial_function
def test_partial_function(self):
expected_test_arg = 123
def fn(a, test_arg):
if test_arg != expected_test_arg:
return ValueError('partial fn does not work correctly')
return a
wrapped_fn = functools.partial(fn, test_arg=123)
self.assertEqual(('a',), function_utils.fn_args(wrapped_fn))
示例11: _call_metric_fn
def _call_metric_fn(metric_fn, features, labels, predictions, config):
"""Calls metric fn with proper arguments."""
metric_fn_args = function_utils.fn_args(metric_fn)
kwargs = {}
if 'features' in metric_fn_args:
kwargs['features'] = features
if 'labels' in metric_fn_args:
kwargs['labels'] = labels
if 'predictions' in metric_fn_args:
kwargs['predictions'] = predictions
if 'config' in metric_fn_args:
kwargs['config'] = config
return metric_fn(**kwargs)
示例12: test_double_partial
def test_double_partial(self):
expected_test_arg1 = 123
expected_test_arg2 = 456
def fn(a, test_arg1, test_arg2):
if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2:
return ValueError('partial does not work correctly')
return a
wrapped_fn = functools.partial(fn, test_arg2=456)
double_wrapped_fn = functools.partial(wrapped_fn, test_arg1=123)
self.assertEqual(('a',), function_utils.fn_args(double_wrapped_fn))
示例13: test_double_partial_with_positional_args_in_both_layers
def test_double_partial_with_positional_args_in_both_layers(self):
expected_test_arg1 = 123
expected_test_arg2 = 456
def fn(test_arg1, test_arg2, a):
if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2:
return ValueError('partial fn does not work correctly')
return a
wrapped_fn = functools.partial(fn, 123) # binds to test_arg1
double_wrapped_fn = functools.partial(wrapped_fn, 456) # binds to test_arg2
self.assertEqual(('a',), function_utils.fn_args(double_wrapped_fn))
self.assertEqual(3, double_wrapped_fn(3))
self.assertEqual(3, double_wrapped_fn(a=3))
示例14: _call_model_fn
def _call_model_fn(self, features, labels, mode, params):
"""Calls the model_fn with required parameters."""
model_fn_args = function_utils.fn_args(self._model_fn)
kwargs = {}
if 'labels' in model_fn_args:
kwargs['labels'] = labels
elif labels is not None:
raise ValueError(
'model_fn does not take labels, but input_fn returns labels.')
if 'mode' in model_fn_args:
kwargs['mode'] = mode
if 'params' in model_fn_args:
kwargs['params'] = params
return self._verify_estimator_spec(
self._model_fn(features=features, **kwargs))
示例15: call_logit_fn
def call_logit_fn(logit_fn, features, mode, params, config):
"""Calls logit_fn.
A utility function that calls the provided logit_fn with the relevant subset
of provided arguments. Similar to tf.estimator._call_model_fn().
Args:
logit_fn: A logit_fn as defined above.
features: The features dict.
mode: TRAIN / EVAL / PREDICT ModeKeys.
params: The hyperparameter dict.
config: The configuration object.
Returns:
A logit Tensor, the output of logit_fn.
Raises:
ValueError: if logit_fn does not return a Tensor or a dictionary mapping
strings to Tensors.
"""
logit_fn_args = function_utils.fn_args(logit_fn)
kwargs = {}
if 'mode' in logit_fn_args:
kwargs['mode'] = mode
if 'params' in logit_fn_args:
kwargs['params'] = params
if 'config' in logit_fn_args:
kwargs['config'] = config
logit_fn_results = logit_fn(features=features, **kwargs)
result_is_valid_dictionary = (
isinstance(logit_fn_results, dict) and
all([(isinstance(k, six.string_types) and isinstance(v, ops.Tensor))
for k, v in six.iteritems(logit_fn_results)]))
result_is_tensor = isinstance(logit_fn_results, ops.Tensor)
if not (result_is_valid_dictionary or result_is_tensor):
raise ValueError('logit_fn should return a Tensor or a dictionary mapping '
'strings to Tensors. logit_fn returned: %s' %
logit_fn_results)
return logit_fn_results