本文整理汇总了Python中tensorflow.python.framework.graph_to_function_def.graph_to_function_def函数的典型用法代码示例。如果您正苦于以下问题:Python graph_to_function_def函数的具体用法?Python graph_to_function_def怎么用?Python graph_to_function_def使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了graph_to_function_def函数的9个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _compute_backprop
def _compute_backprop(self):
"""Computes the backprop function object for this function."""
self._has_backprop = True
with self._graph.as_default(), context.graph_mode():
c = _CapturingContext()
with c:
filtered_outputs = [
x for x in self._returns if x is not None
]
self._out_grad_placeholders = [
graph_placeholder(x.dtype, x.shape) for x in filtered_outputs
]
in_gradients = gradients_impl.gradients(
filtered_outputs,
self._input_placeholders,
grad_ys=self._out_grad_placeholders)
shapes = [x.shape for x in in_gradients if x is not None]
captures = list(sorted(c.captured_tensors, key=lambda x: x.name))
forward_function_def = graph_to_function_def.graph_to_function_def(
self._graph, self._ops, self._input_placeholders,
filtered_outputs + captures)
self._forward_fdef = _DefinedFunction(forward_function_def)
_register_with_name(_forward_name(self._func_name), forward_function_def)
backward_outputs = [x for x in in_gradients if x is not None]
all_inputs = self._out_grad_placeholders + captures
backward_function_def = graph_to_function_def.graph_to_function_def(
self._graph, [x.op for x in self._out_grad_placeholders
] + list(sorted(c.known_ops, key=lambda x: x.name)),
all_inputs, backward_outputs)
_register_with_name(_backward_name(self._func_name), backward_function_def)
self._backward_function = _GraphModeFunction(
all_inputs, [], backward_function_def, self._graph, c.known_ops,
in_gradients, _map_sequence_obj_to_idx(backward_outputs), shapes)
示例2: _create_definition_if_needed
def _create_definition_if_needed(self):
"""Creates the function definition if it's not created yet."""
if self._definition is not None:
return
# Create the func_def object.
temp_graph = _FuncGraph()
with temp_graph.as_default():
# List of placeholders for the function_def.
inputs = []
for (argname, argtype) in self._args:
argholder = array_ops.placeholder(argtype, name=argname)
inputs.append(argholder)
# Call func and gather the output tensors.
with vs.variable_scope("", custom_getter=temp_graph.getvar):
outputs = self._func(*inputs)
# If func only returned one value, make it a tuple.
if not isinstance(outputs, (list, tuple)):
outputs = (outputs,)
if any([_ is None for _ in outputs]):
raise ValueError("Function can not return None.")
# Ensures each output is a Tensor.
outputs = [ops.convert_to_tensor(_) for _ in outputs]
self._extra_inputs = temp_graph.extra_inputs
inputs.extend(temp_graph.extra_args)
# pylint: disable=protected-access
self._sub_functions = temp_graph._functions
# pylint: enable=protected-access
# Build the FunctionDef
self._definition = graph_to_function_def.graph_to_function_def(
temp_graph,
temp_graph.get_operations(),
inputs,
outputs,
out_names=self._out_names)
# Extra kwargs are treated as attrs on the function def.
sig_pre_func_name = self._func_name or _get_func_name(self._func)
kwargs_attr = _parse_kwargs_as_attrs(sig_pre_func_name,
**self._extra_kwargs)
for k in kwargs_attr:
self._definition.attr[k].CopyFrom(kwargs_attr[k])
# Hash the definition and its dependencies.
self._hash_str = self._create_hash_str(
self._definition.signature.input_arg,
self._definition.signature.output_arg, self._definition.node_def)
# Finally, we decide the function name to use. If not specified,
# make up something which is almost certainly unique (but deterministic).
if not self._func_name:
self._func_name = "_".join([_get_func_name(self._func), self._hash_str])
self._definition.signature.name = self._func_name
if self._func.__doc__:
self._definition.signature.description = self._func.__doc__
示例3: testTwoInputsSameOp
def testTwoInputsSameOp(self):
g = ops.Graph()
with g.as_default():
m = array_ops.placeholder(dtypes.float32)
s, u, v = linalg_ops.svd(m)
ss = math_ops.reduce_sum(s)
uu = math_ops.reduce_sum(u)
vv = math_ops.reduce_sum(v)
result = ss + uu + vv
f = graph_to_function_def.graph_to_function_def(
g,
g.get_operations()[1:], # skip the placeholder
[s, u, v],
[result])
self.assertEqual(len(f.signature.input_arg), 3)
示例4: _defun_internal
def _defun_internal(name, func, args, kwds):
"""Defines and returns graph-mode version of func."""
with context.graph_mode():
tmp_graph = ops.Graph()
# Copy the graph collections to ensure summaries and other things work. This
# lets the function access (but not mutate) collections of the containing
# graph, such as the global step and the summary writer collections.
curr_graph = ops.get_default_graph()
for collection in curr_graph.collections:
tmp_graph.get_collection_ref(collection)[:] = curr_graph.get_collection(
collection)
with tmp_graph.as_default():
func_inputs = _get_defun_inputs(args)
captures = {}
with capture_tensors(captures):
func_outputs = func(*func_inputs, **kwds)
ids = list(sorted(captures.keys()))
if ids:
extra_inputs, extra_placeholders = zip(* [captures[x] for x in ids])
else:
extra_inputs = []
extra_placeholders = []
outputs_list = nest.flatten(func_outputs)
output_shapes = [x.shape for x in outputs_list if x is not None]
flat_inputs = [
x for x in nest.flatten(func_inputs) if isinstance(x, ops.Tensor)
]
all_inputs = flat_inputs + list(extra_placeholders)
func_def_outputs = [x for x in outputs_list if x is not None]
inference_function_def = graph_to_function_def.graph_to_function_def(
tmp_graph, tmp_graph.get_operations(), all_inputs, func_def_outputs)
# Register any other functions defined in the graph
# TODO(ashankar): Oh lord, forgive me for this lint travesty.
for f in tmp_graph._functions.values(): # pylint: disable=protected-access
# TODO(ashankar): What about the gradient registry?
_register_with_name(f.name, f.definition)
_register_with_name(_inference_name(name), inference_function_def)
return _GraphModeFunction(
all_inputs, extra_inputs, inference_function_def, tmp_graph,
tmp_graph.get_operations(), func_outputs,
_map_sequence_obj_to_idx(func_def_outputs), output_shapes)
示例5: _build_function_def
def _build_function_def(self):
with ops.Graph().as_default() as g:
# Inputs
x = array_ops.placeholder(dtypes.float32, name="x")
y = array_ops.placeholder(dtypes.float32, name="y")
# Outputs
sum_squares = math_ops.add_n(
[math_ops.pow(x, 2), math_ops.pow(y, 2)], name="sum_squares")
sum_cubes = math_ops.add_n(
[math_ops.pow(x, 3), math_ops.pow(y, 3)], name="sum_cubes")
fdef = graph_to_function_def.graph_to_function_def(
g,
g.get_operations(),
[x, y], # Inputs
[sum_squares, sum_cubes]) # Outputs.
fdef.signature.name = "_whats_in_a_name"
return fdef
示例6: make_function_def
def make_function_def(graph, operations, inputs, outputs):
"""Makes function def where accesses to resources are serialized."""
last_op_using_resource_tensor = {}
# TODO(apassos) probably control flow has to be handled delicately here as in
# if a resource is accessed inside a control flow context we need the control
# dependency to point to something outside the context which is guaranteed to
# happen after the access.
#
# TODO(apassos) this should do some form of alias analysis as ops which
# forward the resources such as Identity and Switch can cause serialization to
# fail.
for op in operations:
for t in op.inputs:
if t.dtype == dtypes.resource:
if t.name in last_op_using_resource_tensor:
op._add_control_input(last_op_using_resource_tensor[t.name]) # pylint: disable=protected-access
last_op_using_resource_tensor[t.name] = op
return graph_to_function_def.graph_to_function_def(
graph, operations, inputs, outputs)
示例7: _create_definition_if_needed_impl
def _create_definition_if_needed_impl(self):
"""This is not what you want, see _create_definition_if_needed."""
if self._definition is not None or self._c_func is not None:
return
temp_graph = func_graph_from_py_func(
self._func, self._arg_names, self._arg_types, self._func_name,
self._capture_by_value, self._caller_device)
self._extra_inputs = temp_graph.extra_inputs
# pylint: disable=protected-access
self._sub_functions = temp_graph._functions
# pylint: enable=protected-access
# Extra kwargs are treated as attrs on the function def.
if self._func_name:
base_func_name = self._func_name
else:
base_func_name = _get_func_name(self._func)
if self._grad_func:
base_func_name += ("_%s" % self._grad_func.name)
kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs)
if not temp_graph._c_graph: # pylint: disable=protected-access
# Build the FunctionDef
self._definition = graph_to_function_def.graph_to_function_def(
temp_graph,
temp_graph.get_operations(),
temp_graph.inputs,
temp_graph.outputs,
out_names=self._out_names)
for k in kwargs_attr:
self._definition.attr[k].CopyFrom(kwargs_attr[k])
# Hash the definition and its dependencies.
self._hash_str = self._create_hash_str(
self._definition.signature.input_arg,
self._definition.signature.output_arg, self._definition.node_def)
# Finally, we decide the function name to use. If not specified,
# make up something which is almost certainly unique (but deterministic).
if not self._func_name:
self._func_name = "_".join([base_func_name, self._hash_str])
self._definition.signature.name = self._func_name
if self._func.__doc__:
self._definition.signature.description = self._func.__doc__
self._op_def = self._definition.signature
else: # C API is enabled
output_names = ([compat.as_bytes(x) for x in self._out_names]
if self._out_names else [])
description = self._func.__doc__ or None
# pylint: disable=protected-access
c_func = c_api.TF_GraphToFunction_wrapper(
temp_graph._c_graph,
base_func_name,
self._func_name is None, # append_hash_to_fn_name
None, # opers
[t._as_tf_output() for t in temp_graph.inputs],
[t._as_tf_output() for t in temp_graph.outputs],
output_names,
None, # opts
description)
self._c_func = c_api_util.ScopedTFFunction(c_func)
# pylint: enable=protected-access
self._set_c_attrs(kwargs_attr)
# Set cached fields: _op_def and _func_name (if not already set)
self._op_def = self.definition.signature
if self._func_name:
assert self._func_name == self._op_def.name
else:
self._func_name = compat.as_str(self._op_def.name)
self._stateful_ops = [(op.name, op.type)
for op in temp_graph.get_operations()
if op.op_def.is_stateful]
示例8: _create_definition_if_needed_impl
def _create_definition_if_needed_impl(self):
"""This is not what you want, see _create_definition_if_needed."""
if self._definition is not None or self._c_func is not None:
return
# Create the func_def object.
temp_graph = _FuncGraph(capture_by_value=self._capture_by_value)
with temp_graph.as_default():
# List of placeholders for the function_def.
inputs = []
for (argname, argtype) in self._args:
argholder = array_ops.placeholder(argtype, name=argname)
inputs.append(argholder)
# Call func and gather the output tensors.
with vs.variable_scope("", custom_getter=temp_graph.getvar):
outputs = self._func(*inputs)
# There is no way of distinguishing between a function not returning
# anything and a function returning None in Python.
# We need to allow the former and ideally want to forbid the latter as
# it is most likely user error.
# TODO(iga): Consider adding a @NoOutput decorator on top of @Defun to
# allow users to explicitly mark the function as not returning anything.
# For now, we allow a single None return and interpret it as a function
# with no output.
if outputs is None:
outputs = []
else:
# If func only returned one value, make it a tuple.
if not isinstance(outputs, (list, tuple)):
outputs = (outputs,)
if any([_ is None for _ in outputs]):
raise ValueError("Function can not return None.")
# Ensures each output is a Tensor.
outputs = [ops.convert_to_tensor(_) for _ in outputs]
self._extra_inputs = temp_graph.extra_inputs
inputs.extend(temp_graph.extra_args)
# pylint: disable=protected-access
self._sub_functions = temp_graph._functions
# pylint: enable=protected-access
# Extra kwargs are treated as attrs on the function def.
base_func_name = self._func_name or _get_func_name(self._func)
kwargs_attr = _parse_kwargs_as_attrs(base_func_name,
**self._extra_kwargs)
if not temp_graph._c_graph: # pylint: disable=protected-access
# Build the FunctionDef
self._definition = graph_to_function_def.graph_to_function_def(
temp_graph,
temp_graph.get_operations(),
inputs,
outputs,
out_names=self._out_names)
for k in kwargs_attr:
self._definition.attr[k].CopyFrom(kwargs_attr[k])
# Hash the definition and its dependencies.
self._hash_str = self._create_hash_str(
self._definition.signature.input_arg,
self._definition.signature.output_arg, self._definition.node_def)
# Finally, we decide the function name to use. If not specified,
# make up something which is almost certainly unique (but deterministic).
if not self._func_name:
self._func_name = "_".join([base_func_name, self._hash_str])
self._definition.signature.name = self._func_name
if self._func.__doc__:
self._definition.signature.description = self._func.__doc__
self._op_def = self._definition.signature
else: # C API is enabled
output_names = ([compat.as_bytes(x) for x in self._out_names]
if self._out_names else [])
description = self._func.__doc__ or None
# pylint: disable=protected-access
with errors.raise_exception_on_not_ok_status() as status:
self._c_func = c_api.TF_GraphToFunction_wrapper(
temp_graph._c_graph,
base_func_name,
self._func_name is None, # append_hash_to_fn_name
None, # opers
[t._as_tf_output() for t in inputs],
[t._as_tf_output() for t in outputs],
output_names,
None, # opts
description,
status)
# pylint: enable=protected-access
self._set_c_attrs(kwargs_attr)
# Set cached fields: _op_def and _func_name (if not already set)
self._op_def = self.definition.signature
if self._func_name:
assert self._func_name == self._op_def.name
else:
self._func_name = compat.as_str(self._op_def.name)
示例9: _create_definition_if_needed_impl
def _create_definition_if_needed_impl(self):
"""This is not what you want, see _create_definition_if_needed."""
if self._definition is not None or self._c_func is not None:
return
# Copy variable collections (by reference) from the parent graph such that
# name based variable sharing (e.g. via tf.make_template) works between the
# func graph and parent graph.
variable_keys = []
variable_keys.extend(ops.GraphKeys._VARIABLE_COLLECTIONS) # pylint: disable=protected-access
variable_keys.append(vs._VARSTORE_KEY) # pylint: disable=protected-access
collections_ref = {}
parent_collections_ref = ops.get_default_graph()._collections # pylint: disable=protected-access
for key in variable_keys:
if key not in parent_collections_ref:
parent_collections_ref[key] = collections_ref[key] = []
else:
collections_ref[key] = parent_collections_ref[key]
temp_graph = func_graph_from_py_func(
self._func,
self._arg_names,
self._arg_types,
self._func_name,
self._capture_by_value,
self._caller_device,
collections_ref=collections_ref,
whitelisted_stateful_ops=self._whitelisted_stateful_ops,
capture_resource_var_by_value=self._capture_resource_var_by_value)
self._extra_inputs = temp_graph.extra_inputs
# pylint: disable=protected-access
self._sub_functions = temp_graph._functions
# pylint: enable=protected-access
# Extra kwargs are treated as attrs on the function def.
if self._func_name:
base_func_name = self._func_name
else:
base_func_name = function_utils.get_func_name(self._func)
if self._grad_func:
base_func_name += ("_%s" % self._grad_func.name)
kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs)
if not temp_graph._c_graph: # pylint: disable=protected-access
# Build the FunctionDef
self._definition = graph_to_function_def.graph_to_function_def(
temp_graph,
temp_graph.get_operations(),
temp_graph.inputs,
temp_graph.outputs,
out_names=self._out_names)
for k in kwargs_attr:
self._definition.attr[k].CopyFrom(kwargs_attr[k])
# Hash the definition and its dependencies.
self._hash_str = self._create_hash_str(
self._definition.signature.input_arg,
self._definition.signature.output_arg, self._definition.node_def)
# Finally, we decide the function name to use. If not specified,
# make up something which is almost certainly unique (but deterministic).
if not self._func_name:
self._func_name = "_".join([base_func_name, self._hash_str])
self._definition.signature.name = self._func_name
if self._func.__doc__:
self._definition.signature.description = self._func.__doc__
self._op_def = self._definition.signature
else: # C API is enabled
output_names = ([compat.as_bytes(x) for x in self._out_names]
if self._out_names else [])
description = self._func.__doc__ or None
# pylint: disable=protected-access
c_func = c_api.TF_GraphToFunction_wrapper(
temp_graph._c_graph,
base_func_name,
self._func_name is None, # append_hash_to_fn_name
None, # opers
[t._as_tf_output() for t in temp_graph.inputs],
[t._as_tf_output() for t in temp_graph.outputs],
output_names,
[], # control_outputs
[], # control_output_names
None, # opts
description)
self._c_func = c_api_util.ScopedTFFunction(c_func)
# pylint: enable=protected-access
self._set_c_attrs(kwargs_attr)
# Set cached fields: _op_def and _func_name (if not already set)
self._op_def = self.definition.signature
if self._func_name:
assert self._func_name == self._op_def.name
else:
self._func_name = compat.as_str(self._op_def.name)
self._stateful_ops = [(op.name, op.type)
#.........这里部分代码省略.........