本文整理汇总了Python中tensorflow.python.eager.function.make_defun_op函数的典型用法代码示例。如果您正苦于以下问题:Python make_defun_op函数的具体用法?Python make_defun_op怎么用?Python make_defun_op使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了make_defun_op函数的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: testDefunOpGraphModeNoneOutput
def testDefunOpGraphModeNoneOutput(self):
def fn(unused_a, unused_b):
return None
x = constant_op.constant(1)
fn_op = function.make_defun_op(fn, x, x)
self.assertEqual(fn_op.output_dtypes, None)
self.assertEqual(fn_op.output_shapes, None)
self.assertAllEqual(fn_op(x, x), None)
示例2: validate
def validate(indexed_slice):
def f():
return indexed_slice
output = function.defun(f)()
self.assertTrue(isinstance(output, ops.IndexedSlices))
self.assertAllEqual(indexed_slice.values, output.values)
self.assertAllEqual(indexed_slice.indices, output.indices)
self.assertAllEqual(indexed_slice.dense_shape, output.dense_shape)
self.assertEqual(
function.make_defun_op(f).output_shapes, indexed_slice.values.shape)
示例3: testBasicDefunOpGraphMode
def testBasicDefunOpGraphMode(self):
matmul = function.defun(math_ops.matmul)
def sq(a):
return matmul(a, a)
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
sq_op = function.make_defun_op(sq, t)
self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2]))
out = sq_op(t)
self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
示例4: testDefunOpGraphModeWithGradients
def testDefunOpGraphModeWithGradients(self):
v = resource_variable_ops.ResourceVariable(1.0, name='v')
def step():
def inner():
return v * v
return backprop.implicit_grad(inner)()[0][0]
step_op = function.make_defun_op(step)
self.assertEqual(step_op.output_dtypes, dtypes.float32)
self.assertEqual(step_op.output_shapes, tensor_shape.TensorShape([]))
self.assertAllEqual(step_op(), 2.0)
示例5: testNestedInputsDefunOpGraphMode
def testNestedInputsDefunOpGraphMode(self):
matmul = function.defun(math_ops.matmul)
pair = collections.namedtuple('pair', ['a', 'b'])
def a_times_b(inputs):
return matmul(inputs.a['a'], inputs.b['b'])
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
inputs = pair({'a': t}, {'b': t})
sq_op = function.make_defun_op(a_times_b, inputs)
self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2]))
out = sq_op(inputs)
self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
示例6: testNestedOutputDefunOpGraphMode
def testNestedOutputDefunOpGraphMode(self):
matmul = function.defun(math_ops.matmul)
def sq(a):
return (matmul(a, a), {'b': constant_op.constant(1.0)})
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
sq_op = function.make_defun_op(sq, t)
self.assertEqual(sq_op.output_shapes,
(tensor_shape.TensorShape([2, 2]),
{'b': tensor_shape.TensorShape([])}))
self.assertEqual(sq_op.output_dtypes,
(dtypes.float32, {'b': dtypes.float32}))
(a, b) = sq_op(t)
self.assertAllEqual(a, math_ops.matmul(t, t).numpy())
self.assertAllEqual(b['b'].numpy(), 1.0)
示例7: execute
def execute(self, fn, *args, **kwargs):
"""Execute function `fn(*args, **kwargs)` inside the CriticalSection.
Args:
fn: The function to execute. Must return at least one tensor.
*args: Additional positional arguments to `fn`.
**kwargs: Additional keyword arguments to `fn`.
Several keywords are reserved for `execute`. These are:
- name; The name to use when creating the execute operation.
- exclusive_resource_access; Whether the resources required by
`fn` should be exclusive to this `CriticalSection`. Default: `True`.
You may want to set this to `False` if you will be accessing a
resource in read-only mode in two different CriticalSections.
Returns:
The tensors returned from `fn(*args, **kwargs)`.
Raises:
ValueError: If `fn` attempts to use this `CriticalSection` in any nested
way.
ValueError: If `exclusive_resource_access` is not provided (is `True`) and
another `CriticalSection` has an execution requesting the same
resources as in `*args`, `**kwargs`, and any additionaly captured
inputs in `fn`. Note, even if `exclusive_resource_access` is `True`,
if another execution in another `CriticalSection` was created without
`exclusive_resource_access=True`, a `ValueError` will be raised.
"""
name = kwargs.pop("name", None)
exclusive_resource_access = kwargs.pop("exclusive_resource_access", True)
args = nest.map_structure(ops.convert_to_tensor, args)
with ops.name_scope(name, "critical_section_execute", []):
fn_op = function.make_defun_op(fn, *args, **kwargs)
flat_dtypes = nest.flatten(fn_op.output_dtypes)
flat_shapes = nest.flatten(fn_op.output_shapes)
all_inputs = nest.flatten(args) + fn_op.captured_inputs
if self._handle in all_inputs:
raise ValueError("The function fn attempts to access the "
"CriticalSection in which it would be running. This "
"is illegal and would cause deadlocks. "
"CriticalSection: %s." % self._handle)
if context.in_graph_mode():
# Collections and op introspection does not work in eager
# mode. This is generally ok; since eager mode (as of
# writing) executes sequentially anyway.
all_input_resources = [
x for x in all_inputs if x.dtype == dtypes.resource]
for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS):
if sg.op.inputs[0].name == self._handle.name:
# Other executions in the same critical section are allowed.
continue
if not (exclusive_resource_access or sg.exclusive_resource_access):
# Neither execution requested exclusive access.
continue
sg_input_names = [y.name for y in sg.op.inputs[1:]]
for res in all_input_resources:
if res.name in sg_input_names:
raise ValueError(
"This execution would access resource %s; but either this "
"execution (CriticalSection: %s) or Execution '%s' "
"(CriticalSection: %s) requested exclusive resource access "
"of this resource for their critical section. Did you mean "
"to call execute with keyword argument "
"exclusive_resource_access=False?"
% (res.name,
self.name,
sg.op.name,
sg.op.inputs[0].op.name))
flat_outputs = gen_resource_variable_ops.execute_in_critical_section(
critical_section=self._handle,
arguments=all_inputs,
f=fn_op,
output_types=flat_dtypes,
output_shapes=flat_shapes)
if context.in_graph_mode():
if isinstance(flat_outputs, ops.Operation):
flat_outputs = [flat_outputs]
op = (flat_outputs[0].op if isinstance(flat_outputs[0], ops.Tensor)
else flat_outputs[0])
signature = _ExecutionSignature(
op=op,
exclusive_resource_access=exclusive_resource_access)
ops.add_to_collections(
CRITICAL_SECTION_EXECUTIONS, signature)
return (flat_outputs[0]
if (len(flat_outputs) == 1
and isinstance(flat_outputs[0], ops.Operation))
else nest.pack_sequence_as(fn_op.output_dtypes, flat_outputs))