当前位置: 首页>>代码示例>>Python>>正文


Python traverse.traverse函数代码示例

本文整理汇总了Python中tensorflow.tools.common.traverse.traverse函数的典型用法代码示例。如果您正苦于以下问题:Python traverse函数的具体用法?Python traverse怎么用?Python traverse使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了traverse函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: collect_function_renames

def collect_function_renames():
  """Looks for functions/classes that need to be renamed in TF 2.0.

  Returns:
    Set of tuples of the form (current name, new name).
  """
  # Set of rename lines to write to output file in the form:
  #   'tf.deprecated_name': 'tf.canonical_name'
  renames = set()

  def visit(unused_path, unused_parent, children):
    """Visitor that collects rename strings to add to rename_line_set."""
    for child in children:
      _, attr = tf_decorator.unwrap(child[1])
      api_names_v1 = tf_export.get_v1_names(attr)
      api_names_v2 = tf_export.get_v2_names(attr)
      deprecated_api_names = set(api_names_v1) - set(api_names_v2)
      for name in deprecated_api_names:
        renames.add((name, get_canonical_name(api_names_v2, name)))

  visitor = public_api.PublicAPIVisitor(visit)
  visitor.do_not_descend_map['tf'].append('contrib')
  visitor.do_not_descend_map['tf.compat'] = ['v1', 'v2']
  traverse.traverse(tf, visitor)

  # It is possible that a different function is exported with the
  # same name. For e.g. when creating a different function to
  # rename arguments. Exclude it from renames in this case.
  v2_names = get_all_v2_names()
  renames = set((name, new_name) for name, new_name in renames
                if name not in v2_names)
  return renames
开发者ID:aritratony,项目名称:tensorflow,代码行数:32,代码来源:generate_v2_renames_map.py

示例2: update_renames_v2

def update_renames_v2(output_file_path):
  """Writes a Python dictionary mapping deprecated to canonical API names.

  Args:
    output_file_path: File path to write output to. Any existing contents
      would be replaced.
  """
  # Set of rename lines to write to output file in the form:
  #   'tf.deprecated_name': 'tf.canonical_name'
  rename_line_set = set()
  # _tf_api_names attribute name
  tensorflow_api_attr = tf_export.API_ATTRS[tf_export.TENSORFLOW_API_NAME].names

  def visit(unused_path, unused_parent, children):
    """Visitor that collects rename strings to add to rename_line_set."""
    for child in children:
      _, attr = tf_decorator.unwrap(child[1])
      if not hasattr(attr, '__dict__'):
        continue
      api_names = attr.__dict__.get(tensorflow_api_attr, [])
      deprecated_api_names = attr.__dict__.get('_tf_deprecated_api_names', [])
      canonical_name = tf_export.get_canonical_name(
          api_names, deprecated_api_names)
      for name in deprecated_api_names:
        rename_line_set.add('    \'tf.%s\': \'tf.%s\'' % (name, canonical_name))

  visitor = public_api.PublicAPIVisitor(visit)
  visitor.do_not_descend_map['tf'].append('contrib')
  traverse.traverse(tf, visitor)

  renames_file_text = '%srenames = {\n%s\n}\n' % (
      _FILE_HEADER, ',\n'.join(sorted(rename_line_set)))
  file_io.write_string_to_file(output_file_path, renames_file_text)
开发者ID:AnishShah,项目名称:tensorflow,代码行数:33,代码来源:generate_v2_renames_map.py

示例3: testAllAPIV1

  def testAllAPIV1(self):
    collect = True
    v1_symbols = set([])

    # Symbols which may be generated by the conversion script which do not exist
    # in TF 1.x. This should be a very short list of symbols which are
    # experimental in 1.x but stable for 2.x.
    whitelisted_v2_only_symbols = set(["tf.saved_model.save"])

    # Converts all symbols in the v1 namespace to the v2 namespace, raising
    # an error if the target of the conversion is not in the v1 namespace.
    def conversion_visitor(unused_path, unused_parent, children):
      for child in children:
        _, attr = tf_decorator.unwrap(child[1])
        api_names = tf_export.get_v1_names(attr)
        for name in api_names:
          if collect:
            v1_symbols.add("tf." + name)
          else:
            _, _, _, text = self._upgrade("tf." + name)
            if (text and
                not text.startswith("tf.compat.v1") and
                not text.startswith("tf.estimator") and
                text not in v1_symbols and
                text not in whitelisted_v2_only_symbols):
              self.assertFalse(
                  True, "Symbol %s generated from %s not in v1 API" % (
                      text, name))

    visitor = public_api.PublicAPIVisitor(conversion_visitor)
    visitor.do_not_descend_map["tf"].append("contrib")
    visitor.private_map["tf.compat"] = ["v1", "v2"]
    traverse.traverse(tf.compat.v1, visitor)
    collect = False
    traverse.traverse(tf.compat.v1, visitor)
开发者ID:rmlarsen,项目名称:tensorflow,代码行数:35,代码来源:tf_upgrade_v2_test.py

示例4: testAllAPI

  def testAllAPI(self):
    if not hasattr(tf.compat, "v2"):
      return

    # Converts all symbols in the v1 namespace to the v2 namespace, raising
    # an error if the target of the conversion is not in the v2 namespace.
    # Please regenerate the renames file or edit any manual renames if this
    # test fails.
    def conversion_visitor(unused_path, unused_parent, children):
      for child in children:
        _, attr = tf_decorator.unwrap(child[1])
        api_names = tf_export.get_v1_names(attr)
        for name in api_names:
          _, _, _, text = self._upgrade("tf." + name)
          if (text and
              not text.startswith("tf.compat.v1") and
              text not in self.v2_symbols):
            self.assertFalse(
                True, "Symbol %s generated from %s not in v2 API" % (
                    text, name))

    visitor = public_api.PublicAPIVisitor(conversion_visitor)
    visitor.do_not_descend_map["tf"].append("contrib")
    visitor.private_map["tf.compat"] = ["v1", "v2"]
    traverse.traverse(tf.compat.v1, visitor)
开发者ID:rmlarsen,项目名称:tensorflow,代码行数:25,代码来源:tf_upgrade_v2_test.py

示例5: testV1KeywordArgNames

  def testV1KeywordArgNames(self):
    all_keyword_renames = (
        tf_upgrade_v2.TFAPIChangeSpec().function_keyword_renames)

    # Visitor that verifies V1 argument names.
    def arg_test_visitor(unused_path, unused_parent, children):
      for child in children:
        _, attr = tf_decorator.unwrap(child[1])
        names_v1 = tf_export.get_v1_names(attr)

        for name in names_v1:
          name = "tf.%s" % name
          if name not in all_keyword_renames:
            continue
          arg_names_v1 = tf_inspect.getargspec(attr)[0]
          keyword_renames = all_keyword_renames[name]
          self.assertEqual(type(keyword_renames), dict)

          # Assert that v1 function has valid v1 argument names.
          for from_name, _ in keyword_renames.items():
            self.assertIn(
                from_name, arg_names_v1,
                "%s not found in %s arguments: %s" %
                (from_name, name, str(arg_names_v1)))

    visitor = public_api.PublicAPIVisitor(arg_test_visitor)
    visitor.do_not_descend_map["tf"].append("contrib")
    visitor.private_map["tf.compat"] = ["v1", "v2"]
    traverse.traverse(tf.compat.v1, visitor)
开发者ID:rmlarsen,项目名称:tensorflow,代码行数:29,代码来源:tf_upgrade_v2_test.py

示例6: collect_function_arg_names

def collect_function_arg_names(function_names):
  """Determines argument names for reordered function signatures.

  Args:
    function_names: Functions to collect arguments for.

  Returns:
    Dictionary mapping function name to its arguments.
  """
  # Map from reordered function name to its arguments.
  function_to_args = {}

  def visit(unused_path, unused_parent, children):
    """Visitor that collects arguments for reordered functions."""
    for child in children:
      _, attr = tf_decorator.unwrap(child[1])
      api_names_v1 = tf_export.get_v1_names(attr)
      api_names_v1 = ['tf.%s' % name for name in api_names_v1]
      matches_function_names = any(
          name in function_names for name in api_names_v1)
      if matches_function_names:
        arg_list = tf_inspect.getargspec(attr)[0]
        for name in api_names_v1:
          function_to_args[name] = arg_list

  visitor = public_api.PublicAPIVisitor(visit)
  visitor.do_not_descend_map['tf'].append('contrib')
  visitor.do_not_descend_map['tf.compat'] = ['v1', 'v2']
  traverse.traverse(tf, visitor)

  return function_to_args
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:31,代码来源:generate_v2_reorders_map.py

示例7: testAPIBackwardsCompatibility

  def testAPIBackwardsCompatibility(self):
    # Extract all API stuff.
    visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()

    public_api_visitor = public_api.PublicAPIVisitor(visitor)
    public_api_visitor.do_not_descend_map['tf'].append('contrib')
    public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental']
    traverse.traverse(tf, public_api_visitor)

    proto_dict = visitor.GetProtos()

    # Read all golden files.
    expression = os.path.join(
        resource_loader.get_root_dir_with_all_resources(),
        _KeyToFilePath('*'))
    golden_file_list = file_io.get_matching_files(expression)

    def _ReadFileToProto(filename):
      """Read a filename, create a protobuf from its contents."""
      ret_val = api_objects_pb2.TFAPIObject()
      text_format.Merge(file_io.read_file_to_string(filename), ret_val)
      return ret_val

    golden_proto_dict = {
        _FileNameToKey(filename): _ReadFileToProto(filename)
        for filename in golden_file_list
    }

    # Diff them. Do not fail if called with update.
    # If the test is run to update goldens, only report diffs but do not fail.
    self._AssertProtoDictEquals(
        golden_proto_dict,
        proto_dict,
        verbose=FLAGS.verbose_diffs,
        update_goldens=FLAGS.update_goldens)
开发者ID:DILASSS,项目名称:tensorflow,代码行数:35,代码来源:api_compatibility_test.py

示例8: setUpClass

  def setUpClass(cls):
    cls.v2_symbols = {}
    cls.v1_symbols = {}
    if hasattr(tf.compat, "v2"):

      def symbol_collector(unused_path, unused_parent, children):
        for child in children:
          _, attr = tf_decorator.unwrap(child[1])
          api_names_v2 = tf_export.get_v2_names(attr)
          for name in api_names_v2:
            cls.v2_symbols["tf." + name] = attr

      visitor = public_api.PublicAPIVisitor(symbol_collector)
      traverse.traverse(tf.compat.v2, visitor)

    if hasattr(tf.compat, "v1"):

      def symbol_collector_v1(unused_path, unused_parent, children):
        for child in children:
          _, attr = tf_decorator.unwrap(child[1])
          api_names_v1 = tf_export.get_v1_names(attr)
          for name in api_names_v1:
            cls.v1_symbols["tf." + name] = attr

      visitor = public_api.PublicAPIVisitor(symbol_collector_v1)
      traverse.traverse(tf.compat.v1, visitor)
开发者ID:terrytangyuan,项目名称:tensorflow,代码行数:26,代码来源:tf_upgrade_v2_test.py

示例9: testNoSubclassOfMessageV2

 def testNoSubclassOfMessageV2(self):
   if not hasattr(tf.compat, 'v2'):
     return
   visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor)
   visitor.do_not_descend_map['tf'].append('contrib')
   if FLAGS.only_test_core_api:
     visitor.do_not_descend_map['tf'].extend(_NON_CORE_PACKAGES)
   traverse.traverse(tf_v2, visitor)
开发者ID:aeverall,项目名称:tensorflow,代码行数:8,代码来源:api_compatibility_test.py

示例10: test_cycle

  def test_cycle(self):

    class Cyclist(object):
      pass
    Cyclist.cycle = Cyclist

    visitor = TestVisitor()
    traverse.traverse(Cyclist, visitor)
开发者ID:AnishShah,项目名称:tensorflow,代码行数:8,代码来源:traverse_test.py

示例11: test_module

  def test_module(self):
    visitor = TestVisitor()
    traverse.traverse(test_module1, visitor)

    called = [parent for _, parent, _ in visitor.call_log]

    self.assertIn(test_module1.ModuleClass1, called)
    self.assertIn(test_module2.ModuleClass2, called)
开发者ID:AnishShah,项目名称:tensorflow,代码行数:8,代码来源:traverse_test.py

示例12: test_module

  def test_module(self):
    visitor = TestVisitor()
    traverse.traverse(sys.modules[__name__], visitor)

    called = [parent for _, parent, _ in visitor.call_log]

    self.assertIn(TestVisitor, called)
    self.assertIn(TraverseTest, called)
    self.assertIn(traverse, called)
开发者ID:1000sprites,项目名称:tensorflow,代码行数:9,代码来源:traverse_test.py

示例13: test_class

 def test_class(self):
   visitor = TestVisitor()
   traverse.traverse(TestVisitor, visitor)
   self.assertEqual(TestVisitor,
                    visitor.call_log[0][1])
   # There are a bunch of other members, but make sure that the ones we know
   # about are there.
   self.assertIn('__init__', [name for name, _ in visitor.call_log[0][2]])
   self.assertIn('__call__', [name for name, _ in visitor.call_log[0][2]])
开发者ID:AnishShah,项目名称:tensorflow,代码行数:9,代码来源:traverse_test.py

示例14: testKeywordArgNames

  def testKeywordArgNames(self):
    if not hasattr(tf.compat, "v2"):
      return

    all_keyword_renames = (
        tf_upgrade_v2.TFAPIChangeSpec().function_keyword_renames)
    v2_name_exceptions = {"verify_shape_is_now_always_true"}

    # Visitor that verifies V1 argument names, converts to V2 and checks
    # V2 argument names.
    def conversion_visitor(unused_path, unused_parent, children):
      for child in children:
        _, attr = tf_decorator.unwrap(child[1])
        names_v1 = get_v1_names(attr)

        for name in names_v1:
          name = "tf.%s" % name
          if name not in all_keyword_renames:
            continue
          arg_names_v1 = tf_inspect.getargspec(attr)[0]
          keyword_renames = all_keyword_renames[name]
          self.assertEqual(type(keyword_renames), dict)

          # Assert that v1 function has valid v1 argument names.
          for from_name, _ in keyword_renames.items():
            self.assertIn(
                from_name, arg_names_v1,
                "%s not found in %s arguments: %s" %
                (from_name, name, str(arg_names_v1)))

          # Assert that arg names after converting to v2 are present in
          # v2 function.
          # 1. First, create an input of the form:
          #    tf.foo(arg1=val1, arg2=val2, ...)
          args = ",".join(
              ["%s=%d" % (from_name, from_index)
               for from_index, from_name in enumerate(keyword_renames.keys())])
          text_input = "%s(%s)" % (name, args)
          # 2. Convert the input to V2.
          _, _, _, text = self._upgrade(text_input)
          new_function_name, new_args = get_func_and_args_from_str(text)
          # 3. Verify V2 function and arguments.
          # Note: If we rename arguments, new function must be available in 2.0.
          # We should not be using compat.v1 in this case.
          self.assertIn(new_function_name, self.v2_symbols)
          args_v2 = tf_inspect.getargspec(self.v2_symbols[new_function_name])[0]
          args_v2.extend(v2_name_exceptions)
          for new_arg in new_args:
            self.assertIn(new_arg, args_v2)

    visitor = public_api.PublicAPIVisitor(conversion_visitor)
    visitor.do_not_descend_map["tf"].append("contrib")
    visitor.private_map["tf.compat"] = ["v1", "v2"]
    traverse.traverse(tf.compat.v1, visitor)
开发者ID:aeverall,项目名称:tensorflow,代码行数:54,代码来源:tf_upgrade_v2_test.py

示例15: testNewAPIBackwardsCompatibility

  def testNewAPIBackwardsCompatibility(self):
    # Extract all API stuff.
    visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()

    public_api_visitor = public_api.PublicAPIVisitor(visitor)
    public_api_visitor.do_not_descend_map['tf'].append('contrib')
    public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental']
    # TODO(annarev): Make slide_dataset available in API.
    public_api_visitor.private_map['tf'] = ['slide_dataset']
    traverse.traverse(api, public_api_visitor)

    proto_dict = visitor.GetProtos()

    # Read all golden files.
    expression = os.path.join(
        resource_loader.get_root_dir_with_all_resources(),
        _KeyToFilePath('*'))
    golden_file_list = file_io.get_matching_files(expression)

    def _ReadFileToProto(filename):
      """Read a filename, create a protobuf from its contents."""
      ret_val = api_objects_pb2.TFAPIObject()
      text_format.Merge(file_io.read_file_to_string(filename), ret_val)
      return ret_val

    golden_proto_dict = {
        _FileNameToKey(filename): _ReadFileToProto(filename)
        for filename in golden_file_list
    }

    # user_ops is an empty module. It is currently available in TensorFlow API
    # but we don't keep empty modules in the new API.
    # We delete user_ops from golden_proto_dict to make sure assert passes
    # when diffing new API against goldens.
    # TODO(annarev): remove user_ops from goldens once we switch to new API.
    tf_module = golden_proto_dict['tensorflow'].tf_module
    for i in range(len(tf_module.member)):
      if tf_module.member[i].name == 'user_ops':
        del tf_module.member[i]
        break

    # Diff them. Do not fail if called with update.
    # If the test is run to update goldens, only report diffs but do not fail.
    self._AssertProtoDictEquals(
        golden_proto_dict,
        proto_dict,
        verbose=FLAGS.verbose_diffs,
        update_goldens=False,
        additional_missing_object_message=
        'Check if tf_export decorator/call is missing for this symbol.')
开发者ID:PuchatekwSzortach,项目名称:tensorflow,代码行数:50,代码来源:api_compatibility_test.py


注:本文中的tensorflow.tools.common.traverse.traverse函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。