当前位置: 首页>>代码示例>>Python>>正文


Python tf_decorator.unwrap函数代码示例

本文整理汇总了Python中tensorflow.python.util.tf_decorator.unwrap函数的典型用法代码示例。如果您正苦于以下问题:Python unwrap函数的具体用法?Python unwrap怎么用?Python unwrap使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了unwrap函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __call__

  def __call__(self, func):
    """Calls this decorator.

    Args:
      func: decorated symbol (function or class).

    Returns:
      The input function with _tf_api_names attribute set.

    Raises:
      SymbolAlreadyExposedError: Raised when a symbol already has API names
        and kwarg `allow_multiple_exports` not set.
    """
    api_names_attr = API_ATTRS[self._api_name].names

    # Undecorate overridden names
    for f in self._overrides:
      _, undecorated_f = tf_decorator.unwrap(f)
      delattr(undecorated_f, api_names_attr)

    _, undecorated_func = tf_decorator.unwrap(func)

    # Check for an existing api. We check if attribute name is in
    # __dict__ instead of using hasattr to verify that subclasses have
    # their own _tf_api_names as opposed to just inheriting it.
    if api_names_attr in undecorated_func.__dict__:
      raise SymbolAlreadyExposedError(
          'Symbol %s is already exposed as %s.' %
          (undecorated_func.__name__, getattr(
              undecorated_func, api_names_attr)))  # pylint: disable=protected-access
    setattr(undecorated_func, api_names_attr, self._names)
    return func
开发者ID:Huoxubeiyin,项目名称:tensorflow,代码行数:32,代码来源:tf_export.py

示例2: __call__

  def __call__(self, func):
    """Calls this decorator.

    Args:
      func: decorated symbol (function or class).

    Returns:
      The input function with _tf_api_names attribute set.

    Raises:
      SymbolAlreadyExposedError: Raised when a symbol already has API names.
    """
    # Undecorate overridden names
    for f in self._overrides:
      _, undecorated_f = tf_decorator.unwrap(f)
      del undecorated_f._tf_api_names  # pylint: disable=protected-access

    _, undecorated_func = tf_decorator.unwrap(func)

    # Check for an existing api. We check if attribute name is in
    # __dict__ instead of using hasattr to verify that subclasses have
    # their own _tf_api_names as opposed to just inheriting it.
    if '_tf_api_names' in undecorated_func.__dict__:
      # pylint: disable=protected-access
      raise SymbolAlreadyExposedError(
          'Symbol %s is already exposed as %s.' %
          (undecorated_func.__name__, undecorated_func._tf_api_names))
      # pylint: enable=protected-access

    # Complete the export by creating/overriding attribute
    # pylint: disable=protected-access
    undecorated_func._tf_api_names = self._names
    # pylint: enable=protected-access
    return func
开发者ID:keveman,项目名称:tensorflow,代码行数:34,代码来源:tf_export.py

示例3: __call__

  def __call__(self, func):
    """Calls this decorator.

    Args:
      func: decorated symbol (function or class).

    Returns:
      The input function with _tf_api_names attribute set.

    Raises:
      SymbolAlreadyExposedError: Raised when a symbol already has API names
        and kwarg `allow_multiple_exports` not set.
    """
    api_names_attr = API_ATTRS[self._api_name].names
    api_names_attr_v1 = API_ATTRS_V1[self._api_name].names
    # Undecorate overridden names
    for f in self._overrides:
      _, undecorated_f = tf_decorator.unwrap(f)
      delattr(undecorated_f, api_names_attr)
      delattr(undecorated_f, api_names_attr_v1)

    _, undecorated_func = tf_decorator.unwrap(func)
    self.set_attr(undecorated_func, api_names_attr, self._names)
    self.set_attr(undecorated_func, api_names_attr_v1, self._names_v1)
    return func
开发者ID:ZhangXinNan,项目名称:tensorflow,代码行数:25,代码来源:tf_export.py

示例4: _op_is_in_tf_version

def _op_is_in_tf_version(op, version):
  if version == 1:
    return (tf_export.get_v1_names(tf_decorator.unwrap(op)[1]) or
            op in _V1_OPS_THAT_DELEGATE_TO_V2_OPS)
  elif version == 2:
    return tf_export.get_v2_names(tf_decorator.unwrap(op)[1])
  else:
    raise ValueError('Expected version 1 or 2.')
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:8,代码来源:ragged_dispatch.py

示例5: fn_args

def fn_args(fn):
  """Get argument names for function-like object.

  Args:
    fn: Function, or function-like object (e.g., result of `functools.partial`).

  Returns:
    `tuple` of string argument names.

  Raises:
    ValueError: if partial function has positionally bound arguments
  """
  _, fn = tf_decorator.unwrap(fn)

  # Handle callables.
  if hasattr(fn, '__call__') and tf_inspect.ismethod(fn.__call__):
    return tuple(tf_inspect.getargspec(fn.__call__).args)

  # Handle functools.partial and similar objects.
  if hasattr(fn, 'func') and hasattr(fn, 'keywords') and hasattr(fn, 'args'):
    # Handle nested partial.
    original_args = fn_args(fn.func)
    if not original_args:
      return tuple()

    return tuple([
        arg for arg in original_args[len(fn.args):]
        if arg not in set((fn.keywords or {}).keys())
    ])

  # Handle function.
  return tuple(tf_inspect.getargspec(fn).args)
开发者ID:1000sprites,项目名称:tensorflow,代码行数:32,代码来源:util.py

示例6: testUnwrapBoundMethods

 def testUnwrapBoundMethods(self):
   test_decorated_class = TestDecoratedClass()
   self.assertEqual([2, 2, 3], test_decorated_class.return_params(1, 2, 3))
   decorators, target = tf_decorator.unwrap(test_decorated_class.return_params)
   self.assertEqual('test_decorator_increment_first_int_arg',
                    decorators[0].decorator_name)
   self.assertEqual([1, 2, 3], target(test_decorated_class, 1, 2, 3))
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:7,代码来源:tf_decorator_test.py

示例7: testUnwrapReturnsDecoratorListFromOutermostToInnermost

 def testUnwrapReturnsDecoratorListFromOutermostToInnermost(self):
   decorators, _ = tf_decorator.unwrap(test_decorated_function)
   self.assertEqual('decorator 1', decorators[0].decorator_name)
   self.assertEqual('test_decorator_increment_first_int_arg',
                    decorators[1].decorator_name)
   self.assertEqual('decorator 3', decorators[2].decorator_name)
   self.assertEqual('decorator 3 documentation', decorators[2].decorator_doc)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:7,代码来源:tf_decorator_test.py

示例8: visit

 def visit(unused_path, unused_parent, children):
   """Visitor that collects TF 2.0 names."""
   for child in children:
     _, attr = tf_decorator.unwrap(child[1])
     api_names_v2 = tf_export.get_v2_names(attr)
     for name in api_names_v2:
       v2_names.add(name)
开发者ID:aritratony,项目名称:tensorflow,代码行数:7,代码来源:generate_v2_renames_map.py

示例9: testReorderFileNeedsUpdate

  def testReorderFileNeedsUpdate(self):
    reordered_function_names = (
        tf_upgrade_v2.TFAPIChangeSpec().reordered_function_names)
    function_reorders = (
        tf_upgrade_v2.TFAPIChangeSpec().function_reorders)

    added_names_message = """Some function names in
self.reordered_function_names are not in reorders_v2.py.
Please run the following commands to update reorders_v2.py:
bazel build tensorflow/tools/compatibility/update:generate_v2_reorders_map
bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
"""
    removed_names_message = """%s in self.reorders_v2 does not match
any name in self.reordered_function_names.
Please run the following commands to update reorders_v2.py:
bazel build tensorflow/tools/compatibility/update:generate_v2_reorders_map
bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
"""
    self.assertTrue(
        reordered_function_names.issubset(function_reorders),
        added_names_message)
    # function_reorders should contain reordered_function_names
    # and their TensorFlow V1 aliases.
    for name in function_reorders:
      # get other names for this function
      attr = get_symbol_for_name(tf.compat.v1, name)
      _, attr = tf_decorator.unwrap(attr)
      v1_names = tf_export.get_v1_names(attr)
      self.assertTrue(v1_names)
      v1_names = ["tf.%s" % n for n in v1_names]
      # check if any other name is in
      self.assertTrue(
          any(n in reordered_function_names for n in v1_names),
          removed_names_message % name)
开发者ID:rmlarsen,项目名称:tensorflow,代码行数:34,代码来源:tf_upgrade_v2_test.py

示例10: getfullargspec

def getfullargspec(obj):  # pylint: disable=redefined-builtin
  """TFDecorator-aware replacement for `inspect.getfullargspec`/`getargspec`.

  This wrapper uses `inspect.getfullargspec` if available and falls back to
  `inspect.getargspec` in Python 2.

  Args:
    obj: A callable, possibly decorated.

  Returns:
    The `FullArgSpec` that describes the signature of
    the outermost decorator that changes the callable's signature. If the
    callable is not decorated, `inspect.getfullargspec()` will be called
    directly on the callable.
  """
  if six.PY2:
    def spec_fn(target):
      argspecs = _inspect.getargspec(target)
      fullargspecs = FullArgSpec(
          args=argspecs.args,
          varargs=argspecs.varargs,
          varkw=argspecs.keywords,
          defaults=argspecs.defaults,
          kwonlyargs=[],
          kwonlydefaults=None,
          annotations={})
      return fullargspecs
  else:
    spec_fn = _inspect.getfullargspec

  decorators, target = tf_decorator.unwrap(obj)
  return next((d.decorator_argspec for d in decorators
               if d.decorator_argspec is not None), spec_fn(target))
开发者ID:moses-sun,项目名称:tensorflow,代码行数:33,代码来源:tf_inspect.py

示例11: get_canonical_name_for_symbol

def get_canonical_name_for_symbol(symbol, api_name=TENSORFLOW_API_NAME):
  """Get canonical name for the API symbol.

  Canonical name is the first non-deprecated endpoint name.

  Args:
    symbol: API function or class.
    api_name: API name (tensorflow or estimator).

  Returns:
    Canonical name for the API symbol (for e.g. initializers.zeros) if
    canonical name could be determined. Otherwise, returns None.
  """
  if not hasattr(symbol, '__dict__'):
    return None
  api_names_attr = API_ATTRS[api_name].names
  _, undecorated_symbol = tf_decorator.unwrap(symbol)
  if api_names_attr not in undecorated_symbol.__dict__:
    return None
  api_names = getattr(undecorated_symbol, api_names_attr)
  # TODO(annarev): may be add a separate deprecated attribute
  # for estimator names.
  deprecated_api_names = undecorated_symbol.__dict__.get(
      '_tf_deprecated_api_names', [])
  return get_canonical_name(api_names, deprecated_api_names)
开发者ID:ZhangXinNan,项目名称:tensorflow,代码行数:25,代码来源:tf_export.py

示例12: get_api_init_text

def get_api_init_text(packages,
                      output_package,
                      api_name,
                      api_version,
                      compat_api_versions=None):
  """Get a map from destination module to __init__.py code for that module.

  Args:
    packages: Base python packages containing python with target tf_export
      decorators.
    output_package: Base output python package where generated API will be
      added.
    api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
    api_version: API version you want to generate (1 or 2).
    compat_api_versions: Additional API versions to generate under compat/
      directory.

  Returns:
    A dictionary where
      key: (string) destination module (for e.g. tf or tf.consts).
      value: (string) text that should be in __init__.py files for
        corresponding modules.
  """
  if compat_api_versions is None:
    compat_api_versions = []
  module_code_builder = _ModuleInitCodeBuilder(output_package)
  # Traverse over everything imported above. Specifically,
  # we want to traverse over TensorFlow Python modules.

  def in_packages(m):
    return any(package in m for package in packages)

  for module in list(sys.modules.values()):
    # Only look at tensorflow modules.
    if (not module or not hasattr(module, '__name__') or
        module.__name__ is None or not in_packages(module.__name__)):
      continue
    # Do not generate __init__.py files for contrib modules for now.
    if (('.contrib.' in module.__name__ or module.__name__.endswith('.contrib'))
        and '.lite' not in module.__name__):
      continue

    for module_contents_name in dir(module):
      if (module.__name__ + '.' + module_contents_name
          in _SYMBOLS_TO_SKIP_EXPLICITLY):
        continue
      attr = getattr(module, module_contents_name)
      _, attr = tf_decorator.unwrap(attr)

      add_imports_for_symbol(
          module_code_builder, attr, module.__name__, module_contents_name,
          api_name, api_version)
      for compat_api_version in compat_api_versions:
        add_imports_for_symbol(
            module_code_builder, attr, module.__name__, module_contents_name,
            api_name, compat_api_version,
            _COMPAT_MODULE_TEMPLATE % compat_api_version)

  return module_code_builder.build()
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:59,代码来源:create_python_api.py

示例13: visit

 def visit(unused_path, unused_parent, children):
   """Visitor that collects TF 2.0 names."""
   for child in children:
     _, attr = tf_decorator.unwrap(child[1])
     if not hasattr(attr, '__dict__'):
       continue
     api_names_v2 = attr.__dict__.get(_TENSORFLOW_API_ATTR, [])
     for name in api_names_v2:
       v2_names.add(name)
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:9,代码来源:generate_v2_renames_map.py

示例14: testUnwrapReturnsListOfUniqueTFDecorators

 def testUnwrapReturnsListOfUniqueTFDecorators(self):
   decorators, _ = tf_decorator.unwrap(test_decorated_function)
   self.assertEqual(3, len(decorators))
   self.assertTrue(isinstance(decorators[0], tf_decorator.TFDecorator))
   self.assertTrue(isinstance(decorators[1], tf_decorator.TFDecorator))
   self.assertTrue(isinstance(decorators[2], tf_decorator.TFDecorator))
   self.assertIsNot(decorators[0], decorators[1])
   self.assertIsNot(decorators[1], decorators[2])
   self.assertIsNot(decorators[2], decorators[0])
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:9,代码来源:tf_decorator_test.py

示例15: testRewrapMutatesAffectedFunction

  def testRewrapMutatesAffectedFunction(self):

    def new_target(x):
      return x * 3

    self.assertEqual((1 * 2 + 1) ** 2, test_rewrappable_decorated(1))
    prev_target, _ = tf_decorator.unwrap(test_rewrappable_decorated)
    tf_decorator.rewrap(test_rewrappable_decorated, prev_target, new_target)
    self.assertEqual((1 * 3 + 1) ** 2, test_rewrappable_decorated(1))
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:9,代码来源:tf_decorator_test.py


注:本文中的tensorflow.python.util.tf_decorator.unwrap函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。