当前位置: 首页>>代码示例>>Python>>正文


Python function.make_defun_op函数代码示例

本文整理汇总了Python中tensorflow.python.eager.function.make_defun_op函数的典型用法代码示例。如果您正苦于以下问题:Python make_defun_op函数的具体用法?Python make_defun_op怎么用?Python make_defun_op使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了make_defun_op函数的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: testDefunOpGraphModeNoneOutput

  def testDefunOpGraphModeNoneOutput(self):
    def fn(unused_a, unused_b):
      return None

    x = constant_op.constant(1)
    fn_op = function.make_defun_op(fn, x, x)

    self.assertEqual(fn_op.output_dtypes, None)
    self.assertEqual(fn_op.output_shapes, None)
    self.assertAllEqual(fn_op(x, x), None)
开发者ID:StephenOman,项目名称:tensorflow,代码行数:10,代码来源:function_test.py

示例2: validate

    def validate(indexed_slice):
      def f():
        return indexed_slice

      output = function.defun(f)()
      self.assertTrue(isinstance(output, ops.IndexedSlices))
      self.assertAllEqual(indexed_slice.values, output.values)
      self.assertAllEqual(indexed_slice.indices, output.indices)
      self.assertAllEqual(indexed_slice.dense_shape, output.dense_shape)

      self.assertEqual(
          function.make_defun_op(f).output_shapes, indexed_slice.values.shape)
开发者ID:StephenOman,项目名称:tensorflow,代码行数:12,代码来源:function_test.py

示例3: testBasicDefunOpGraphMode

  def testBasicDefunOpGraphMode(self):
    matmul = function.defun(math_ops.matmul)

    def sq(a):
      return matmul(a, a)

    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])

    sq_op = function.make_defun_op(sq, t)

    self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2]))
    out = sq_op(t)
    self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
开发者ID:StephenOman,项目名称:tensorflow,代码行数:13,代码来源:function_test.py

示例4: testDefunOpGraphModeWithGradients

  def testDefunOpGraphModeWithGradients(self):
    v = resource_variable_ops.ResourceVariable(1.0, name='v')

    def step():
      def inner():
        return v * v

      return backprop.implicit_grad(inner)()[0][0]

    step_op = function.make_defun_op(step)

    self.assertEqual(step_op.output_dtypes, dtypes.float32)
    self.assertEqual(step_op.output_shapes, tensor_shape.TensorShape([]))
    self.assertAllEqual(step_op(), 2.0)
开发者ID:StephenOman,项目名称:tensorflow,代码行数:14,代码来源:function_test.py

示例5: testNestedInputsDefunOpGraphMode

  def testNestedInputsDefunOpGraphMode(self):
    matmul = function.defun(math_ops.matmul)

    pair = collections.namedtuple('pair', ['a', 'b'])
    def a_times_b(inputs):
      return matmul(inputs.a['a'], inputs.b['b'])

    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])

    inputs = pair({'a': t}, {'b': t})
    sq_op = function.make_defun_op(a_times_b, inputs)

    self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2]))
    out = sq_op(inputs)
    self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
开发者ID:dananjayamahesh,项目名称:tensorflow,代码行数:15,代码来源:function_test.py

示例6: testNestedOutputDefunOpGraphMode

  def testNestedOutputDefunOpGraphMode(self):
    matmul = function.defun(math_ops.matmul)

    def sq(a):
      return (matmul(a, a), {'b': constant_op.constant(1.0)})

    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])

    sq_op = function.make_defun_op(sq, t)

    self.assertEqual(sq_op.output_shapes,
                     (tensor_shape.TensorShape([2, 2]),
                      {'b': tensor_shape.TensorShape([])}))
    self.assertEqual(sq_op.output_dtypes,
                     (dtypes.float32, {'b': dtypes.float32}))
    (a, b) = sq_op(t)
    self.assertAllEqual(a, math_ops.matmul(t, t).numpy())
    self.assertAllEqual(b['b'].numpy(), 1.0)
开发者ID:StephenOman,项目名称:tensorflow,代码行数:18,代码来源:function_test.py

示例7: execute

  def execute(self, fn, *args, **kwargs):
    """Execute function `fn(*args, **kwargs)` inside the CriticalSection.

    Args:
      fn: The function to execute.  Must return at least one tensor.
      *args: Additional positional arguments to `fn`.
      **kwargs: Additional keyword arguments to `fn`.
        Several keywords are reserved for `execute`.  These are:

        - name; The name to use when creating the execute operation.
        - exclusive_resource_access; Whether the resources required by
          `fn` should be exclusive to this `CriticalSection`.  Default: `True`.
          You may want to set this to `False` if you will be accessing a
          resource in read-only mode in two different CriticalSections.

    Returns:
      The tensors returned from `fn(*args, **kwargs)`.

    Raises:
      ValueError: If `fn` attempts to use this `CriticalSection` in any nested
        way.
      ValueError: If `exclusive_resource_access` is not provided (is `True`) and
        another `CriticalSection` has an execution requesting the same
        resources as in `*args`, `**kwargs`, and any additionaly captured
        inputs in `fn`.  Note, even if `exclusive_resource_access` is `True`,
        if another execution in another `CriticalSection` was created without
        `exclusive_resource_access=True`, a `ValueError` will be raised.
    """
    name = kwargs.pop("name", None)
    exclusive_resource_access = kwargs.pop("exclusive_resource_access", True)

    args = nest.map_structure(ops.convert_to_tensor, args)
    with ops.name_scope(name, "critical_section_execute", []):
      fn_op = function.make_defun_op(fn, *args, **kwargs)
      flat_dtypes = nest.flatten(fn_op.output_dtypes)
      flat_shapes = nest.flatten(fn_op.output_shapes)
      all_inputs = nest.flatten(args) + fn_op.captured_inputs
      if self._handle in all_inputs:
        raise ValueError("The function fn attempts to access the "
                         "CriticalSection in which it would be running.  This "
                         "is illegal and would cause deadlocks.  "
                         "CriticalSection: %s." % self._handle)

      if context.in_graph_mode():
        # Collections and op introspection does not work in eager
        # mode.  This is generally ok; since eager mode (as of
        # writing) executes sequentially anyway.
        all_input_resources = [
            x for x in all_inputs if x.dtype == dtypes.resource]
        for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS):
          if sg.op.inputs[0].name == self._handle.name:
            # Other executions in the same critical section are allowed.
            continue
          if not (exclusive_resource_access or sg.exclusive_resource_access):
            # Neither execution requested exclusive access.
            continue
          sg_input_names = [y.name for y in sg.op.inputs[1:]]
          for res in all_input_resources:
            if res.name in sg_input_names:
              raise ValueError(
                  "This execution would access resource %s; but either this "
                  "execution (CriticalSection: %s) or Execution '%s' "
                  "(CriticalSection: %s) requested exclusive resource access "
                  "of this resource for their critical section.  Did you mean "
                  "to call execute with keyword argument "
                  "exclusive_resource_access=False?"
                  % (res.name,
                     self.name,
                     sg.op.name,
                     sg.op.inputs[0].op.name))

      flat_outputs = gen_resource_variable_ops.execute_in_critical_section(
          critical_section=self._handle,
          arguments=all_inputs,
          f=fn_op,
          output_types=flat_dtypes,
          output_shapes=flat_shapes)

      if context.in_graph_mode():
        if isinstance(flat_outputs, ops.Operation):
          flat_outputs = [flat_outputs]
        op = (flat_outputs[0].op if isinstance(flat_outputs[0], ops.Tensor)
              else flat_outputs[0])
        signature = _ExecutionSignature(
            op=op,
            exclusive_resource_access=exclusive_resource_access)
        ops.add_to_collections(
            CRITICAL_SECTION_EXECUTIONS, signature)

      return (flat_outputs[0]
              if (len(flat_outputs) == 1
                  and isinstance(flat_outputs[0], ops.Operation))
              else nest.pack_sequence_as(fn_op.output_dtypes, flat_outputs))
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:93,代码来源:critical_section_ops.py


注:本文中的tensorflow.python.eager.function.make_defun_op函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。