本文整理汇总了Python中tensorflow.tools.common.traverse.traverse函数的典型用法代码示例。如果您正苦于以下问题:Python traverse函数的具体用法?Python traverse怎么用?Python traverse使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了traverse函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: collect_function_renames
def collect_function_renames():
"""Looks for functions/classes that need to be renamed in TF 2.0.
Returns:
Set of tuples of the form (current name, new name).
"""
# Set of rename lines to write to output file in the form:
# 'tf.deprecated_name': 'tf.canonical_name'
renames = set()
def visit(unused_path, unused_parent, children):
"""Visitor that collects rename strings to add to rename_line_set."""
for child in children:
_, attr = tf_decorator.unwrap(child[1])
api_names_v1 = tf_export.get_v1_names(attr)
api_names_v2 = tf_export.get_v2_names(attr)
deprecated_api_names = set(api_names_v1) - set(api_names_v2)
for name in deprecated_api_names:
renames.add((name, get_canonical_name(api_names_v2, name)))
visitor = public_api.PublicAPIVisitor(visit)
visitor.do_not_descend_map['tf'].append('contrib')
visitor.do_not_descend_map['tf.compat'] = ['v1', 'v2']
traverse.traverse(tf, visitor)
# It is possible that a different function is exported with the
# same name. For e.g. when creating a different function to
# rename arguments. Exclude it from renames in this case.
v2_names = get_all_v2_names()
renames = set((name, new_name) for name, new_name in renames
if name not in v2_names)
return renames
示例2: update_renames_v2
def update_renames_v2(output_file_path):
"""Writes a Python dictionary mapping deprecated to canonical API names.
Args:
output_file_path: File path to write output to. Any existing contents
would be replaced.
"""
# Set of rename lines to write to output file in the form:
# 'tf.deprecated_name': 'tf.canonical_name'
rename_line_set = set()
# _tf_api_names attribute name
tensorflow_api_attr = tf_export.API_ATTRS[tf_export.TENSORFLOW_API_NAME].names
def visit(unused_path, unused_parent, children):
"""Visitor that collects rename strings to add to rename_line_set."""
for child in children:
_, attr = tf_decorator.unwrap(child[1])
if not hasattr(attr, '__dict__'):
continue
api_names = attr.__dict__.get(tensorflow_api_attr, [])
deprecated_api_names = attr.__dict__.get('_tf_deprecated_api_names', [])
canonical_name = tf_export.get_canonical_name(
api_names, deprecated_api_names)
for name in deprecated_api_names:
rename_line_set.add(' \'tf.%s\': \'tf.%s\'' % (name, canonical_name))
visitor = public_api.PublicAPIVisitor(visit)
visitor.do_not_descend_map['tf'].append('contrib')
traverse.traverse(tf, visitor)
renames_file_text = '%srenames = {\n%s\n}\n' % (
_FILE_HEADER, ',\n'.join(sorted(rename_line_set)))
file_io.write_string_to_file(output_file_path, renames_file_text)
示例3: testAllAPIV1
def testAllAPIV1(self):
collect = True
v1_symbols = set([])
# Symbols which may be generated by the conversion script which do not exist
# in TF 1.x. This should be a very short list of symbols which are
# experimental in 1.x but stable for 2.x.
whitelisted_v2_only_symbols = set(["tf.saved_model.save"])
# Converts all symbols in the v1 namespace to the v2 namespace, raising
# an error if the target of the conversion is not in the v1 namespace.
def conversion_visitor(unused_path, unused_parent, children):
for child in children:
_, attr = tf_decorator.unwrap(child[1])
api_names = tf_export.get_v1_names(attr)
for name in api_names:
if collect:
v1_symbols.add("tf." + name)
else:
_, _, _, text = self._upgrade("tf." + name)
if (text and
not text.startswith("tf.compat.v1") and
not text.startswith("tf.estimator") and
text not in v1_symbols and
text not in whitelisted_v2_only_symbols):
self.assertFalse(
True, "Symbol %s generated from %s not in v1 API" % (
text, name))
visitor = public_api.PublicAPIVisitor(conversion_visitor)
visitor.do_not_descend_map["tf"].append("contrib")
visitor.private_map["tf.compat"] = ["v1", "v2"]
traverse.traverse(tf.compat.v1, visitor)
collect = False
traverse.traverse(tf.compat.v1, visitor)
示例4: testAllAPI
def testAllAPI(self):
if not hasattr(tf.compat, "v2"):
return
# Converts all symbols in the v1 namespace to the v2 namespace, raising
# an error if the target of the conversion is not in the v2 namespace.
# Please regenerate the renames file or edit any manual renames if this
# test fails.
def conversion_visitor(unused_path, unused_parent, children):
for child in children:
_, attr = tf_decorator.unwrap(child[1])
api_names = tf_export.get_v1_names(attr)
for name in api_names:
_, _, _, text = self._upgrade("tf." + name)
if (text and
not text.startswith("tf.compat.v1") and
text not in self.v2_symbols):
self.assertFalse(
True, "Symbol %s generated from %s not in v2 API" % (
text, name))
visitor = public_api.PublicAPIVisitor(conversion_visitor)
visitor.do_not_descend_map["tf"].append("contrib")
visitor.private_map["tf.compat"] = ["v1", "v2"]
traverse.traverse(tf.compat.v1, visitor)
示例5: testV1KeywordArgNames
def testV1KeywordArgNames(self):
all_keyword_renames = (
tf_upgrade_v2.TFAPIChangeSpec().function_keyword_renames)
# Visitor that verifies V1 argument names.
def arg_test_visitor(unused_path, unused_parent, children):
for child in children:
_, attr = tf_decorator.unwrap(child[1])
names_v1 = tf_export.get_v1_names(attr)
for name in names_v1:
name = "tf.%s" % name
if name not in all_keyword_renames:
continue
arg_names_v1 = tf_inspect.getargspec(attr)[0]
keyword_renames = all_keyword_renames[name]
self.assertEqual(type(keyword_renames), dict)
# Assert that v1 function has valid v1 argument names.
for from_name, _ in keyword_renames.items():
self.assertIn(
from_name, arg_names_v1,
"%s not found in %s arguments: %s" %
(from_name, name, str(arg_names_v1)))
visitor = public_api.PublicAPIVisitor(arg_test_visitor)
visitor.do_not_descend_map["tf"].append("contrib")
visitor.private_map["tf.compat"] = ["v1", "v2"]
traverse.traverse(tf.compat.v1, visitor)
示例6: collect_function_arg_names
def collect_function_arg_names(function_names):
"""Determines argument names for reordered function signatures.
Args:
function_names: Functions to collect arguments for.
Returns:
Dictionary mapping function name to its arguments.
"""
# Map from reordered function name to its arguments.
function_to_args = {}
def visit(unused_path, unused_parent, children):
"""Visitor that collects arguments for reordered functions."""
for child in children:
_, attr = tf_decorator.unwrap(child[1])
api_names_v1 = tf_export.get_v1_names(attr)
api_names_v1 = ['tf.%s' % name for name in api_names_v1]
matches_function_names = any(
name in function_names for name in api_names_v1)
if matches_function_names:
arg_list = tf_inspect.getargspec(attr)[0]
for name in api_names_v1:
function_to_args[name] = arg_list
visitor = public_api.PublicAPIVisitor(visit)
visitor.do_not_descend_map['tf'].append('contrib')
visitor.do_not_descend_map['tf.compat'] = ['v1', 'v2']
traverse.traverse(tf, visitor)
return function_to_args
示例7: testAPIBackwardsCompatibility
def testAPIBackwardsCompatibility(self):
# Extract all API stuff.
visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()
public_api_visitor = public_api.PublicAPIVisitor(visitor)
public_api_visitor.do_not_descend_map['tf'].append('contrib')
public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental']
traverse.traverse(tf, public_api_visitor)
proto_dict = visitor.GetProtos()
# Read all golden files.
expression = os.path.join(
resource_loader.get_root_dir_with_all_resources(),
_KeyToFilePath('*'))
golden_file_list = file_io.get_matching_files(expression)
def _ReadFileToProto(filename):
"""Read a filename, create a protobuf from its contents."""
ret_val = api_objects_pb2.TFAPIObject()
text_format.Merge(file_io.read_file_to_string(filename), ret_val)
return ret_val
golden_proto_dict = {
_FileNameToKey(filename): _ReadFileToProto(filename)
for filename in golden_file_list
}
# Diff them. Do not fail if called with update.
# If the test is run to update goldens, only report diffs but do not fail.
self._AssertProtoDictEquals(
golden_proto_dict,
proto_dict,
verbose=FLAGS.verbose_diffs,
update_goldens=FLAGS.update_goldens)
示例8: setUpClass
def setUpClass(cls):
cls.v2_symbols = {}
cls.v1_symbols = {}
if hasattr(tf.compat, "v2"):
def symbol_collector(unused_path, unused_parent, children):
for child in children:
_, attr = tf_decorator.unwrap(child[1])
api_names_v2 = tf_export.get_v2_names(attr)
for name in api_names_v2:
cls.v2_symbols["tf." + name] = attr
visitor = public_api.PublicAPIVisitor(symbol_collector)
traverse.traverse(tf.compat.v2, visitor)
if hasattr(tf.compat, "v1"):
def symbol_collector_v1(unused_path, unused_parent, children):
for child in children:
_, attr = tf_decorator.unwrap(child[1])
api_names_v1 = tf_export.get_v1_names(attr)
for name in api_names_v1:
cls.v1_symbols["tf." + name] = attr
visitor = public_api.PublicAPIVisitor(symbol_collector_v1)
traverse.traverse(tf.compat.v1, visitor)
示例9: testNoSubclassOfMessageV2
def testNoSubclassOfMessageV2(self):
if not hasattr(tf.compat, 'v2'):
return
visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor)
visitor.do_not_descend_map['tf'].append('contrib')
if FLAGS.only_test_core_api:
visitor.do_not_descend_map['tf'].extend(_NON_CORE_PACKAGES)
traverse.traverse(tf_v2, visitor)
示例10: test_cycle
def test_cycle(self):
class Cyclist(object):
pass
Cyclist.cycle = Cyclist
visitor = TestVisitor()
traverse.traverse(Cyclist, visitor)
示例11: test_module
def test_module(self):
visitor = TestVisitor()
traverse.traverse(test_module1, visitor)
called = [parent for _, parent, _ in visitor.call_log]
self.assertIn(test_module1.ModuleClass1, called)
self.assertIn(test_module2.ModuleClass2, called)
示例12: test_module
def test_module(self):
visitor = TestVisitor()
traverse.traverse(sys.modules[__name__], visitor)
called = [parent for _, parent, _ in visitor.call_log]
self.assertIn(TestVisitor, called)
self.assertIn(TraverseTest, called)
self.assertIn(traverse, called)
示例13: test_class
def test_class(self):
visitor = TestVisitor()
traverse.traverse(TestVisitor, visitor)
self.assertEqual(TestVisitor,
visitor.call_log[0][1])
# There are a bunch of other members, but make sure that the ones we know
# about are there.
self.assertIn('__init__', [name for name, _ in visitor.call_log[0][2]])
self.assertIn('__call__', [name for name, _ in visitor.call_log[0][2]])
示例14: testKeywordArgNames
def testKeywordArgNames(self):
if not hasattr(tf.compat, "v2"):
return
all_keyword_renames = (
tf_upgrade_v2.TFAPIChangeSpec().function_keyword_renames)
v2_name_exceptions = {"verify_shape_is_now_always_true"}
# Visitor that verifies V1 argument names, converts to V2 and checks
# V2 argument names.
def conversion_visitor(unused_path, unused_parent, children):
for child in children:
_, attr = tf_decorator.unwrap(child[1])
names_v1 = get_v1_names(attr)
for name in names_v1:
name = "tf.%s" % name
if name not in all_keyword_renames:
continue
arg_names_v1 = tf_inspect.getargspec(attr)[0]
keyword_renames = all_keyword_renames[name]
self.assertEqual(type(keyword_renames), dict)
# Assert that v1 function has valid v1 argument names.
for from_name, _ in keyword_renames.items():
self.assertIn(
from_name, arg_names_v1,
"%s not found in %s arguments: %s" %
(from_name, name, str(arg_names_v1)))
# Assert that arg names after converting to v2 are present in
# v2 function.
# 1. First, create an input of the form:
# tf.foo(arg1=val1, arg2=val2, ...)
args = ",".join(
["%s=%d" % (from_name, from_index)
for from_index, from_name in enumerate(keyword_renames.keys())])
text_input = "%s(%s)" % (name, args)
# 2. Convert the input to V2.
_, _, _, text = self._upgrade(text_input)
new_function_name, new_args = get_func_and_args_from_str(text)
# 3. Verify V2 function and arguments.
# Note: If we rename arguments, new function must be available in 2.0.
# We should not be using compat.v1 in this case.
self.assertIn(new_function_name, self.v2_symbols)
args_v2 = tf_inspect.getargspec(self.v2_symbols[new_function_name])[0]
args_v2.extend(v2_name_exceptions)
for new_arg in new_args:
self.assertIn(new_arg, args_v2)
visitor = public_api.PublicAPIVisitor(conversion_visitor)
visitor.do_not_descend_map["tf"].append("contrib")
visitor.private_map["tf.compat"] = ["v1", "v2"]
traverse.traverse(tf.compat.v1, visitor)
示例15: testNewAPIBackwardsCompatibility
def testNewAPIBackwardsCompatibility(self):
# Extract all API stuff.
visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()
public_api_visitor = public_api.PublicAPIVisitor(visitor)
public_api_visitor.do_not_descend_map['tf'].append('contrib')
public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental']
# TODO(annarev): Make slide_dataset available in API.
public_api_visitor.private_map['tf'] = ['slide_dataset']
traverse.traverse(api, public_api_visitor)
proto_dict = visitor.GetProtos()
# Read all golden files.
expression = os.path.join(
resource_loader.get_root_dir_with_all_resources(),
_KeyToFilePath('*'))
golden_file_list = file_io.get_matching_files(expression)
def _ReadFileToProto(filename):
"""Read a filename, create a protobuf from its contents."""
ret_val = api_objects_pb2.TFAPIObject()
text_format.Merge(file_io.read_file_to_string(filename), ret_val)
return ret_val
golden_proto_dict = {
_FileNameToKey(filename): _ReadFileToProto(filename)
for filename in golden_file_list
}
# user_ops is an empty module. It is currently available in TensorFlow API
# but we don't keep empty modules in the new API.
# We delete user_ops from golden_proto_dict to make sure assert passes
# when diffing new API against goldens.
# TODO(annarev): remove user_ops from goldens once we switch to new API.
tf_module = golden_proto_dict['tensorflow'].tf_module
for i in range(len(tf_module.member)):
if tf_module.member[i].name == 'user_ops':
del tf_module.member[i]
break
# Diff them. Do not fail if called with update.
# If the test is run to update goldens, only report diffs but do not fail.
self._AssertProtoDictEquals(
golden_proto_dict,
proto_dict,
verbose=FLAGS.verbose_diffs,
update_goldens=False,
additional_missing_object_message=
'Check if tf_export decorator/call is missing for this symbol.')