本文整理汇总了Python中tensorflow.python.util.tf_decorator.unwrap函数的典型用法代码示例。如果您正苦于以下问题:Python unwrap函数的具体用法?Python unwrap怎么用?Python unwrap使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了unwrap函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: __call__
def __call__(self, func):
"""Calls this decorator.
Args:
func: decorated symbol (function or class).
Returns:
The input function with _tf_api_names attribute set.
Raises:
SymbolAlreadyExposedError: Raised when a symbol already has API names
and kwarg `allow_multiple_exports` not set.
"""
api_names_attr = API_ATTRS[self._api_name].names
# Undecorate overridden names
for f in self._overrides:
_, undecorated_f = tf_decorator.unwrap(f)
delattr(undecorated_f, api_names_attr)
_, undecorated_func = tf_decorator.unwrap(func)
# Check for an existing api. We check if attribute name is in
# __dict__ instead of using hasattr to verify that subclasses have
# their own _tf_api_names as opposed to just inheriting it.
if api_names_attr in undecorated_func.__dict__:
raise SymbolAlreadyExposedError(
'Symbol %s is already exposed as %s.' %
(undecorated_func.__name__, getattr(
undecorated_func, api_names_attr))) # pylint: disable=protected-access
setattr(undecorated_func, api_names_attr, self._names)
return func
示例2: __call__
def __call__(self, func):
"""Calls this decorator.
Args:
func: decorated symbol (function or class).
Returns:
The input function with _tf_api_names attribute set.
Raises:
SymbolAlreadyExposedError: Raised when a symbol already has API names.
"""
# Undecorate overridden names
for f in self._overrides:
_, undecorated_f = tf_decorator.unwrap(f)
del undecorated_f._tf_api_names # pylint: disable=protected-access
_, undecorated_func = tf_decorator.unwrap(func)
# Check for an existing api. We check if attribute name is in
# __dict__ instead of using hasattr to verify that subclasses have
# their own _tf_api_names as opposed to just inheriting it.
if '_tf_api_names' in undecorated_func.__dict__:
# pylint: disable=protected-access
raise SymbolAlreadyExposedError(
'Symbol %s is already exposed as %s.' %
(undecorated_func.__name__, undecorated_func._tf_api_names))
# pylint: enable=protected-access
# Complete the export by creating/overriding attribute
# pylint: disable=protected-access
undecorated_func._tf_api_names = self._names
# pylint: enable=protected-access
return func
示例3: __call__
def __call__(self, func):
"""Calls this decorator.
Args:
func: decorated symbol (function or class).
Returns:
The input function with _tf_api_names attribute set.
Raises:
SymbolAlreadyExposedError: Raised when a symbol already has API names
and kwarg `allow_multiple_exports` not set.
"""
api_names_attr = API_ATTRS[self._api_name].names
api_names_attr_v1 = API_ATTRS_V1[self._api_name].names
# Undecorate overridden names
for f in self._overrides:
_, undecorated_f = tf_decorator.unwrap(f)
delattr(undecorated_f, api_names_attr)
delattr(undecorated_f, api_names_attr_v1)
_, undecorated_func = tf_decorator.unwrap(func)
self.set_attr(undecorated_func, api_names_attr, self._names)
self.set_attr(undecorated_func, api_names_attr_v1, self._names_v1)
return func
示例4: _op_is_in_tf_version
def _op_is_in_tf_version(op, version):
if version == 1:
return (tf_export.get_v1_names(tf_decorator.unwrap(op)[1]) or
op in _V1_OPS_THAT_DELEGATE_TO_V2_OPS)
elif version == 2:
return tf_export.get_v2_names(tf_decorator.unwrap(op)[1])
else:
raise ValueError('Expected version 1 or 2.')
示例5: fn_args
def fn_args(fn):
"""Get argument names for function-like object.
Args:
fn: Function, or function-like object (e.g., result of `functools.partial`).
Returns:
`tuple` of string argument names.
Raises:
ValueError: if partial function has positionally bound arguments
"""
_, fn = tf_decorator.unwrap(fn)
# Handle callables.
if hasattr(fn, '__call__') and tf_inspect.ismethod(fn.__call__):
return tuple(tf_inspect.getargspec(fn.__call__).args)
# Handle functools.partial and similar objects.
if hasattr(fn, 'func') and hasattr(fn, 'keywords') and hasattr(fn, 'args'):
# Handle nested partial.
original_args = fn_args(fn.func)
if not original_args:
return tuple()
return tuple([
arg for arg in original_args[len(fn.args):]
if arg not in set((fn.keywords or {}).keys())
])
# Handle function.
return tuple(tf_inspect.getargspec(fn).args)
示例6: testUnwrapBoundMethods
def testUnwrapBoundMethods(self):
test_decorated_class = TestDecoratedClass()
self.assertEqual([2, 2, 3], test_decorated_class.return_params(1, 2, 3))
decorators, target = tf_decorator.unwrap(test_decorated_class.return_params)
self.assertEqual('test_decorator_increment_first_int_arg',
decorators[0].decorator_name)
self.assertEqual([1, 2, 3], target(test_decorated_class, 1, 2, 3))
示例7: testUnwrapReturnsDecoratorListFromOutermostToInnermost
def testUnwrapReturnsDecoratorListFromOutermostToInnermost(self):
decorators, _ = tf_decorator.unwrap(test_decorated_function)
self.assertEqual('decorator 1', decorators[0].decorator_name)
self.assertEqual('test_decorator_increment_first_int_arg',
decorators[1].decorator_name)
self.assertEqual('decorator 3', decorators[2].decorator_name)
self.assertEqual('decorator 3 documentation', decorators[2].decorator_doc)
示例8: visit
def visit(unused_path, unused_parent, children):
"""Visitor that collects TF 2.0 names."""
for child in children:
_, attr = tf_decorator.unwrap(child[1])
api_names_v2 = tf_export.get_v2_names(attr)
for name in api_names_v2:
v2_names.add(name)
示例9: testReorderFileNeedsUpdate
def testReorderFileNeedsUpdate(self):
reordered_function_names = (
tf_upgrade_v2.TFAPIChangeSpec().reordered_function_names)
function_reorders = (
tf_upgrade_v2.TFAPIChangeSpec().function_reorders)
added_names_message = """Some function names in
self.reordered_function_names are not in reorders_v2.py.
Please run the following commands to update reorders_v2.py:
bazel build tensorflow/tools/compatibility/update:generate_v2_reorders_map
bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
"""
removed_names_message = """%s in self.reorders_v2 does not match
any name in self.reordered_function_names.
Please run the following commands to update reorders_v2.py:
bazel build tensorflow/tools/compatibility/update:generate_v2_reorders_map
bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
"""
self.assertTrue(
reordered_function_names.issubset(function_reorders),
added_names_message)
# function_reorders should contain reordered_function_names
# and their TensorFlow V1 aliases.
for name in function_reorders:
# get other names for this function
attr = get_symbol_for_name(tf.compat.v1, name)
_, attr = tf_decorator.unwrap(attr)
v1_names = tf_export.get_v1_names(attr)
self.assertTrue(v1_names)
v1_names = ["tf.%s" % n for n in v1_names]
# check if any other name is in
self.assertTrue(
any(n in reordered_function_names for n in v1_names),
removed_names_message % name)
示例10: getfullargspec
def getfullargspec(obj): # pylint: disable=redefined-builtin
"""TFDecorator-aware replacement for `inspect.getfullargspec`/`getargspec`.
This wrapper uses `inspect.getfullargspec` if available and falls back to
`inspect.getargspec` in Python 2.
Args:
obj: A callable, possibly decorated.
Returns:
The `FullArgSpec` that describes the signature of
the outermost decorator that changes the callable's signature. If the
callable is not decorated, `inspect.getfullargspec()` will be called
directly on the callable.
"""
if six.PY2:
def spec_fn(target):
argspecs = _inspect.getargspec(target)
fullargspecs = FullArgSpec(
args=argspecs.args,
varargs=argspecs.varargs,
varkw=argspecs.keywords,
defaults=argspecs.defaults,
kwonlyargs=[],
kwonlydefaults=None,
annotations={})
return fullargspecs
else:
spec_fn = _inspect.getfullargspec
decorators, target = tf_decorator.unwrap(obj)
return next((d.decorator_argspec for d in decorators
if d.decorator_argspec is not None), spec_fn(target))
示例11: get_canonical_name_for_symbol
def get_canonical_name_for_symbol(symbol, api_name=TENSORFLOW_API_NAME):
"""Get canonical name for the API symbol.
Canonical name is the first non-deprecated endpoint name.
Args:
symbol: API function or class.
api_name: API name (tensorflow or estimator).
Returns:
Canonical name for the API symbol (for e.g. initializers.zeros) if
canonical name could be determined. Otherwise, returns None.
"""
if not hasattr(symbol, '__dict__'):
return None
api_names_attr = API_ATTRS[api_name].names
_, undecorated_symbol = tf_decorator.unwrap(symbol)
if api_names_attr not in undecorated_symbol.__dict__:
return None
api_names = getattr(undecorated_symbol, api_names_attr)
# TODO(annarev): may be add a separate deprecated attribute
# for estimator names.
deprecated_api_names = undecorated_symbol.__dict__.get(
'_tf_deprecated_api_names', [])
return get_canonical_name(api_names, deprecated_api_names)
示例12: get_api_init_text
def get_api_init_text(packages,
output_package,
api_name,
api_version,
compat_api_versions=None):
"""Get a map from destination module to __init__.py code for that module.
Args:
packages: Base python packages containing python with target tf_export
decorators.
output_package: Base output python package where generated API will be
added.
api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
api_version: API version you want to generate (1 or 2).
compat_api_versions: Additional API versions to generate under compat/
directory.
Returns:
A dictionary where
key: (string) destination module (for e.g. tf or tf.consts).
value: (string) text that should be in __init__.py files for
corresponding modules.
"""
if compat_api_versions is None:
compat_api_versions = []
module_code_builder = _ModuleInitCodeBuilder(output_package)
# Traverse over everything imported above. Specifically,
# we want to traverse over TensorFlow Python modules.
def in_packages(m):
return any(package in m for package in packages)
for module in list(sys.modules.values()):
# Only look at tensorflow modules.
if (not module or not hasattr(module, '__name__') or
module.__name__ is None or not in_packages(module.__name__)):
continue
# Do not generate __init__.py files for contrib modules for now.
if (('.contrib.' in module.__name__ or module.__name__.endswith('.contrib'))
and '.lite' not in module.__name__):
continue
for module_contents_name in dir(module):
if (module.__name__ + '.' + module_contents_name
in _SYMBOLS_TO_SKIP_EXPLICITLY):
continue
attr = getattr(module, module_contents_name)
_, attr = tf_decorator.unwrap(attr)
add_imports_for_symbol(
module_code_builder, attr, module.__name__, module_contents_name,
api_name, api_version)
for compat_api_version in compat_api_versions:
add_imports_for_symbol(
module_code_builder, attr, module.__name__, module_contents_name,
api_name, compat_api_version,
_COMPAT_MODULE_TEMPLATE % compat_api_version)
return module_code_builder.build()
示例13: visit
def visit(unused_path, unused_parent, children):
"""Visitor that collects TF 2.0 names."""
for child in children:
_, attr = tf_decorator.unwrap(child[1])
if not hasattr(attr, '__dict__'):
continue
api_names_v2 = attr.__dict__.get(_TENSORFLOW_API_ATTR, [])
for name in api_names_v2:
v2_names.add(name)
示例14: testUnwrapReturnsListOfUniqueTFDecorators
def testUnwrapReturnsListOfUniqueTFDecorators(self):
decorators, _ = tf_decorator.unwrap(test_decorated_function)
self.assertEqual(3, len(decorators))
self.assertTrue(isinstance(decorators[0], tf_decorator.TFDecorator))
self.assertTrue(isinstance(decorators[1], tf_decorator.TFDecorator))
self.assertTrue(isinstance(decorators[2], tf_decorator.TFDecorator))
self.assertIsNot(decorators[0], decorators[1])
self.assertIsNot(decorators[1], decorators[2])
self.assertIsNot(decorators[2], decorators[0])
示例15: testRewrapMutatesAffectedFunction
def testRewrapMutatesAffectedFunction(self):
def new_target(x):
return x * 3
self.assertEqual((1 * 2 + 1) ** 2, test_rewrappable_decorated(1))
prev_target, _ = tf_decorator.unwrap(test_rewrappable_decorated)
tf_decorator.rewrap(test_rewrappable_decorated, prev_target, new_target)
self.assertEqual((1 * 3 + 1) ** 2, test_rewrappable_decorated(1))