当前位置: 首页>>代码示例>>Python>>正文


Python v1.while_loop方法代码示例

本文整理汇总了Python中tensorflow.compat.v1.while_loop方法的典型用法代码示例。如果您正苦于以下问题:Python v1.while_loop方法的具体用法?Python v1.while_loop怎么用?Python v1.while_loop使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.compat.v1的用法示例。


在下文中一共展示了v1.while_loop方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _build

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import while_loop [as 别名]
def _build(self, inputs, labels):

    def cond(i, unused_attack, success):
      # If we are already successful, we break.
      return tf.logical_and(i < self._num_restarts,
                            tf.logical_not(tf.reduce_all(success)))

    def body(i, attack, success):
      new_attack = self._inner_attack(inputs, labels)
      new_success = self._inner_attack.success
      # The first iteration always sets the attack.
      use_new_values = tf.logical_or(tf.equal(i, 0), new_success)
      return (i + 1,
              tf.where(use_new_values, new_attack, attack),
              tf.logical_or(success, new_success))

    _, self._attack, self._success = tf.while_loop(
        cond, body, back_prop=False, parallel_iterations=1,
        loop_vars=[
            tf.constant(0, dtype=tf.int32),
            inputs,
            tf.zeros([tf.shape(inputs)[0]], dtype=tf.bool),
        ])
    self._logits = self._eval_fn(self._attack, mode='final')
    return self._attack 
开发者ID:deepmind,项目名称:interval-bound-propagation,代码行数:27,代码来源:attacks.py

示例2: search

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import while_loop [as 别名]
def search(self, initial_ids, initial_cache):
    """Beam search for sequences with highest scores."""
    state, state_shapes = self._create_initial_state(initial_ids, initial_cache)

    finished_state = tf.while_loop(
        self._continue_search, self._search_step, loop_vars=[state],
        shape_invariants=[state_shapes], parallel_iterations=1, back_prop=False)
    finished_state = finished_state[0]

    alive_seq = finished_state[_StateKeys.ALIVE_SEQ]
    alive_log_probs = finished_state[_StateKeys.ALIVE_LOG_PROBS]
    finished_seq = finished_state[_StateKeys.FINISHED_SEQ]
    finished_scores = finished_state[_StateKeys.FINISHED_SCORES]
    finished_flags = finished_state[_StateKeys.FINISHED_FLAGS]

    # Account for corner case where there are no finished sequences for a
    # particular batch item. In that case, return alive sequences for that batch
    # item.
    finished_seq = tf.where(
        tf.reduce_any(finished_flags, 1), finished_seq, alive_seq)
    finished_scores = tf.where(
        tf.reduce_any(finished_flags, 1), finished_scores, alive_log_probs)
    return finished_seq, finished_scores 
开发者ID:tensorflow,项目名称:models,代码行数:25,代码来源:beam_search_v1.py

示例3: test_loop_2_vars

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import while_loop [as 别名]
def test_loop_2_vars():
    graph = tf.Graph()
    with graph.as_default():
        i0 = tf.constant(0)
        j0 = tf.ones([2, 2])

        def c(i, j): return i < 10

        def b(i, j): return [tf.add(i, 1), j]

        i1, i2 = tf.while_loop(c, b, loop_vars=[i0, j0])
        i1 += tf.constant(1337)

        with tf.Session() as sess:
            tf_out = sess.run(i1)

    check_equal(graph, tf_out) 
开发者ID:apache,项目名称:incubator-tvm,代码行数:19,代码来源:test_control_flow.py

示例4: test_loop_3_vars

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import while_loop [as 别名]
def test_loop_3_vars():
    graph = tf.Graph()
    with graph.as_default():
        i0 = tf.constant(1)
        j0 = tf.constant(2)
        k0 = tf.constant(4)

        def c(i, j, k): return i < 10

        def b(i, j, k): return [i+1, j * k, k + i]
        r = tf.while_loop(c, b, loop_vars=[i0, j0, k0])

        with tf.Session() as sess:
            tf_out = sess.run(r)

    check_equal(graph, tf_out) 
开发者ID:apache,项目名称:incubator-tvm,代码行数:18,代码来源:test_control_flow.py

示例5: test_loop_conditions

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import while_loop [as 别名]
def test_loop_conditions():
    graph = tf.Graph()
    with graph.as_default():
        i = tf.constant(1)
        j = tf.constant(1)
        k = tf.constant(5)

        def c(i, j, k): return \
            tf.equal(tf.not_equal(tf.less(i + j, 10),
                                  tf.less(j * k, 100)),
                     tf.greater_equal(k, i + j))

        def b(i, j, k): return [i+j, j+k, k+1]
        r = tf.while_loop(c, b, loop_vars=[i, j, k])
        with tf.Session() as sess:
            tf_out = sess.run(r)

    check_equal(graph, tf_out) 
开发者ID:apache,项目名称:incubator-tvm,代码行数:20,代码来源:test_control_flow.py

示例6: test_nested_loop

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import while_loop [as 别名]
def test_nested_loop():
    graph = tf.Graph()
    with graph.as_default():

        def body(x):
            def nest_body(c):
                return tf.multiply(c, 2)
            def cd(c): return tf.less(c, 10)
            c = tf.constant(2)
            res = tf.while_loop(cd, nest_body, loop_vars=[c])
            return tf.nn.relu(x + res)

        def condition(x):
            return tf.greater(x, 100)
        x = tf.constant(3)
        r = tf.while_loop(condition, body, loop_vars=[x])

        with tf.Session() as sess:
            tf_out = sess.run(r)

    check_equal(graph, tf_out) 
开发者ID:apache,项目名称:incubator-tvm,代码行数:23,代码来源:test_control_flow.py

示例7: test_loop_in_cond

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import while_loop [as 别名]
def test_loop_in_cond():
    graph = tf.Graph()
    with graph.as_default():
        def fn1(a, b):
            i = tf.constant(0)

            def cd(i): return tf.less(i, 10)

            def bd(i): return tf.add(i, 1)
            res = tf.while_loop(cd, bd, [i])
            return tf.multiply(tf.add(20, res), 10)

        def fn2(a, b):
            return tf.add(10, 20)

        x = tf.constant(7)
        y = tf.constant(20)
        z = tf.constant(10)
        pred = tf.less(x, y)
        r = tf.cond(pred, lambda: fn1(x, y), lambda: fn2(y, z))

        with tf.Session() as sess:
            tf_out = sess.run(r, feed_dict={x: 1, y: 2, z: 3, pred: True})

    check_equal(graph, tf_out) 
开发者ID:apache,项目名称:incubator-tvm,代码行数:27,代码来源:test_control_flow.py

示例8: test_cond_in_loop

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import while_loop [as 别名]
def test_cond_in_loop():
    graph = tf.Graph()
    with graph.as_default():
        def body(x):
            x = tf.constant(7)
            z = tf.constant(20)
            res = tf.cond(tf.less(x, 10), lambda: tf.add(
                10, 20), lambda: tf.square(10))
            return tf.multiply(res, x)

        x = tf.constant(21)
        def condition(x):
            return tf.less(x, 100)

        r = tf.while_loop(condition, body, loop_vars=[x])
        with tf.Session() as sess:
            tf_out = sess.run(r)

    check_equal(graph, tf_out) 
开发者ID:apache,项目名称:incubator-tvm,代码行数:21,代码来源:test_control_flow.py

示例9: should_generate_summaries

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import while_loop [as 别名]
def should_generate_summaries():
  """Is this an appropriate context to generate summaries.

  Returns:
    a boolean
  """
  name_scope = contrib.framework().get_name_scope()
  if name_scope and "while/" in name_scope:
    # Summaries don't work well within tf.while_loop()
    return False
  if tf.get_variable_scope().reuse:
    # Avoid generating separate summaries for different data shards
    return False
  return True 
开发者ID:tensorflow,项目名称:tensor2tensor,代码行数:16,代码来源:common_layers.py

示例10: _should_cache_variables

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import while_loop [as 别名]
def _should_cache_variables():
  """Returns True if a default caching device should be set, otherwise False."""
  # Don't set a caching device when running in a loop, since it is possible that
  # train steps could be wrapped in a tf.while_loop. In that scenario caching
  # prevents forward computations in loop iterations from re-reading the
  # updated weights.
  graph = tf.get_default_graph()
  ctxt = graph._get_control_flow_context()  # pylint: disable=protected-access
  in_v1_while_loop = (
      control_flow_util.GetContainingWhileContext(ctxt) is not None)
  return not in_v1_while_loop 
开发者ID:magenta,项目名称:magenta,代码行数:13,代码来源:seq2seq.py

示例11: adapt

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import while_loop [as 别名]
def adapt(self, original_inputs, adversarial_inputs, labels):
    """Runs binary search to find the first misclassified input."""
    batch_size = tf.shape(original_inputs)[0]
    binary_search_iterations = 10

    def cond(i, *_):
      return tf.less(i, binary_search_iterations)

    def get(m):
      m = tf.reshape(m, [batch_size] + [1] * (len(original_inputs.shape) - 1))
      return (adversarial_inputs - original_inputs) * m + original_inputs

    def is_attack_successful(m):
      logits = self._eval_fn(get(m))
      return self._success_fn(self._specification.evaluate(logits))

    def loop_body(i, lower, upper):
      m = (lower + upper) * .5
      success = is_attack_successful(m)
      new_lower = tf.where(success, lower, m)
      new_upper = tf.where(success, m, upper)
      return i + 1, new_lower, new_upper

    lower = tf.zeros(shape=[batch_size])
    upper = tf.ones(shape=[batch_size])
    _, lower, upper = tf.while_loop(
        cond,
        loop_body,
        loop_vars=[tf.constant(0.), lower, upper],
        parallel_iterations=1,
        back_prop=False)
    # If lower is incorrectly classified, pick lower; otherwise pick upper.
    success = is_attack_successful(lower)
    return get(tf.where(success, lower, upper)) 
开发者ID:deepmind,项目名称:interval-bound-propagation,代码行数:36,代码来源:attacks.py

示例12: benchmark_handwritten

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import while_loop [as 别名]
def benchmark_handwritten(self):
    with tf.Graph().as_default():
      ds, opt, hp, w, b = get_data_and_params()
      iterator = ds.make_one_shot_iterator()

      def loop_body(i, unused_previous_loss_t):
        """Manual implementation of training loop."""
        # Call get_next() inside body or else training happens repeatedly on
        # the first minibatch only.
        x, y = iterator.get_next()
        loss_t = loss_fn(x, y, w, b)
        train_op = opt.minimize(loss_t, var_list=(w, b))
        i = tf.cond(tf.equal(i % 100, 0),
                    lambda: tf.Print(i, [i, loss_t], message='Step, loss: '),
                    lambda: i)

        with tf.control_dependencies([train_op]):
          return i + 1, loss_t

      _, final_loss_t = tf.while_loop(
          lambda i, _: i < hp.train_steps,
          loop_body,
          [tf.constant(0), tf.constant(0.0)])

      with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        def target():
          loss_val = sess.run(final_loss_t)
          assert 0.1 < loss_val < 1, loss_val

        self.time_execution(
            'Handwritten',
            target,
            iter_volume=hp.train_steps,
            iter_unit='training steps') 
开发者ID:tensorflow,项目名称:autograph,代码行数:38,代码来源:mnist_benchmark.py

示例13: test_vanilla_loop

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import while_loop [as 别名]
def test_vanilla_loop():
    graph = tf.Graph()
    with graph.as_default():
        i = tf.constant(0, name="while/constant")

        def c(i): return tf.less(i, 10)

        def b(i): return tf.add(i, 1)

        r = tf.while_loop(c, b, [i])

        with tf.Session() as sess:
            tf_out = sess.run(r)

        check_equal(graph, tf_out) 
开发者ID:apache,项目名称:incubator-tvm,代码行数:17,代码来源:test_control_flow.py

示例14: test_callnode_loop_vars

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import while_loop [as 别名]
def test_callnode_loop_vars():
    graph = tf.Graph()
    with graph.as_default():
        i = tf.add(tf.constant(0), 1)

        def c(i): return tf.less(i, 10)

        def b(i): return tf.add(i, 1)

        r = tf.while_loop(c, b, [i])

        with tf.Session() as sess:
            tf_out = sess.run(r)

        check_equal(graph, tf_out) 
开发者ID:apache,项目名称:incubator-tvm,代码行数:17,代码来源:test_control_flow.py

示例15: test_vanilla_loop_bound

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import while_loop [as 别名]
def test_vanilla_loop_bound():
    graph = tf.Graph()
    with graph.as_default():
        dshape = (2, 10)
        dtype = "float32"
        dname = "data"
        np_data = np.random.uniform(size=dshape).astype(dtype)
        data = tf.placeholder(shape=dshape, dtype=dtype, name=dname)
        x = tf.slice(data, [1, 4], [1, 4])
        outer = x + 5.0
        def body(x, y):
            res = tf.cond(tf.less(y, 10), lambda: tf.add(
                10.0, 20.0), lambda: tf.square(10.0))
            z = tf.constant(7)
            res = tf.cond(tf.less(z, 10), lambda: res * 5, lambda: res + 10)
            return tf.multiply(res, x * outer), y + 1

        y = tf.constant(0)
        def condition(x, y):
            return tf.less(y, 20)

        r = tf.while_loop(condition, body, loop_vars=[x, y])
        with tf.Session() as sess:
            tf_out = sess.run(r, feed_dict={"%s:0" % dname: np_data})

    check_equal(graph, tf_out, {dname: np_data}) 
开发者ID:apache,项目名称:incubator-tvm,代码行数:28,代码来源:test_control_flow.py


注:本文中的tensorflow.compat.v1.while_loop方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。