当前位置: 首页>>代码示例>>Python>>正文


Python control_flow_ops.while_loop方法代码示例

本文整理汇总了Python中tensorflow.python.ops.control_flow_ops.while_loop方法的典型用法代码示例。如果您正苦于以下问题:Python control_flow_ops.while_loop方法的具体用法?Python control_flow_ops.while_loop怎么用?Python control_flow_ops.while_loop使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.python.ops.control_flow_ops的用法示例。


在下文中一共展示了control_flow_ops.while_loop方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: testDebugWhileLoopWatchingWholeGraphWorks

# 需要导入模块: from tensorflow.python.ops import control_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.control_flow_ops import while_loop [as 别名]
def testDebugWhileLoopWatchingWholeGraphWorks(self):
    with session.Session() as sess:
      loop_body = lambda i: math_ops.add(i, 2)
      loop_cond = lambda i: math_ops.less(i, 16)

      i = constant_op.constant(10, name="i")
      loop = control_flow_ops.while_loop(loop_cond, loop_body, [i])

      run_options = config_pb2.RunOptions(output_partition_graphs=True)
      debug_utils.watch_graph(run_options,
                              sess.graph,
                              debug_urls=self._debug_urls())
      run_metadata = config_pb2.RunMetadata()
      self.assertEqual(
          16, sess.run(loop, options=run_options, run_metadata=run_metadata))

      dump = debug_data.DebugDumpDir(
          self._dump_root, partition_graphs=run_metadata.partition_graphs)

      self.assertEqual(
          [[10]], dump.get_tensors("while/Enter", 0, "DebugIdentity"))
      self.assertEqual(
          [[12], [14], [16]],
          dump.get_tensors("while/NextIteration", 0, "DebugIdentity")) 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:26,代码来源:session_debug_testlib.py

示例2: generate_infeed_enqueue_ops_and_dequeue_fn

# 需要导入模块: from tensorflow.python.ops import control_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.control_flow_ops import while_loop [as 别名]
def generate_infeed_enqueue_ops_and_dequeue_fn(self):
    """Generates infeed enqueue ops and dequeue_fn."""
    # While tf.while_loop is called, the body function, which invokes
    # `enqueue_fn` passed in, is called to construct the graph. So, input_fn
    # structure is recorded.
    enqueue_ops, all_hooks, run_infeed_loop_on_coordinator = (
        self._invoke_input_fn_and_record_structure())

    self._validate_input_pipeline()

    def dequeue_fn():
      """dequeue_fn is used by TPU to retrieve the tensors."""
      # In the model-parallel case, both the host-side and device-side
      # computations must agree on the core on which infeed takes place. We
      # choose to perform infeed on logical core 0 of each replica.
      values = self._infeed_queue.generate_dequeue_op(tpu_device=0)
      # The unflatten process uses the structure information recorded above.
      return self._inputs_structure_recorder.unflatten_features_and_labels(
          values)

    return (enqueue_ops, dequeue_fn, all_hooks, run_infeed_loop_on_coordinator) 
开发者ID:ymcui,项目名称:Chinese-XLNet,代码行数:23,代码来源:tpu_estimator.py

示例3: _wrap_computation_in_while_loop

# 需要导入模块: from tensorflow.python.ops import control_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.control_flow_ops import while_loop [as 别名]
def _wrap_computation_in_while_loop(device, op_fn):
  """Wraps the ops generated by `op_fn` in tf.while_loop."""

  def computation(i):
    with ops.control_dependencies(op_fn()):
      return i + 1

  iterations_per_loop_var = _create_or_get_iterations_per_loop()
  # By setting parallel_iterations=1, the parallel execution in while_loop is
  # basically turned off.
  with ops.device(device):
    iterations = array_ops.identity(iterations_per_loop_var)
    return control_flow_ops.while_loop(
        lambda i: i < iterations,
        computation, [constant_op.constant(0)],
        parallel_iterations=1) 
开发者ID:ymcui,项目名称:Chinese-XLNet,代码行数:18,代码来源:tpu_estimator.py

示例4: _wrap_computation_in_while_loop_with_stopping_signals

# 需要导入模块: from tensorflow.python.ops import control_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.control_flow_ops import while_loop [as 别名]
def _wrap_computation_in_while_loop_with_stopping_signals(device, op_fn):
  """Wraps the ops generated by `op_fn` in tf.while_loop."""

  def cond(scalar_stopping_signal):
    return math_ops.logical_not(
        _StopSignals.should_stop(scalar_stopping_signal))

  def computation(unused_scalar_stopping_signal):
    return_value = op_fn()
    execute_ops = return_value['ops']
    signals = return_value['signals']
    with ops.control_dependencies(execute_ops):
      return _StopSignals.as_scalar_stopping_signal(signals)

  # By setting parallel_iterations=1, the parallel execution in while_loop is
  # basically turned off.
  with ops.device(device):
    return control_flow_ops.while_loop(
        cond,
        computation, [_StopSignals.NON_STOPPING_SIGNAL],
        parallel_iterations=1) 
开发者ID:ymcui,项目名称:Chinese-XLNet,代码行数:23,代码来源:tpu_estimator.py

示例5: testIndexedSlicesGradient

# 需要导入模块: from tensorflow.python.ops import control_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.control_flow_ops import while_loop [as 别名]
def testIndexedSlicesGradient(self):
    with ops.Graph().as_default():
      embedding_matrix = tf.get_variable(
          "embedding_matrix", [5, 5],
          initializer=tf.random_normal_initializer())
      def Cond(it, _):
        return it < 5
      def Body(it, cost):
        embedding = embedding_ops.embedding_lookup(embedding_matrix + 0.0, [0])
        cost += tf.reduce_sum(embedding)
        return it + 1, cost
      _, cost = control_flow_ops.while_loop(
          Cond, Body, [tf.constant(0), tf.constant(0.0)])
      optimizer = momentum.MomentumOptimizer(0.1, 0.9)
      train_op = optimizer.minimize(cost)
      with self.test_session() as sess:
        sess.run(tf.global_variables_initializer())
        for _ in range(10):
          sess.run([train_op]) 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:21,代码来源:control_flow_ops_test.py

示例6: testIndexedSlicesWithDynamicShapeGradientInWhileLoop

# 需要导入模块: from tensorflow.python.ops import control_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.control_flow_ops import while_loop [as 别名]
def testIndexedSlicesWithDynamicShapeGradientInWhileLoop(self):
    for dtype in [dtypes.float32, dtypes.float64]:
      with self.test_session() as sess:
        inputs = tf.placeholder(dtype=dtype)
        initial_outputs = tf.TensorArray(dtype=dtype, dynamic_size=True,
                                         size=1)
        initial_i = tf.constant(0, dtype=dtypes.int32)

        def Cond(i, _):
          return i < tf.size(inputs)  # pylint: disable=cell-var-from-loop

        def Body(i, outputs):
          x = tf.gather(inputs, i)  # pylint: disable=cell-var-from-loop
          outputs = outputs.write(i, x)
          return i + 1, outputs

        _, outputs = tf.while_loop(Cond, Body, [initial_i, initial_outputs])

        outputs = tf.reduce_sum(outputs.pack())
        r = tf.gradients([outputs], [inputs])[0]
        grad_wr_inputs = ops.convert_to_tensor(r)
        o, grad = sess.run([outputs, grad_wr_inputs],
                           feed_dict={inputs: [1, 3, 2]})
        self.assertEquals(o, 6)
        self.assertAllEqual(grad, [1] * 3) 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:27,代码来源:control_flow_ops_test.py

示例7: testDebugWhileLoopWatchingWholeGraphWorks

# 需要导入模块: from tensorflow.python.ops import control_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.control_flow_ops import while_loop [as 别名]
def testDebugWhileLoopWatchingWholeGraphWorks(self):
    with session.Session() as sess:
      loop_body = lambda i: math_ops.add(i, 2)
      loop_cond = lambda i: math_ops.less(i, 16)

      i = constant_op.constant(10, name="i")
      loop = control_flow_ops.while_loop(loop_cond, loop_body, [i])

      loop_result, dump = self._debug_run_and_get_dump(sess, loop)
      self.assertEqual(16, loop_result)

      self.assertEqual(
          [[10]], dump.get_tensors("while/Enter", 0, "DebugIdentity"))
      self.assertEqual(
          [[12], [14], [16]],
          dump.get_tensors("while/NextIteration", 0, "DebugIdentity")) 
开发者ID:PacktPublishing,项目名称:Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda,代码行数:18,代码来源:session_debug_testlib.py

示例8: predict

# 需要导入模块: from tensorflow.python.ops import control_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.control_flow_ops import while_loop [as 别名]
def predict(self, input_x, h_0=None):
        if h_0 is None:
            h_0 = self.h_0
        def _g_recurrence(i, x_t, h_tm1, o_t):
            h_t = self.g_recurrent_unit(x_t, h_tm1)  # hidden_memory_tuple
            o_t = self.g_output_unit(h_t)  # batch x vocab , logits not prob
            x_tp1 = tf.squeeze(tf.slice(input_x, begin=[0, i, 0], size=[self.batch_size_scale, 1, self.num_vocabulary]))
            return i + 1, x_tp1, h_t, o_t

        o_0 = tf.constant(np.zeros(shape=[self.batch_size_scale, self.num_classes]))
        o_0 = tf.cast(o_0, dtype=tf.float32)
        _, _, h_t, output = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, _3: i < self.sequence_length,
            body=_g_recurrence,
            loop_vars=(tf.constant(0, dtype=tf.int32),
                       tf.nn.embedding_lookup(self.one_hot, self.start_token), self.h0, o_0))

        return output 
开发者ID:geek-ai,项目名称:Texygen,代码行数:20,代码来源:GsganDiscriminator.py

示例9: _repeat_range

# 需要导入模块: from tensorflow.python.ops import control_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.control_flow_ops import while_loop [as 别名]
def _repeat_range(counts, name=None):
  """Repeat integers given by range(len(counts)) each the given number of times.

  Example behavior:
  [0, 1, 2, 3] -> [1, 2, 2, 3, 3, 3]

  Args:
    counts: 1D tensor with dtype=int32.
    name: optional name for operation.

  Returns:
    1D tensor with dtype=int32 and dynamic length giving the repeated integers.
  """
  with ops.name_scope(name, 'repeat_range', [counts]) as scope:
    counts = ops.convert_to_tensor(counts, name='counts')

    def cond(unused_output, i):
      return i < size

    def body(output, i):
      value = array_ops.fill(counts[i:i+1], i)
      return (output.write(i, value), i + 1)

    size = array_ops.shape(counts)[0]
    init_output_array = tensor_array_ops.TensorArray(
        dtype=dtypes.int32, size=size, infer_shape=False)
    output_array, num_writes = control_flow_ops.while_loop(
        cond, body, loop_vars=[init_output_array, 0])

    return control_flow_ops.cond(
        num_writes > 0,
        output_array.concat,
        lambda: array_ops.zeros(shape=[0], dtype=dtypes.int32),
        name=scope) 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:36,代码来源:resample.py

示例10: setUpClass

# 需要导入模块: from tensorflow.python.ops import control_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.control_flow_ops import while_loop [as 别名]
def setUpClass(cls):
    cls._dump_root = tempfile.mkdtemp()

    with session.Session() as sess:
      loop_var = constant_op.constant(0, name="while_loop_test/loop_var")
      cond = lambda loop_var: math_ops.less(loop_var, 10)
      body = lambda loop_var: math_ops.add(loop_var, 1)
      while_loop = control_flow_ops.while_loop(
          cond, body, [loop_var], parallel_iterations=1)

      run_options = config_pb2.RunOptions(output_partition_graphs=True)
      debug_url = "file://%s" % cls._dump_root

      watch_opts = run_options.debug_options.debug_tensor_watch_opts

      # Add debug tensor watch for "while/Identity".
      watch = watch_opts.add()
      watch.node_name = "while/Identity"
      watch.output_slot = 0
      watch.debug_ops.append("DebugIdentity")
      watch.debug_urls.append(debug_url)

      # Invoke Session.run().
      run_metadata = config_pb2.RunMetadata()
      sess.run(while_loop, options=run_options, run_metadata=run_metadata)

    cls._debug_dump = debug_data.DebugDumpDir(
        cls._dump_root, partition_graphs=run_metadata.partition_graphs)

    cls._analyzer = analyzer_cli.DebugAnalyzer(cls._debug_dump)
    cls._registry = debugger_cli_common.CommandHandlerRegistry()
    cls._registry.register_command_handler(
        "list_tensors",
        cls._analyzer.list_tensors,
        cls._analyzer.get_help("list_tensors"),
        prefix_aliases=["lt"])
    cls._registry.register_command_handler(
        "print_tensor",
        cls._analyzer.print_tensor,
        cls._analyzer.get_help("print_tensor"),
        prefix_aliases=["pt"]) 
开发者ID:abhisuri97,项目名称:auto-alt-text-lambda-api,代码行数:43,代码来源:analyzer_cli_test.py

示例11: _predict_on_tpu_system

# 需要导入模块: from tensorflow.python.ops import control_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.control_flow_ops import while_loop [as 别名]
def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
  """Executes `model_fn_wrapper` multiple times on all TPU shards."""
  (single_tpu_predict_step, host_calls, captured_scaffold_fn,
   captured_predict_hooks
  ) = model_fn_wrapper.convert_to_single_tpu_predict_step(dequeue_fn)

  def multi_tpu_predict_steps_on_single_shard():

    def cond(scalar_stopping_signal):
      return math_ops.logical_not(
          _StopSignals.should_stop(scalar_stopping_signal))

    inputs = [_StopSignals.NON_STOPPING_SIGNAL]
    outputs = training_loop.while_loop(
        cond, single_tpu_predict_step, inputs=inputs, name=b'loop')
    return outputs

  (compile_op, dummy_predict_op,) = tpu.split_compile_and_shard(
      multi_tpu_predict_steps_on_single_shard,
      inputs=[],
      num_shards=ctx.num_replicas,
      outputs_from_all_shards=False,
      device_assignment=ctx.device_assignment)

  dummy_predict_op = dummy_predict_op[0]
  scaffold = _get_scaffold(captured_scaffold_fn)
  return (compile_op, dummy_predict_op, host_calls, scaffold,
          captured_predict_hooks.get()) 
开发者ID:ymcui,项目名称:Chinese-XLNet,代码行数:30,代码来源:tpu_estimator.py

示例12: compute_gt_cluster_score

# 需要导入模块: from tensorflow.python.ops import control_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.control_flow_ops import while_loop [as 别名]
def compute_gt_cluster_score(pairwise_distances, labels):
  """Compute ground truth facility location score.

  Loop over each unique classes and compute average travel distances.

  Args:
    pairwise_distances: 2-D Tensor of pairwise distances.
    labels: 1-D Tensor of ground truth cluster assignment.

  Returns:
    gt_cluster_score: dtypes.float32 score.
  """
  unique_class_ids = array_ops.unique(labels)[0]
  num_classes = array_ops.size(unique_class_ids)
  iteration = array_ops.constant(0)
  gt_cluster_score = array_ops.constant(0.0, dtype=dtypes.float32)

  def func_cond(iteration, gt_cluster_score):
    del gt_cluster_score  # Unused argument.
    return iteration < num_classes

  def func_body(iteration, gt_cluster_score):
    """Per each cluster, compute the average travel distance."""
    mask = math_ops.equal(labels, unique_class_ids[iteration])
    this_cluster_ids = array_ops.where(mask)
    pairwise_distances_subset = array_ops.transpose(
        array_ops.gather(
            array_ops.transpose(
                array_ops.gather(pairwise_distances, this_cluster_ids)),
            this_cluster_ids))
    this_cluster_score = -1.0 * math_ops.reduce_min(
        math_ops.reduce_sum(
            pairwise_distances_subset, axis=0))
    return iteration + 1, gt_cluster_score + this_cluster_score

  _, gt_cluster_score = control_flow_ops.while_loop(
      func_cond, func_body, [iteration, gt_cluster_score])
  return gt_cluster_score 
开发者ID:google-research,项目名称:tf-slim,代码行数:40,代码来源:metric_learning.py

示例13: setUpClass

# 需要导入模块: from tensorflow.python.ops import control_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.control_flow_ops import while_loop [as 别名]
def setUpClass(cls):
    cls._dump_root = tempfile.mkdtemp()

    with session.Session() as sess:
      loop_var = constant_op.constant(0, name="while_loop_test/loop_var")
      cond = lambda loop_var: math_ops.less(loop_var, 10)
      body = lambda loop_var: math_ops.add(loop_var, 1)
      while_loop = control_flow_ops.while_loop(
          cond, body, [loop_var], parallel_iterations=1)

      run_options = config_pb2.RunOptions(output_partition_graphs=True)
      debug_url = "file://%s" % cls._dump_root

      watch_opts = run_options.debug_tensor_watch_opts

      # Add debug tensor watch for "while/Identity".
      watch = watch_opts.add()
      watch.node_name = "while/Identity"
      watch.output_slot = 0
      watch.debug_ops.append("DebugIdentity")
      watch.debug_urls.append(debug_url)

      # Invoke Session.run().
      run_metadata = config_pb2.RunMetadata()
      sess.run(while_loop, options=run_options, run_metadata=run_metadata)

    cls._debug_dump = debug_data.DebugDumpDir(
        cls._dump_root, partition_graphs=run_metadata.partition_graphs)

    cls._analyzer = analyzer_cli.DebugAnalyzer(cls._debug_dump)
    cls._registry = debugger_cli_common.CommandHandlerRegistry()
    cls._registry.register_command_handler(
        "list_tensors",
        cls._analyzer.list_tensors,
        cls._analyzer.get_help("list_tensors"),
        prefix_aliases=["lt"])
    cls._registry.register_command_handler(
        "print_tensor",
        cls._analyzer.print_tensor,
        cls._analyzer.get_help("print_tensor"),
        prefix_aliases=["pt"]) 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:43,代码来源:analyzer_cli_test.py

示例14: testIndexedSlicesGradientInCondInWhileLoop

# 需要导入模块: from tensorflow.python.ops import control_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.control_flow_ops import while_loop [as 别名]
def testIndexedSlicesGradientInCondInWhileLoop(self):
    with ops.Graph().as_default():
      embedding_matrix = tf.get_variable(
          "embedding_matrix", [5, 5],
          initializer=tf.random_normal_initializer())

      def Cond(it, _):
        return it < 5
      def Body(it, cost):
        embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
        cost = tf.cond(tf.equal(it, 3),
                       lambda: tf.square(cost),
                       lambda: cost + tf.reduce_sum(embedding))
        return it + 1, cost
      _, cost = control_flow_ops.while_loop(
          Cond, Body, [tf.constant(0), tf.constant(0.0)])

      dynamic_grads = tf.gradients(cost, [embedding_matrix])[0]
      dynamic_grads = tf.segment_sum(dynamic_grads.values,
                                     dynamic_grads.indices)

      embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
      static = tf.square(
          tf.reduce_sum(embedding) +
          tf.reduce_sum(embedding) +
          tf.reduce_sum(embedding)) + tf.reduce_sum(embedding)
      static_grads = tf.gradients(static, [embedding_matrix])[0]
      static_grads = tf.segment_sum(static_grads.values, static_grads.indices)

      with self.test_session() as sess:
        sess.run(tf.global_variables_initializer())
        self.assertAllEqual(*sess.run([static_grads, dynamic_grads])) 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:34,代码来源:control_flow_ops_test.py

示例15: testIndexedSlicesWithShapeGradientInWhileLoop

# 需要导入模块: from tensorflow.python.ops import control_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.control_flow_ops import while_loop [as 别名]
def testIndexedSlicesWithShapeGradientInWhileLoop(self):
    for dtype in [dtypes.float32, dtypes.float64]:
      with self.test_session() as sess:
        num_steps = 9

        inputs = tf.placeholder(dtype=dtype, shape=[num_steps])
        initial_outputs = tf.TensorArray(dtype=dtype, size=num_steps)
        initial_i = tf.constant(0, dtype=dtypes.int32)

        def Cond(i, _):
          return i < num_steps  # pylint: disable=cell-var-from-loop

        def Body(i, outputs):
          x = tf.gather(inputs, i)  # pylint: disable=cell-var-from-loop
          outputs = outputs.write(i, x)
          return i + 1, outputs

        _, outputs = tf.while_loop(Cond, Body, [initial_i, initial_outputs])

        outputs = tf.reduce_sum(outputs.pack())
        r = tf.gradients([outputs], [inputs])[0]
        grad_wr_inputs = ops.convert_to_tensor(r)
        o, grad = sess.run([outputs, grad_wr_inputs],
                           feed_dict={inputs: [4, 6, 0, 7, 0, 0, 1, 2, 0]})
        self.assertEquals(o, 20)
        self.assertAllEqual(grad, [1] * num_steps) 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:28,代码来源:control_flow_ops_test.py


注:本文中的tensorflow.python.ops.control_flow_ops.while_loop方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。