本文整理汇总了Python中tensorflow.python.framework.function._from_library函数的典型用法代码示例。如果您正苦于以下问题:Python _from_library函数的具体用法?Python _from_library怎么用?Python _from_library使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了_from_library函数的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: testFromLibraryMissingFuncDef
def testFromLibraryMissingFuncDef(self):
@function.Defun(dtypes.float32, dtypes.float32)
def G1(x, dy):
return x * dy
@function.Defun(dtypes.float32)
def F1(x):
return math_ops.exp(x) - math_ops.exp(-x)
gradient = function_pb2.GradientDef()
gradient.function_name = F1.name
gradient.gradient_func = G1.name
# Create invalid function def that is missing G1 function def
library = function_pb2.FunctionDefLibrary()
library.gradient.extend([gradient])
library.function.extend([F1.definition])
with self.assertRaisesRegexp(
ValueError,
"FunctionDefLibrary missing 'G1_[0-9a-zA-Z]{8,11}' FunctionDef"):
function._from_library(library)
# Create invalid function def that is missing F1 function def
library = function_pb2.FunctionDefLibrary()
library.gradient.extend([gradient])
library.function.extend([G1.definition])
with self.assertRaisesRegexp(
ValueError,
"FunctionDefLibrary missing 'F1_[0-9a-zA-Z]{8,11}' FunctionDef"):
function._from_library(library)
示例2: testFromLibraryCyclicGradFuncs
def testFromLibraryCyclicGradFuncs(self):
@function.Defun(dtypes.float32)
def F1(x):
return math_ops.exp(x) - math_ops.exp(-x)
@function.Defun(dtypes.float32)
def F2(x):
return math_ops.exp(x) - math_ops.exp(-x)
# Create invalid function def library where F1 has gradient function F2 and
# F2 has gradient function F1
library = function_pb2.FunctionDefLibrary()
library.function.extend([F1.definition, F2.definition])
gradient1 = function_pb2.GradientDef()
gradient1.function_name = F1.name
gradient1.gradient_func = F2.name
gradient2 = function_pb2.GradientDef()
gradient2.function_name = F2.name
gradient2.gradient_func = F1.name
library.gradient.extend([gradient1, gradient2])
with self.assertRaisesRegexp(
ValueError, "FunctionDefLibrary contains cyclic gradient functions!"):
function._from_library(library)
示例3: testFromLibrary
def testFromLibrary(self):
# Define some functions with different gradient functions. Note that many of
# the below functions are identical since function bodies don't matter for
# this test.
@function.Defun(dtypes.float32, dtypes.float32)
def G1(x, dy):
return x * dy
@function.Defun(dtypes.float32, dtypes.float32)
def G2(x, dy):
return x * dy
# F1 and F2 have the same gradient function
@function.Defun(dtypes.float32, grad_func=G1)
def F1(x):
return math_ops.exp(x) - math_ops.exp(-x)
@function.Defun(dtypes.float32, grad_func=G1)
def F2(x):
return math_ops.exp(x) - math_ops.exp(-x)
# F3 has a different gradient function
@function.Defun(dtypes.float32, grad_func=G2)
def F3(x):
return math_ops.exp(x) - math_ops.exp(-x)
# F4 has no gradient function
@function.Defun(dtypes.float32)
def F4(x):
return math_ops.exp(x) - math_ops.exp(-x)
# Instantiate all functions
g = ops.Graph()
with g.as_default():
c = constant_op.constant(1.0, dtypes.float32)
f1 = F1(c)
f2 = F2(c)
f3 = F3(c)
f4 = F4(c)
gradients_impl.gradients([f1, f2, f3, f4], c)
library = g.as_graph_def().library
new_funcs = function._from_library(library)
def CheckNewFunc(func):
new_func = [f for f in new_funcs if f.name == func.name]
self.assertEqual(len(new_func), 1)
self.expectFunctionsEqual(func, new_func=new_func[0])
CheckNewFunc(G1)
CheckNewFunc(G2)
CheckNewFunc(F1)
CheckNewFunc(F2)
CheckNewFunc(F3)
CheckNewFunc(F4)
示例4: testFromLibraryEmptyLib
def testFromLibraryEmptyLib(self):
library = function_pb2.FunctionDefLibrary()
self.assertEqual(len(function._from_library(library)), 0)
示例5: import_graph_def
def import_graph_def(graph_def, input_map=None, return_elements=None,
name=None, op_dict=None, producer_op_list=None):
"""Imports the graph from `graph_def` into the current default `Graph`.
This function provides a way to import a serialized TensorFlow
[`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
protocol buffer, and extract individual objects in the `GraphDef` as
@{tf.Tensor} and @{tf.Operation} objects. Once extracted,
these objects are placed into the current default `Graph`. See
@{tf.Graph.as_graph_def} for a way to create a `GraphDef`
proto.
Args:
graph_def: A `GraphDef` proto containing operations to be imported into
the default graph.
input_map: A dictionary mapping input names (as strings) in `graph_def`
to `Tensor` objects. The values of the named input tensors in the
imported graph will be re-mapped to the respective `Tensor` values.
return_elements: A list of strings containing operation names in
`graph_def` that will be returned as `Operation` objects; and/or
tensor names in `graph_def` that will be returned as `Tensor` objects.
name: (Optional.) A prefix that will be prepended to the names in
`graph_def`. Note that this does not apply to imported function names.
Defaults to `"import"`.
op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos.
Must contain an `OpDef` proto for each op type named in `graph_def`.
If omitted, uses the `OpDef` protos registered in the global registry.
producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
list of `OpDef`s used by the producer of the graph. If provided, attrs
for ops in `graph_def` that are not in `op_dict` that have their default
value according to `producer_op_list` will be removed. This will allow
some more `GraphDef`s produced by later binaries to be accepted by
earlier binaries.
Returns:
A list of `Operation` and/or `Tensor` objects from the imported graph,
corresponding to the names in `return_elements`.
Raises:
TypeError: If `graph_def` is not a `GraphDef` proto,
`input_map` is not a dictionary mapping strings to `Tensor` objects,
or `return_elements` is not a list of strings.
ValueError: If `input_map`, or `return_elements` contains names that
do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
it refers to an unknown tensor).
"""
# Type checks for inputs.
if not isinstance(graph_def, graph_pb2.GraphDef):
# `graph_def` could be a dynamically-created message, so try a duck-typed
# approach
try:
old_graph_def = graph_def
graph_def = graph_pb2.GraphDef()
graph_def.MergeFrom(old_graph_def)
except TypeError:
raise TypeError('graph_def must be a GraphDef proto.')
if input_map is None:
input_map = {}
else:
if not (isinstance(input_map, dict)
and all(isinstance(k, compat.bytes_or_text_types)
for k in input_map.keys())):
raise TypeError('input_map must be a dictionary mapping strings to '
'Tensor objects.')
if return_elements is not None:
return_elements = tuple(return_elements)
if not all(isinstance(x, compat.bytes_or_text_types)
for x in return_elements):
raise TypeError('return_elements must be a list of strings.')
# Use a canonical representation for all tensor names.
input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
used_input_keys = set()
name_to_op = {}
if op_dict is None:
op_dict = op_def_registry.get_registered_ops()
if producer_op_list is None:
producer_op_dict = None
else:
producer_op_dict = {op.name: op for op in producer_op_list.op}
g = ops.get_default_graph()
# Add any functions defined in `graph_def` to `g`
if graph_def.library and graph_def.library.function:
# Copy op_dict so we don't clobber the original
op_dict = copy.copy(op_dict)
# pylint: disable=protected-access
# Note that we do not prepend `name` to the function name. The reasoning is
# that function names are similar to op definition names, which currently do
# not have a scoped name or namespace scheme.
functions = function._from_library(graph_def.library)
for f in functions:
g._add_function(f)
op_dict[f.name] = f.definition.signature
# pylint: enable=protected-access
#.........这里部分代码省略.........
示例6: import_graph_def
def import_graph_def(graph_def, input_map=None, return_elements=None,
name=None, op_dict=None, producer_op_list=None):
"""Imports the graph from `graph_def` into the current default `Graph`.
This function provides a way to import a serialized TensorFlow
[`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
protocol buffer, and extract individual objects in the `GraphDef` as
@{tf.Tensor} and @{tf.Operation} objects. Once extracted,
these objects are placed into the current default `Graph`. See
@{tf.Graph.as_graph_def} for a way to create a `GraphDef`
proto.
Args:
graph_def: A `GraphDef` proto containing operations to be imported into
the default graph.
input_map: A dictionary mapping input names (as strings) in `graph_def`
to `Tensor` objects. The values of the named input tensors in the
imported graph will be re-mapped to the respective `Tensor` values.
return_elements: A list of strings containing operation names in
`graph_def` that will be returned as `Operation` objects; and/or
tensor names in `graph_def` that will be returned as `Tensor` objects.
name: (Optional.) A prefix that will be prepended to the names in
`graph_def`. Note that this does not apply to imported function names.
Defaults to `"import"`.
op_dict: (Optional.) Deprecated, do not use.
producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
list of `OpDef`s used by the producer of the graph. If provided,
unrecognized attrs for ops in `graph_def` that have their default value
according to `producer_op_list` will be removed. This will allow some more
`GraphDef`s produced by later binaries to be accepted by earlier binaries.
Returns:
A list of `Operation` and/or `Tensor` objects from the imported graph,
corresponding to the names in `return_elements`.
Raises:
TypeError: If `graph_def` is not a `GraphDef` proto,
`input_map` is not a dictionary mapping strings to `Tensor` objects,
or `return_elements` is not a list of strings.
ValueError: If `input_map`, or `return_elements` contains names that
do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
it refers to an unknown tensor).
"""
graph_def = _ProcessGraphDefParam(graph_def)
input_map = _ProcessInputMapParam(input_map)
return_elements = _ProcessReturnElementsParam(return_elements)
op_dict = op_def_registry.get_registered_ops()
if producer_op_list is not None:
# TODO(skyewm): make a copy of graph_def so we're not mutating the argument?
_RemoveDefaultAttrs(op_dict, producer_op_list, graph_def)
graph = ops.get_default_graph()
if graph._c_graph: # pylint: disable=protected-access
with ops.name_scope(name, 'import', input_map.values()) as scope:
# Save unique prefix generated by name_scope
if scope:
assert scope.endswith('/')
prefix = scope[:-1]
else:
prefix = ''
# Generate any input map tensors inside name scope
input_map = _ConvertInputMapValues(name, input_map)
scoped_options = c_api_util.ScopedTFImportGraphDefOptions()
options = scoped_options.options
_PopulateTFImportGraphDefOptions(options, prefix, input_map,
return_elements)
with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
try:
with errors.raise_exception_on_not_ok_status() as status:
results = c_api.TF_GraphImportGraphDefWithResults(
graph._c_graph, serialized, options, status) # pylint: disable=protected-access
except errors.InvalidArgumentError as e:
# Convert to ValueError for backwards compatibility.
raise ValueError(str(e))
_ProcessNewOps(graph)
# Create _DefinedFunctions for any imported functions.
#
# We do this by creating _DefinedFunctions directly from `graph_def`, and
# adding them to `graph`. Adding an existing function to a TF_Graph is a
# no-op, so this only has the effect of updating the Python state (usually
# _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
#
# TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
# TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
if graph_def.library and graph_def.library.function:
# pylint: disable=protected-access
functions = function._from_library(graph_def.library)
for f in functions:
f.add_to_graph(graph)
# pylint: enable=protected-access
# Treat input mappings that don't appear in the graph as an error, because
#.........这里部分代码省略.........
示例7: import_graph_def
#.........这里部分代码省略.........
name: (Optional.) A prefix that will be prepended to the names in
`graph_def`. Note that this does not apply to imported function names.
Defaults to `"import"`.
op_dict: (Optional.) Deprecated, do not use.
producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
list of `OpDef`s used by the producer of the graph. If provided,
unrecognized attrs for ops in `graph_def` that have their default value
according to `producer_op_list` will be removed. This will allow some more
`GraphDef`s produced by later binaries to be accepted by earlier binaries.
Returns:
A list of `Operation` and/or `Tensor` objects from the imported graph,
corresponding to the names in `return_elements`.
Raises:
TypeError: If `graph_def` is not a `GraphDef` proto,
`input_map` is not a dictionary mapping strings to `Tensor` objects,
or `return_elements` is not a list of strings.
ValueError: If `input_map`, or `return_elements` contains names that
do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
it refers to an unknown tensor).
"""
op_dict = op_def_registry.get_registered_ops()
graph_def = _ProcessGraphDefParam(graph_def, op_dict)
input_map = _ProcessInputMapParam(input_map)
return_elements = _ProcessReturnElementsParam(return_elements)
if producer_op_list is not None:
# TODO(skyewm): make a copy of graph_def so we're not mutating the argument?
_RemoveDefaultAttrs(op_dict, producer_op_list, graph_def)
graph = ops.get_default_graph()
with ops.name_scope(name, 'import', input_map.values()) as scope:
# Save unique prefix generated by name_scope
if scope:
assert scope.endswith('/')
prefix = scope[:-1]
else:
prefix = ''
# Generate any input map tensors inside name scope
input_map = _ConvertInputMapValues(name, input_map)
scoped_options = c_api_util.ScopedTFImportGraphDefOptions()
options = scoped_options.options
_PopulateTFImportGraphDefOptions(options, prefix, input_map,
return_elements)
# _ProcessNewOps mutates the new operations. _mutation_lock ensures a
# Session.run call cannot occur between creating the TF_Operations in the
# TF_GraphImportGraphDefWithResults call and mutating the them in
# _ProcessNewOps.
with graph._mutation_lock(): # pylint: disable=protected-access
with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
try:
results = c_api.TF_GraphImportGraphDefWithResults(
graph._c_graph, serialized, options) # pylint: disable=protected-access
results = c_api_util.ScopedTFImportGraphDefResults(results)
except errors.InvalidArgumentError as e:
# Convert to ValueError for backwards compatibility.
raise ValueError(str(e))
# Create _DefinedFunctions for any imported functions.
#
# We do this by creating _DefinedFunctions directly from `graph_def`, and
# adding them to `graph`. Adding an existing function to a TF_Graph is a
# no-op, so this only has the effect of updating the Python state (usually
# _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
#
# TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
# TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
# TODO(b/74620627): move this after _ProcessNewOps outside the lock once
# _USE_C_SHAPES is removed.
if graph_def.library and graph_def.library.function:
# pylint: disable=protected-access
functions = function._from_library(graph_def.library)
for f in functions:
f.add_to_graph(graph)
# pylint: enable=protected-access
_ProcessNewOps(graph)
# Treat input mappings that don't appear in the graph as an error, because
# they are likely to be due to a typo.
missing_unused_input_keys = (
c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
results.results))
if missing_unused_input_keys:
missing_unused_input_keys = [
compat.as_str(s) for s in missing_unused_input_keys
]
raise ValueError(
'Attempted to map inputs that were not found in graph_def: [%s]' %
', '.join(missing_unused_input_keys))
if return_elements is None:
return None
else:
return _GatherReturnElements(return_elements, graph, results.results)