当前位置: 首页>>代码示例>>Python>>正文


Python control_flow_ops.merge函数代码示例

本文整理汇总了Python中tensorflow.python.ops.control_flow_ops.merge函数的典型用法代码示例。如果您正苦于以下问题:Python merge函数的具体用法?Python merge怎么用?Python merge使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了merge函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: testMergeShapes

  def testMergeShapes(self):
    # All inputs unknown.
    p1 = tf.placeholder(tf.float32)
    p2 = tf.placeholder(tf.float32)
    p3 = tf.placeholder(tf.float32)
    m, index = control_flow_ops.merge([p1, p2, p3])
    self.assertIs(None, m.get_shape().ndims)
    self.assertEqual([], index.get_shape())

    # All inputs known but different.
    p1 = tf.placeholder(tf.float32, shape=[1, 2])
    p2 = tf.placeholder(tf.float32, shape=[2, 1])
    m, index = control_flow_ops.merge([p1, p2])
    self.assertIs(None, m.get_shape().ndims)
    self.assertEqual([], index.get_shape())

    # All inputs known but same.
    p1 = tf.placeholder(tf.float32, shape=[1, 2])
    p2 = tf.placeholder(tf.float32, shape=[1, 2])
    m, index = control_flow_ops.merge([p1, p2])
    self.assertEqual([1, 2], m.get_shape())
    self.assertEqual([], index.get_shape())

    # Possibly the same but not guaranteed.
    p1 = tf.placeholder(tf.float32, shape=[1, 2])
    p2 = tf.placeholder(tf.float32)
    p2.set_shape([None, 2])
    m, index = control_flow_ops.merge([p1, p2])
    self.assertIs(None, m.get_shape().ndims)
    self.assertEqual([], index.get_shape())
开发者ID:hypatiad,项目名称:tensorflow,代码行数:30,代码来源:control_flow_ops_py_test.py

示例2: testLoop_1

    def testLoop_1(self):
        with self.test_session():
            zero = tf.convert_to_tensor(0)
            one = tf.convert_to_tensor(1)
            n = tf.constant(10)

            enter_zero = control_flow_ops.enter(zero, "foo_1", False)
            enter_one = control_flow_ops.enter(one, "foo_1", False)
            enter_n = control_flow_ops.enter(n, "foo_1", False)
            merge_zero = control_flow_ops.merge([enter_zero, enter_zero], name="merge_zero")[0]
            merge_one = control_flow_ops.merge([enter_one, enter_one], name="merge_one")[0]
            merge_n = control_flow_ops.merge([enter_n, enter_n], name="merge_n")[0]
            less_op = tf.less(merge_n, merge_n)
            cond_op = control_flow_ops.loop_cond(less_op)
            switch_zero = control_flow_ops.switch(merge_zero, cond_op)
            switch_one = control_flow_ops.switch(merge_one, cond_op)
            switch_n = control_flow_ops.switch(merge_n, cond_op)
            next_zero = control_flow_ops.next_iteration(switch_zero[1])
            next_one = control_flow_ops.next_iteration(switch_one[1])
            next_n = control_flow_ops.next_iteration(switch_n[1])
            merge_zero.op._update_input(1, next_zero)
            merge_one.op._update_input(1, next_one)
            merge_n.op._update_input(1, next_n)
            exit_n = control_flow_ops.exit(switch_n[0])

            result = exit_n.eval()
        self.assertAllEqual(10, result)
开发者ID:peace195,项目名称:tensorflow,代码行数:27,代码来源:control_flow_ops_py_test.py

示例3: testLoop_2

  def testLoop_2(self):
    with self.test_session():
      zero = tf.constant(0)
      one = tf.constant(1)
      n = tf.constant(10)

      enter_i = control_flow_ops.enter(zero, "foo", False)
      enter_one = control_flow_ops.enter(one, "foo", True)
      enter_n = control_flow_ops.enter(n, "foo", True)

      merge_i = control_flow_ops.merge([enter_i, enter_i])[0]

      less_op = tf.less(merge_i, enter_n)
      cond_op = control_flow_ops.loop_cond(less_op)
      switch_i = control_flow_ops.switch(merge_i, cond_op)

      add_i = tf.add(switch_i[1], enter_one)

      with tf.device("/gpu:0"):
        next_i = control_flow_ops.next_iteration(add_i)
      merge_i.op._update_input(1, next_i)

      exit_i = control_flow_ops.exit(switch_i[0])
      result = exit_i.eval()
    self.assertAllEqual(10, result)
开发者ID:hypatiad,项目名称:tensorflow,代码行数:25,代码来源:control_flow_ops_py_test.py

示例4: apply_with_random_selector

def apply_with_random_selector(image, func, num_cases):
    """random select a mode case to func(image, case)"""
    # random select a mode
    sel = tf.random_uniform([], maxval=num_cases, dtype=tf.int32)
    return control_flow_ops.merge([
        func(control_flow_ops.switch(image, tf.equal(case, sel))[1], case)
         for case in range(num_cases)])[0]
开发者ID:beacandler,项目名称:tf-slim-demo,代码行数:7,代码来源:inception_preprocessing.py

示例5: _process_switch

  def _process_switch(self, switch_op, ops_which_must_run,
                      last_op_using_resource_tensor, merge_for_resource):
    """Processes a switch node for a resource input.

    When tensorflow creates a cond, it creates a control flow context for each
    branch of the cond. Each external tensor accessed by that branch is routed
    through a switch op, which gets created in the graph _after_ the op which
    uses that tensor get created.

    If the resource comes from another switch op we process that one first.

    _process_switch creates a corresponding merge node for the switch node. This
    merge node is added to the outer control flow context of the switch
    node. We also ensure that:

      1. The switch node executes after the previous op which used the resource
         tensor

      2. Any op which uses a resource output of the switch node executes before
         the merge for the switch node.

      3. The next op which uses the input resource to the switch node (which
         might be another switch node for the other branch of the conditional)
         will execute after the merge node is done.

      4. The merge node is marked as must_run so it will run even if no
         subsequent operation uses the resource.

    Args:
      switch_op: the switch op to be processed
      ops_which_must_run: the set of ops which must run
      last_op_using_resource_tensor: map from resource tensor to last op using
        it
      merge_for_resource: map from resource tensor to merge which must follow
        all usages of it.
    """
    inp = switch_op.inputs[0]
    if inp.dtype == dtypes_module.resource and inp.op.type == "Switch":
      self._process_switch(inp.op, ops_which_must_run,
                           last_op_using_resource_tensor, merge_for_resource)
    if switch_op.outputs[0] in merge_for_resource:
      return
    new_merge = control_flow_ops.merge(switch_op.outputs,
                                       name="artificial_merge")
    new_merge[0].op._control_flow_context = (  # pylint: disable=protected-access
        switch_op._control_flow_context.outer_context)  # pylint: disable=protected-access
    # Ensures the merge always runs
    ops_which_must_run.add(new_merge[0].op)
    if inp in last_op_using_resource_tensor:
      # Ensures the switch exectutes after the previous op using the resource.
      switch_op._add_control_input(last_op_using_resource_tensor[inp])  # pylint: disable=protected-access
    # Ensure the next op outside the cond happens after the merge.
    last_op_using_resource_tensor[inp] = new_merge[0].op
    if inp in merge_for_resource:
      merge_for_resource[inp]._add_control_input(new_merge[0].op)  # pylint: disable=protected-access
    for o in switch_op.outputs:
      # Ensures the merge will execute after all ops inside the cond
      merge_for_resource[o] = new_merge[0].op
开发者ID:Jackiefan,项目名称:tensorflow,代码行数:58,代码来源:function.py

示例6: _testSwitchMerge_1

  def _testSwitchMerge_1(self, use_gpu):
    with self.test_session(use_gpu=use_gpu):
      data = tf.constant([1, 2, 3, 4, 5, 6], name="data")
      ports = tf.convert_to_tensor(True, name="ports")
      switch_op = control_flow_ops.switch(data, ports)
      merge_op = control_flow_ops.merge(switch_op)[0]

      result = merge_op.eval()
    self.assertAllEqual(np.arange(1, 7), result)
开发者ID:hypatiad,项目名称:tensorflow,代码行数:9,代码来源:control_flow_ops_py_test.py

示例7: testSwitchMergeIdentity_1

  def testSwitchMergeIdentity_1(self):
    with self.test_session():
      data = tf.constant([1, 2, 3, 4, 5, 6], name="data")
      ports = tf.convert_to_tensor(True, name="ports")
      switch_op = control_flow_ops.switch(data, ports)
      merge_op = control_flow_ops.merge(switch_op)[0]
      id_op = tf.identity(merge_op)

      result = id_op.eval()
    self.assertAllEqual(np.arange(1, 7), result)
开发者ID:hypatiad,项目名称:tensorflow,代码行数:10,代码来源:control_flow_ops_py_test.py

示例8: testSwitchMergeLess_1

  def testSwitchMergeLess_1(self):
    with self.test_session():
      data = tf.constant([1, 2, 3, 4, 5, 6], name="data")
      zero = tf.convert_to_tensor(0)
      one = tf.convert_to_tensor(1)
      less_op = tf.less(zero, one)
      switch_op = control_flow_ops.switch(data, less_op)
      merge_op = control_flow_ops.merge(switch_op)[0]

      result = merge_op.eval()
    self.assertAllEqual(np.arange(1, 7), result)
开发者ID:hypatiad,项目名称:tensorflow,代码行数:11,代码来源:control_flow_ops_py_test.py

示例9: testSwitchMergeAddIdentity_1

  def testSwitchMergeAddIdentity_1(self):
    with self.test_session():
      data = tf.constant([1, 2, 3, 4, 5, 6], name="data")
      ports = tf.convert_to_tensor(True, name="ports")
      switch_op = control_flow_ops.switch(data, ports)
      one = tf.constant(1)
      add_op = tf.add(switch_op[0], one)
      id_op = tf.identity(switch_op[1])
      merge_op = control_flow_ops.merge([add_op, id_op])[0]

      result = merge_op.eval()
    self.assertAllEqual(np.arange(1, 7), result)
开发者ID:hypatiad,项目名称:tensorflow,代码行数:12,代码来源:control_flow_ops_py_test.py

示例10: testSwitchMergeAddMul_1

  def testSwitchMergeAddMul_1(self):
    with self.test_session():
      data = tf.constant([1, 2, 3, 4, 5, 6], name="data")
      ports = tf.convert_to_tensor(True, name="ports")
      switch_op = control_flow_ops.switch(data, ports)
      one = tf.constant(1)
      add_op = tf.add(switch_op[0], one)
      five = tf.constant(5)
      mul_op = tf.mul(switch_op[1], five)
      merge_op = control_flow_ops.merge([add_op, mul_op])[0]

      result = merge_op.eval()
    self.assertAllEqual(np.array([x * 5 for x in [1, 2, 3, 4, 5, 6]]), result)
开发者ID:hypatiad,项目名称:tensorflow,代码行数:13,代码来源:control_flow_ops_py_test.py

示例11: testSwitchMergeIndexedSlices

  def testSwitchMergeIndexedSlices(self):
    with self.test_session():
      values = tf.constant([1, 2, 3, 4, 5, 6])
      indices = tf.constant([0, 2, 4, 6, 8, 10])
      data = tf.IndexedSlices(values, indices)
      pred = tf.convert_to_tensor(True)
      switch_op = control_flow_ops.switch(data, pred)
      merge_op = control_flow_ops.merge(switch_op)[0]

      val = merge_op.values.eval()
      ind = merge_op.indices.eval()
    self.assertAllEqual(np.arange(1, 7), val)
    self.assertAllEqual(np.arange(0, 12, 2), ind)
开发者ID:hypatiad,项目名称:tensorflow,代码行数:13,代码来源:control_flow_ops_py_test.py

示例12: testLoop_false

    def testLoop_false(self):
        with self.test_session():
            false = tf.convert_to_tensor(False)
            n = tf.constant(10)

            enter_false = control_flow_ops.enter(false, "foo_1", False)
            enter_n = control_flow_ops.enter(n, "foo_1", False)

            merge_n = control_flow_ops.merge([enter_n], name="merge_n")[0]
            switch_n = control_flow_ops.switch(merge_n, enter_false)
            exit_n = control_flow_ops.exit(switch_n[0])

            result = exit_n.eval()
        self.assertAllEqual(10, result)
开发者ID:peace195,项目名称:tensorflow,代码行数:14,代码来源:control_flow_ops_py_test.py

示例13: apply_with_random_selector

def apply_with_random_selector(x, func, num_cases):
  """Computes func(x, sel), with sel sampled from [0...num_cases-test].
      Args:
        x: input Tensor.
        func: Python function to apply.
        num_cases: Python int32, number of cases to sample sel from.
      Returns:
        The result of func(x, sel), where func receives the value of the
        selector as a python integer, but sel is sampled dynamically.
      """
  sel = tf.random_uniform([], maxval=num_cases, dtype=tf.int32)
  # Pass the real x only to one of the func calls.
  return control_flow_ops.merge([
      func(control_flow_ops.switch(x, tf.equal(sel, case))[1], case)
      for case in range(num_cases)
  ])[0]
开发者ID:veyvin,项目名称:tensorflow-learn,代码行数:16,代码来源:image_pre_test.py

示例14: apply_with_random_selector

def apply_with_random_selector(x, func, num_cases):
  """Computes func(x, sel), with sel sampled from [0...num_cases-1].

  TODO(coreylynch): add as a dependency, when slim or tensorflow/models are
  pipfied.
  Source:
  https://raw.githubusercontent.com/tensorflow/models/a9d0e6e8923a4/slim/preprocessing/inception_preprocessing.py

  Args:
    x: input Tensor.
    func: Python function to apply.
    num_cases: Python int32, number of cases to sample sel from.
  Returns:
    The result of func(x, sel), where func receives the value of the
    selector as a python integer, but sel is sampled dynamically.
  """
  sel = tf.random_uniform([], maxval=num_cases, dtype=tf.int32)
  # Pass the real x only to one of the func calls.
  return control_flow_ops.merge([
      func(control_flow_ops.switch(x, tf.equal(sel, case))[1], case)
      for case in range(num_cases)])[0]
开发者ID:NoPointExc,项目名称:models,代码行数:21,代码来源:preprocessing.py

示例15: create_op


#.........这里部分代码省略.........
    if self._return_as_is or op_type in _PASS_THROUGH_OPS:
      return self._wrap(super(ImperativeGraph, self).create_op(*args, **kwargs))

    if not output_dtypes:
      return self._wrap(
          super(ImperativeGraph, self).create_op(*args, **kwargs))

    output_has_ref = any([dtype._is_ref_dtype for dtype in output_dtypes])  # pylint: disable=protected-access

    if output_has_ref:
      if op_type not in _REF_OPS_WHITELIST:
        raise errors.UnimplementedError(None, None,
                                        op_type + ' op not supported in '
                                        'imperative graph')

      ret = super(ImperativeGraph, self).create_op(*args, **kwargs)

      if self._in_variable_creation:
        if op_type == 'Assign':
          self.add_pending_init(ret)

      return self._wrap(ret)

    with self.return_as_is():
      # Declares the variables to hold the output values of this op.
      op_output_var = [state_ops.variable_op_v2(
          tensor_shape.TensorShape(None), dtype, container=self._name)
                       for dtype in output_dtypes]
      # Ops to free the resources used by the temporary cache variables.
      # The following two ops are created for each cache variable,
      # having no control dependencies on any other ops :
      # var_handle_op ----> destroy_resource_op
      for dtype, v in zip(output_dtypes, op_output_var):
        with ops.control_dependencies(None):
          self._variable_cleanup_ops += [
              gen_resource_variable_ops.destroy_resource_op(
                  gen_resource_variable_ops.var_handle_op(
                      dtype, tensor_shape.TensorShape(None),
                      container=self._name, shared_name=v.op.name),
                  ignore_lookup_error=True)]

      # Create the conditional to run the original op only when the variable
      # corresponding to the first output is not initialized.
      inited = state_ops.is_variable_initialized(op_output_var[0])
      v_f, v_t = control_flow_ops.ref_switch(op_output_var[0], inited)
      # pylint: disable=protected-access
      v_f_op = gen_array_ops._ref_identity(v_f)
      v_t_op = gen_array_ops._ref_identity(v_t)
      # pylint: enable=protected-access

      with ops.control_dependencies([v_f_op.op]):
        # Create the original op
        orig_op = self._wrap(
            super(ImperativeGraph, self).create_op(*args, **kwargs))
      shapes = [val.get_shape() for val in orig_op.outputs]

      controls = []
      for var, val in zip(op_output_var, orig_op.outputs):
        if (not val.get_shape().is_fully_defined() or
            val.get_shape().num_elements() > 0):
          assign_op = state_ops.assign(var, val, validate_shape=False)
          assign_op.set_shape(val.get_shape())
          controls.append(assign_op)

      values = []
      if len(controls) > 1:
        if control_flow_ops.IsSwitch(orig_op):
          # pylint: disable=protected-access
          controls = gen_control_flow_ops._ref_merge(controls)
          # pylint: enable=protected-access
        else:
          controls = control_flow_ops.tuple(controls)

      for var, val in zip(op_output_var, orig_op.outputs):
        with ops.control_dependencies(controls):
          with self.colocate_with(v_f_op):
            real_val = array_ops.identity(val)
        with ops.control_dependencies([v_t_op.op]):
          with self.colocate_with(v_t_op):
            stored_val = array_ops.identity(var)
          stored_val.set_shape(val.get_shape())
          real_val, _ = control_flow_ops.merge([real_val, stored_val])
        real_val.op.node_def.attr['_gradient_op_type'].CopyFrom(
            attr_value_pb2.AttrValue(s=compat.as_bytes(self._merge_op_type)))
        values.append(real_val)

      for i, _ in enumerate(shapes):
        values[i].set_shape(shapes[i])
      self._outputs_map[orig_op.name] = values
      try:
        self._gradient_function_map[orig_op.name] = ops.get_gradient_function(
            orig_op)
      except (KeyError, LookupError):
        pass
      else:
        orig_op.node_def.attr['_gradient_op_type'].CopyFrom(
            attr_value_pb2.AttrValue(
                s=compat.as_bytes(self._imperative_op_type)))

      return MultiOutputOperation(values, orig_op)
开发者ID:chdinh,项目名称:tensorflow,代码行数:101,代码来源:imperative_graph.py


注:本文中的tensorflow.python.ops.control_flow_ops.merge函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。