本文整理汇总了Python中tensorflow.python.framework.op_def_registry.get_registered_ops方法的典型用法代码示例。如果您正苦于以下问题:Python op_def_registry.get_registered_ops方法的具体用法?Python op_def_registry.get_registered_ops怎么用?Python op_def_registry.get_registered_ops使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.python.framework.op_def_registry
的用法示例。
在下文中一共展示了op_def_registry.get_registered_ops方法的9个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _get_ref_args
# 需要导入模块: from tensorflow.python.framework import op_def_registry [as 别名]
# 或者: from tensorflow.python.framework.op_def_registry import get_registered_ops [as 别名]
def _get_ref_args(self, node):
"""Determine whether an input of an op is ref-type.
Args:
node: A `NodeDef`.
Returns:
A list of the arg names (as strs) that are ref-type.
"""
op_def = op_def_registry.get_registered_ops().get(node.op)
ref_args = []
if op_def:
for i, output_arg in enumerate(op_def.output_arg):
if output_arg.is_ref:
arg_name = node.name if i == 0 else ("%s:%d" % (node.name, i))
ref_args.append(arg_name)
return ref_args
开发者ID:PacktPublishing,项目名称:Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda,代码行数:19,代码来源:debug_graphs.py
示例2: stripped_op_list_for_graph
# 需要导入模块: from tensorflow.python.framework import op_def_registry [as 别名]
# 或者: from tensorflow.python.framework.op_def_registry import get_registered_ops [as 别名]
def stripped_op_list_for_graph(graph_def):
"""Collect the stripped OpDefs for ops used by a graph.
This function computes the `stripped_op_list` field of `MetaGraphDef` and
similar protos. The result can be communicated from the producer to the
consumer, which can then use the C++ function
`RemoveNewDefaultAttrsFromGraphDef` to improve forwards compatibility.
Args:
graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`.
Returns:
An `OpList` of ops used by the graph.
Raises:
ValueError: If an unregistered op is used.
"""
# This is the Python equivalent of StrippedOpListForGraph in C++.
# Unfortunately, since the Python op registry can differ from that in C++, we
# can't remove the duplication using swig (at least naively).
# TODO(irving): Support taking graphs directly.
used_ops = ops_used_by_graph_def(graph_def)
# Verify that all used ops are registered.
registered_ops = op_def_registry.get_registered_ops()
# These internal ops used by functions are not registered, so we need to
# whitelist them. # TODO(irving): Do something better here.
op_whitelist = ("_Arg", "_Retval", "_ListToArray", "_ArrayToList")
for op in used_ops:
if op not in registered_ops and op not in op_whitelist:
raise ValueError("Op %s is used by the graph, but is not registered" % op)
# Build the stripped op list in sorted order
return op_def_pb2.OpList(op=[registered_ops[op] for op in sorted(used_ops)
if op in registered_ops])
示例3: _get_op_def
# 需要导入模块: from tensorflow.python.framework import op_def_registry [as 别名]
# 或者: from tensorflow.python.framework.op_def_registry import get_registered_ops [as 别名]
def _get_op_def(op):
return op.op_def or op_def_registry.get_registered_ops()[op.type]
示例4: __call__
# 需要导入模块: from tensorflow.python.framework import op_def_registry [as 别名]
# 或者: from tensorflow.python.framework.op_def_registry import get_registered_ops [as 别名]
def __call__(self, inputs, state, scope=None):
if self._compile_stateful:
compile_ops = True
else:
def compile_ops(node_def):
global _REGISTERED_OPS
if _REGISTERED_OPS is None:
_REGISTERED_OPS = op_def_registry.get_registered_ops()
return not _REGISTERED_OPS[node_def.op].is_stateful
with jit.experimental_jit_scope(compile_ops=compile_ops):
return self._cell(inputs, state, scope)
示例5: _get_op_def
# 需要导入模块: from tensorflow.python.framework import op_def_registry [as 别名]
# 或者: from tensorflow.python.framework.op_def_registry import get_registered_ops [as 别名]
def _get_op_def(op):
# pylint: disable=protected-access
if hasattr(op, "_sig"):
return getattr(op, "_sig")
else:
return op_def_registry.get_registered_ops()[op.type]
# pylint: enable=protected-access
示例6: __call__
# 需要导入模块: from tensorflow.python.framework import op_def_registry [as 别名]
# 或者: from tensorflow.python.framework.op_def_registry import get_registered_ops [as 别名]
def __call__(self, inputs, state, scope=None):
if self._compile_stateful:
compile_ops = True
else:
def compile_ops(node_def):
global _REGISTERED_OPS
if _REGISTERED_OPS is None:
_REGISTERED_OPS = op_def_registry.get_registered_ops()
return not _REGISTERED_OPS[node_def.op].is_stateful
with jit.experimental_jit_scope(compile_ops=compile_ops):
return self._cell(inputs, state, scope=scope)
示例7: get
# 需要导入模块: from tensorflow.python.framework import op_def_registry [as 别名]
# 或者: from tensorflow.python.framework.op_def_registry import get_registered_ops [as 别名]
def get(name):
registered_ops = op_def_registry.get_registered_ops()
return registered_ops.get(name)
示例8: sync
# 需要导入模块: from tensorflow.python.framework import op_def_registry [as 别名]
# 或者: from tensorflow.python.framework.op_def_registry import get_registered_ops [as 别名]
def sync():
p_buffer = c_api.TF_GetAllOpList()
cpp_op_list = op_def_pb2.OpList()
cpp_op_list.ParseFromString(c_api.TF_GetBuffer(p_buffer))
registered_ops = op_def_registry.get_registered_ops()
for op_def in cpp_op_list.op:
# If an OpList is registered from a gen_*_ops.py, it does not any
# descriptions. Strip them here as well to satisfy validation in
# register_op_list.
_remove_non_deprecated_descriptions(op_def)
registered_ops[op_def.name] = op_def
示例9: _add_op_node
# 需要导入模块: from tensorflow.python.framework import op_def_registry [as 别名]
# 或者: from tensorflow.python.framework.op_def_registry import get_registered_ops [as 别名]
def _add_op_node(op, func):
"""Converts an op to a function def node and add it to `func`."""
node = function_pb2.FunctionDef.Node()
node.op = op.type
# pylint: disable=protected-access
if hasattr(op, "_sig"):
op_def = getattr(op, "_sig")
else:
op_def = op_def_registry.get_registered_ops()[op.type]
# pylint: enable=protected-access
attrs = _get_node_def_attr(op)
if not op_def.output_arg:
node.ret.append(_make_argname_from_tensor_name(op.name))
else:
out_index = 0
for arg_def in op_def.output_arg:
if arg_def.number_attr:
dtype = arg_def.type or attrs[arg_def.type_attr].type
num = attrs[arg_def.number_attr].i
node.ret.append(
_add_output_array(op, out_index, out_index + num, dtype, func))
out_index += num
elif arg_def.type_list_attr:
dtype_lst = attrs[arg_def.type_list_attr].list.type
num = len(dtype_lst)
node.ret.append(
_add_output_list(op, out_index, out_index + num, dtype_lst, func))
out_index += num
else:
node.ret.append(
_make_argname_from_tensor_name(op.outputs[out_index].name))
out_index += 1
inp_index = 0
for arg_def in op_def.input_arg:
if arg_def.number_attr:
dtype = arg_def.type or attrs[arg_def.type_attr].type
num = attrs[arg_def.number_attr].i
node.arg.append(
_add_input_array(op, inp_index, inp_index + num, dtype, func))
inp_index += num
elif arg_def.type_list_attr:
num = len(attrs[arg_def.type_list_attr].list.type)
node.arg.extend([
_make_argname_from_tensor_name(op.inputs[i].name)
for i in range(inp_index, inp_index + num)
])
inp_index += num
else:
node.arg.append(_make_argname_from_tensor_name(op.inputs[inp_index].name))
inp_index += 1
node.dep.extend(
[_make_argname_from_tensor_name(x.name) for x in op.control_inputs])
for k, v in _get_node_def_attr(op).items():
node.attr[k].CopyFrom(v)
func.node.extend([node])