本文整理汇总了Python中tensorflow.python.framework.op_def_registry.get_registered_ops函数的典型用法代码示例。如果您正苦于以下问题:Python get_registered_ops函数的具体用法?Python get_registered_ops怎么用?Python get_registered_ops使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了get_registered_ops函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: stripped_op_list_for_graph
def stripped_op_list_for_graph(graph_def):
"""Collect the stripped OpDefs for ops used by a graph.
This function computes the `stripped_op_list` field of `MetaGraphDef` and
similar protos. The result can be communicated from the producer to the
consumer, which can then use the C++ function
`RemoveNewDefaultAttrsFromGraphDef` to improve forwards compatibility.
Args:
graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`.
Returns:
An `OpList` of ops used by the graph.
Raises:
ValueError: If an unregistered op is used.
"""
# This is the Python equivalent of StrippedOpListForGraph in C++.
# Unfortunately, since the Python op registry can differ from that in C++, we
# can't remove the duplication using swig (at least naively).
# TODO(irving): Support taking graphs directly.
used_ops = ops_used_by_graph_def(graph_def)
# Verify that all used ops are registered.
registered_ops = op_def_registry.get_registered_ops()
# These internal ops used by functions are not registered, so we need to
# whitelist them. # TODO(irving): Do something better here.
op_whitelist = ("_Arg", "_Retval", "_ListToArray", "_ArrayToList")
for op in used_ops:
if op not in registered_ops and op not in op_whitelist:
raise ValueError("Op %s is used by the graph, but is not registered" % op)
# Build the stripped op list in sorted order
return op_def_pb2.OpList(op=[registered_ops[op] for op in sorted(used_ops) if op in registered_ops])
示例2: testStripDefaultAttrsInconsistentConsumerDefaults
def testStripDefaultAttrsInconsistentConsumerDefaults(self):
if ops._USE_C_API: return # TODO(skyewm): get this working
export_dir = self._get_export_dir(
"test_strip_default_attrs_no_consumer_defaults")
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Add a graph with two float32 variables and a Complex Op composing them
# with strip_default_attrs enabled. This must remove the following
# defaults for the "Complex" Op:
# o "T" : float32. (input type)
# o "Tout" : complex64. (output type)
with session.Session(graph=ops.Graph()) as sess:
real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
math_ops.complex(real_num, imag_num, name="complex")
sess.run(variables.global_variables_initializer())
builder.add_meta_graph_and_variables(
sess, ["foo"], strip_default_attrs=True)
# Save the SavedModel to disk in text format.
builder.save(as_text=True)
# Update the Op registry to remove defaults for all attrs("T", "Tout") from
# the "Complex" OpDef.
complex_op_def = op_def_registry.get_registered_ops()["Complex"]
original_complex_op_def = op_def_pb2.OpDef()
original_complex_op_def.CopyFrom(complex_op_def)
for attr_def in complex_op_def.attr:
attr_def.ClearField("default_value")
# Loading the SavedModel via the loader must fail because the SavedModel
# does not have any attr values for the "Complex" node and the current
# op registry does not have have any default values for the "Complex" op.
sess = session.Session(graph=ops.Graph())
with self.assertRaisesRegexp(
ValueError,
"Expected one attr with name .*T(out)?.* in name: \"complex\".*"):
loader.load(sess, ["foo"], export_dir)
# Update the Op registry to change the defaults for attr "Tout"
# (complex64 -> complex128).
complex_op_def.CopyFrom(original_complex_op_def)
for attr_def in complex_op_def.attr:
if attr_def.name == "Tout":
attr_def.default_value.type = types_pb2.DT_COMPLEX128
# Loading the SavedModel via the loader must set "Tout" attr_value for the
# "Complex" node according to the latest defaults (complex128). This is
# expected to fail the model import as there is no OpKernel registered to
# handle attrs "T" (float32) and "Tout" (complex128).
sess = session.Session(graph=ops.Graph())
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
".*No OpKernel was registered to support Op \'Complex\' with these "
"attrs..*"):
loader.load(sess, ["foo"], export_dir)
示例3: _is_array_type_input
def _is_array_type_input(op, i):
registered_ops = op_def_registry.get_registered_ops()
if op not in registered_ops:
return False
op_def = registered_ops[op]
if i not in xrange(len(op_def.input_arg)):
raise TypeError("Expected arg index " "to be in [0, %d)" % len(op_def.input_arg))
input_arg = op_def.input_arg[i]
return True if input_arg.number_attr else False
示例4: _register_function_ops
def _register_function_ops(func_list):
"""Registers custom ops in the default graph. This is needed
Because our checkpoint is saved with ops that are not part of Tensorflow."""
op_dict = op_def_registry.get_registered_ops()
for func in func_list:
#pylint: disable=W0212
func._create_definition_if_needed()
op_def = func._definition.signature
op_dict[op_def.name] = op_def
RegisterShape(op_def.name)(common_shapes.unknown_shape)
示例5: _add_op_node
def _add_op_node(op, func):
"""Converts an op to a function def node and add it to `func`."""
node = function_pb2.FunctionDef.Node()
node.op = op.type
# pylint: disable=protected-access
if hasattr(op, "_sig"):
op_def = getattr(op, "_sig")
else:
op_def = op_def_registry.get_registered_ops()[op.type]
# pylint: enable=protected-access
attrs = _get_node_def_attr(op)
if not op_def.output_arg:
node.ret.append(_make_argname_from_tensor_name(op.name))
else:
out_index = 0
for arg_def in op_def.output_arg:
if arg_def.number_attr:
dtype = arg_def.type or attrs[arg_def.type_attr].type
num = attrs[arg_def.number_attr].i
node.ret.append(
_add_output_array(op, out_index, out_index + num, dtype, func))
out_index += num
elif arg_def.type_list_attr:
dtype_lst = attrs[arg_def.type_list_attr].list.type
num = len(dtype_lst)
node.ret.append(
_add_output_list(op, out_index, out_index + num, dtype_lst, func))
out_index += num
else:
node.ret.append(
_make_argname_from_tensor_name(op.outputs[out_index].name))
out_index += 1
inp_index = 0
for arg_def in op_def.input_arg:
if arg_def.number_attr:
dtype = arg_def.type or attrs[arg_def.type_attr].type
num = attrs[arg_def.number_attr].i
node.arg.append(
_add_input_array(op, inp_index, inp_index + num, dtype, func))
inp_index += num
elif arg_def.type_list_attr:
num = len(attrs[arg_def.type_list_attr].list.type)
node.arg.extend([
_make_argname_from_tensor_name(op.inputs[i].name)
for i in range(inp_index, inp_index + num)
])
inp_index += num
else:
node.arg.append(_make_argname_from_tensor_name(op.inputs[inp_index].name))
inp_index += 1
node.dep.extend(
[_make_argname_from_tensor_name(x.name) for x in op.control_inputs])
for k, v in _get_node_def_attr(op).items():
node.attr[k].CopyFrom(v)
func.node.extend([node])
示例6: _stripped_op_list_for_graph
def _stripped_op_list_for_graph(graph_def):
"""Returns OpDefs of ops used in graph_def."""
op_set = set()
registered_ops = op_def_registry.get_registered_ops()
for n in graph_def.node:
if n.op in registered_ops:
op_set.add(n.op)
for func in graph_def.library.function:
for n in func.node:
if n.op in registered_ops:
op_set.add(n.op)
return op_def_pb2.OpList(op=[registered_ops[x] for x in sorted(op_set)])
示例7: list_registered_stateful_ops_without_inputs
def list_registered_stateful_ops_without_inputs():
"""Returns set of registered stateful ops that do not expect inputs.
This list is used to identify the ops to be included in the state-graph and
that are subsequently fed into the apply-graphs.
Returns:
A set of strings.
"""
return set([
name
for name, op in op_def_registry.get_registered_ops().items()
if op.is_stateful and not op.input_arg
])
示例8: _get_ref_args
def _get_ref_args(self, node):
"""Determine whether an input of an op is ref-type.
Args:
node: A `NodeDef`.
Returns:
A list of the arg names (as strs) that are ref-type.
"""
op_def = op_def_registry.get_registered_ops().get(node.op)
ref_args = []
if op_def:
for i, output_arg in enumerate(op_def.output_arg):
if output_arg.is_ref:
arg_name = node.name if i == 0 else ("%s:%d" % (node.name, i))
ref_args.append(arg_name)
return ref_args
示例9: _strip_graph_default_valued_attrs
def _strip_graph_default_valued_attrs(meta_graph_def):
"""Strips default valued attributes for node defs in given MetaGraphDef.
This method also sets `meta_info_def.stripped_default_attrs` in the given
`MetaGraphDef` proto to True.
Args:
meta_graph_def: `MetaGraphDef` protocol buffer
Returns:
None.
"""
# Map function op names to their function definitions.
op_name_to_function = {}
for function_def in meta_graph_def.graph_def.library.function:
op_name_to_function[function_def.signature.name] = function_def
# Get all registered ops.
registered_ops = op_def_registry.get_registered_ops()
def _strip_node_default_valued_attrs(node_def):
"""Removes default valued attributes from a single node def."""
if node_def.op in op_name_to_function or node_def.op not in registered_ops:
return
op_def = registered_ops[node_def.op]
attrs_to_strip = set()
for attr_name, attr_value in node_def.attr.items():
if _is_default_attr_value(op_def, attr_name, attr_value):
attrs_to_strip.add(attr_name)
for attr in attrs_to_strip:
del node_def.attr[attr]
# Process all NodeDef instances in graph_def.
for node_def in meta_graph_def.graph_def.node:
_strip_node_default_valued_attrs(node_def)
# Process all NodeDef instances in graph_def.library.function.
for function_def in meta_graph_def.graph_def.library.function:
for function_node_def in function_def.node_def:
_strip_node_default_valued_attrs(function_node_def)
# Tell consumers of this graph that default valued attrs have been stripped.
meta_graph_def.meta_info_def.stripped_default_attrs = True
示例10: register_ops_if_needed
def register_ops_if_needed(graph_ops):
"""Register graph ops absent in op_def_registry, if present in c++ registry.
Args:
graph_ops: set with graph op names to register.
Raises:
RuntimeError: if `graph_ops` contains ops that are not in either python or
c++ registry.
"""
missing_ops = graph_ops - set(op_def_registry.get_registered_ops().keys())
if not missing_ops:
return
p_buffer = c_api.TF_GetAllOpList()
cpp_op_list = op_def_pb2.OpList()
cpp_op_list.ParseFromString(c_api.TF_GetBuffer(p_buffer))
cpp_registry_ops = {op.name: op for op in cpp_op_list.op}
missing_op_list = op_def_pb2.OpList()
for missing_op in missing_ops:
if missing_op not in cpp_registry_ops:
tf.logging.info(
"Op %s is missing from both the python and C++ registry.",
missing_op)
else:
missing_op_list.op.extend([cpp_registry_ops[missing_op]])
tf.logging.info(
"Adding op %s from c++ registry to python registry.",
missing_op)
op_def_registry.register_op_list(missing_op_list)
# Note: Only raise missing op ValueError after trying to load ops.
# This allows the test to exercise all the calls into TensorFlow
# without having to write a C + python test.
if not missing_ops <= set(cpp_registry_ops.keys()):
raise RuntimeError(
"Graph ops missing from the python registry (%s) are also absent from "
"the c++ registry."
% missing_ops.difference(set(cpp_registry_ops.keys())))
示例11: _stripped_op_list_for_graph
def _stripped_op_list_for_graph(graph_def):
registered_ops = op_def_registry.get_registered_ops()
used_ops = {n.op for n in graph_def.node}
op_list = [registered_ops[op_name] for op_name in sorted(used_ops)]
return op_def_pb2.OpList(op=op_list)
示例12: import_graph_def
def import_graph_def(graph_def, input_map=None, return_elements=None,
name=None, op_dict=None, producer_op_list=None):
"""Imports the graph from `graph_def` into the current default `Graph`.
This function provides a way to import a serialized TensorFlow
[`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
protocol buffer, and extract individual objects in the `GraphDef` as
@{tf.Tensor} and @{tf.Operation} objects. Once extracted,
these objects are placed into the current default `Graph`. See
@{tf.Graph.as_graph_def} for a way to create a `GraphDef`
proto.
Args:
graph_def: A `GraphDef` proto containing operations to be imported into
the default graph.
input_map: A dictionary mapping input names (as strings) in `graph_def`
to `Tensor` objects. The values of the named input tensors in the
imported graph will be re-mapped to the respective `Tensor` values.
return_elements: A list of strings containing operation names in
`graph_def` that will be returned as `Operation` objects; and/or
tensor names in `graph_def` that will be returned as `Tensor` objects.
name: (Optional.) A prefix that will be prepended to the names in
`graph_def`. Note that this does not apply to imported function names.
Defaults to `"import"`.
op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos.
Must contain an `OpDef` proto for each op type named in `graph_def`.
If omitted, uses the `OpDef` protos registered in the global registry.
producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
list of `OpDef`s used by the producer of the graph. If provided, attrs
for ops in `graph_def` that are not in `op_dict` that have their default
value according to `producer_op_list` will be removed. This will allow
some more `GraphDef`s produced by later binaries to be accepted by
earlier binaries.
Returns:
A list of `Operation` and/or `Tensor` objects from the imported graph,
corresponding to the names in `return_elements`.
Raises:
TypeError: If `graph_def` is not a `GraphDef` proto,
`input_map` is not a dictionary mapping strings to `Tensor` objects,
or `return_elements` is not a list of strings.
ValueError: If `input_map`, or `return_elements` contains names that
do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
it refers to an unknown tensor).
"""
# Type checks for inputs.
if not isinstance(graph_def, graph_pb2.GraphDef):
# `graph_def` could be a dynamically-created message, so try a duck-typed
# approach
try:
old_graph_def = graph_def
graph_def = graph_pb2.GraphDef()
graph_def.MergeFrom(old_graph_def)
except TypeError:
raise TypeError('graph_def must be a GraphDef proto.')
if input_map is None:
input_map = {}
else:
if not (isinstance(input_map, dict)
and all(isinstance(k, compat.bytes_or_text_types)
for k in input_map.keys())):
raise TypeError('input_map must be a dictionary mapping strings to '
'Tensor objects.')
if return_elements is not None:
return_elements = tuple(return_elements)
if not all(isinstance(x, compat.bytes_or_text_types)
for x in return_elements):
raise TypeError('return_elements must be a list of strings.')
# Use a canonical representation for all tensor names.
input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
used_input_keys = set()
name_to_op = {}
if op_dict is None:
op_dict = op_def_registry.get_registered_ops()
if producer_op_list is None:
producer_op_dict = None
else:
producer_op_dict = {op.name: op for op in producer_op_list.op}
g = ops.get_default_graph()
# Add any functions defined in `graph_def` to `g`
if graph_def.library and graph_def.library.function:
# Copy op_dict so we don't clobber the original
op_dict = copy.copy(op_dict)
# pylint: disable=protected-access
# Note that we do not prepend `name` to the function name. The reasoning is
# that function names are similar to op definition names, which currently do
# not have a scoped name or namespace scheme.
functions = function._from_library(graph_def.library)
for f in functions:
g._add_function(f)
op_dict[f.name] = f.definition.signature
# pylint: enable=protected-access
#.........这里部分代码省略.........
示例13: _get_op_def
def _get_op_def(op):
return op.op_def or op_def_registry.get_registered_ops()[op.type]
示例14: _get_op_def
def _get_op_def(op):
# pylint: disable=protected-access
if hasattr(op, "_sig"):
return getattr(op, "_sig")
else:
return op_def_registry.get_registered_ops()[op.type]
示例15: import_graph_def
def import_graph_def(graph_def, input_map=None, return_elements=None,
name=None, op_dict=None, producer_op_list=None):
"""Imports the graph from `graph_def` into the current default `Graph`.
This function provides a way to import a serialized TensorFlow
[`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
protocol buffer, and extract individual objects in the `GraphDef` as
@{tf.Tensor} and @{tf.Operation} objects. Once extracted,
these objects are placed into the current default `Graph`. See
@{tf.Graph.as_graph_def} for a way to create a `GraphDef`
proto.
Args:
graph_def: A `GraphDef` proto containing operations to be imported into
the default graph.
input_map: A dictionary mapping input names (as strings) in `graph_def`
to `Tensor` objects. The values of the named input tensors in the
imported graph will be re-mapped to the respective `Tensor` values.
return_elements: A list of strings containing operation names in
`graph_def` that will be returned as `Operation` objects; and/or
tensor names in `graph_def` that will be returned as `Tensor` objects.
name: (Optional.) A prefix that will be prepended to the names in
`graph_def`. Note that this does not apply to imported function names.
Defaults to `"import"`.
op_dict: (Optional.) Deprecated, do not use.
producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
list of `OpDef`s used by the producer of the graph. If provided,
unrecognized attrs for ops in `graph_def` that have their default value
according to `producer_op_list` will be removed. This will allow some more
`GraphDef`s produced by later binaries to be accepted by earlier binaries.
Returns:
A list of `Operation` and/or `Tensor` objects from the imported graph,
corresponding to the names in `return_elements`.
Raises:
TypeError: If `graph_def` is not a `GraphDef` proto,
`input_map` is not a dictionary mapping strings to `Tensor` objects,
or `return_elements` is not a list of strings.
ValueError: If `input_map`, or `return_elements` contains names that
do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
it refers to an unknown tensor).
"""
graph_def = _ProcessGraphDefParam(graph_def)
input_map = _ProcessInputMapParam(input_map)
return_elements = _ProcessReturnElementsParam(return_elements)
op_dict = op_def_registry.get_registered_ops()
if producer_op_list is not None:
# TODO(skyewm): make a copy of graph_def so we're not mutating the argument?
_RemoveDefaultAttrs(op_dict, producer_op_list, graph_def)
graph = ops.get_default_graph()
if graph._c_graph: # pylint: disable=protected-access
with ops.name_scope(name, 'import', input_map.values()) as scope:
# Save unique prefix generated by name_scope
if scope:
assert scope.endswith('/')
prefix = scope[:-1]
else:
prefix = ''
# Generate any input map tensors inside name scope
input_map = _ConvertInputMapValues(name, input_map)
scoped_options = c_api_util.ScopedTFImportGraphDefOptions()
options = scoped_options.options
_PopulateTFImportGraphDefOptions(options, prefix, input_map,
return_elements)
with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
try:
with errors.raise_exception_on_not_ok_status() as status:
results = c_api.TF_GraphImportGraphDefWithResults(
graph._c_graph, serialized, options, status) # pylint: disable=protected-access
except errors.InvalidArgumentError as e:
# Convert to ValueError for backwards compatibility.
raise ValueError(str(e))
_ProcessNewOps(graph)
# Create _DefinedFunctions for any imported functions.
#
# We do this by creating _DefinedFunctions directly from `graph_def`, and
# adding them to `graph`. Adding an existing function to a TF_Graph is a
# no-op, so this only has the effect of updating the Python state (usually
# _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
#
# TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
# TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
if graph_def.library and graph_def.library.function:
# pylint: disable=protected-access
functions = function._from_library(graph_def.library)
for f in functions:
f.add_to_graph(graph)
# pylint: enable=protected-access
# Treat input mappings that don't appear in the graph as an error, because
#.........这里部分代码省略.........