当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.compat.v1.while_loop用法及代码示例


当条件 cond 为真时重复 body

用法

tf.compat.v1.while_loop(
    cond, body, loop_vars, shape_invariants=None, parallel_iterations=10,
    back_prop=True, swap_memory=False, name=None, maximum_iterations=None,
    return_same_structure=False
)

参数

  • cond 表示循环终止条件的可调用对象。
  • body 表示循环体的可调用对象。
  • loop_vars 一个(可能是嵌套的)元组、namedtuple 或 numpy 数组、TensorTensorArray 对象的列表。
  • shape_invariants 循环变量的形状不变量。
  • parallel_iterations 允许并行运行的迭代次数。它必须是一个正整数。
  • back_prop 是否为此 while 循环启用反向传播。
  • swap_memory 是否为此循环启用GPU-CPU 内存交换。
  • name 返回的张量的可选名称前缀。
  • maximum_iterations 要运行的 while 循环的可选最大迭代次数。如果提供,cond 输出是 AND-ed 附加条件确保执行的迭代次数不大于 maximum_iterations
  • return_same_structure 如果为 True,则输出具有与 loop_vars 相同的结构。如果启用了即刻执行,则将被忽略(并始终视为 True)。

返回

  • 循环后循环变量的输出张量。如果 return_same_structure 为 True,则返回值与 loop_vars 具有相同的结构。如果 return_same_structure 为 False,如果 loop_vars 的长度为 1,则返回值为 Tensor、TensorArray 或 IndexedSlice,否则为列表。

抛出

  • TypeError 如果 condbody 不可调用。
  • ValueError 如果loop_vars 为空。

cond 是一个可调用的返回布尔标量张量。 body 是一个可调用的,返回一个(可能是嵌套的)元组、命名元组或与 loop_vars 具有相同数量(长度和结构)和类型的张量列表。 loop_vars 是传递给 condbody 的(可能是嵌套的)元组、命名元组或张量列表。 condbody 都采用与 loop_vars 一样多的参数。

除了常规的 Tensor 或 IndexedSlices,body 还可以接受和返回 TensorArray 对象。 TensorArray 对象的流将在循环之间和梯度计算期间适当地转发。

注意while_loop调用condbody 恰好一次(在调用内while_loop, 并且在Session.run())。while_loop将在创建过程中创建的图形片段缝合在一起condbody调用一些额外的图节点来创建重复的图流body直到cond返回假。

为了正确起见,tf.while_loop() 严格执行循环变量的形状不变量。形状不变量是一个(可能是部分的)形状,它在循环的迭代中保持不变。如果确定迭代后循环变量的形状比其形状不变量更一般或不兼容,则会引发错误。例如,[11, None] 的形状比 [11, 17] 的形状更通用,[11, 21] 与 [11, 17] 不兼容。默认情况下(如果没有指定参数shape_invariants),假设loop_vars中每个张量的初始形状在每次迭代中都是相同的。 shape_invariants 参数允许调用者为每个循环变量指定一个不太具体的形状不变量,如果形状在迭代之间变化,则需要这样做。 tf.Tensor.set_shape 函数也可以在body 函数中使用,以指示输出循环变量具有特定形状。 SparseTensor 和 IndexedSlices 的形状不变量被特殊处理如下:

a) 如果循环变量是 SparseTensor,则形状不变量必须是 TensorShape([r]) 其中 r 是稀疏张量表示的密集张量的秩。这意味着 SparseTensor 的三个张量的形状是 ([None], [None, r], [r])。注意:这里的形状不变量是 SparseTensor.dense_shape 属性的形状。它必须是矢量的形状。

b) 如果循环变量是 IndexedSlices,则形状不变量必须是 IndexedSlices 的值张量的形状不变量。这意味着 IndexedSlices 的三个张量的形状是 (shape, [shape[0]], [shape.ndims])。

while_loop 实现非严格语义,允许多个迭代并行运行。 parallel_iterations 可以控制最大并行迭代次数,这使用户可以控制内存消耗和执行顺序。对于正确的程序,对于任何 parallel_iterations > 0,while_loop 应该返回相同的结果。

对于训练,TensorFlow 存储在前向推理中产生并在反向传播中需要的张量。这些张量是内存消耗的主要来源,并且在 GPU 上训练时经常会导致 OOM 错误。当标志 swap_memory 为真时,我们将这些张量从 GPU 交换到 CPU。例如,这允许我们训练具有非常长序列和大批量的 RNN 模型。

例子:

i = tf.constant(0)
c = lambda i:tf.less(i, 10)
b = lambda i:tf.add(i, 1)
r = tf.while_loop(c, b, [i])

嵌套和命名元组的示例:

import collections
Pair = collections.namedtuple('Pair', 'j, k')
ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2)))
c = lambda i, p:i < 10
b = lambda i, p:(i + 1, Pair((p.j + p.k), (p.j - p.k)))
ijk_final = tf.while_loop(c, b, ijk_0)

使用 shape_invariants 的示例:

i0 = tf.constant(0)
m0 = tf.ones([2, 2])
c = lambda i, m:i < 10
b = lambda i, m:[i+1, tf.concat([m, m], axis=0)]
tf.while_loop(
    c, b, loop_vars=[i0, m0],
    shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])

演示非严格语义的示例:在以下示例中,计数器 i 的最终值不依赖于 x 。所以 while_loop 可以增加计数器与 x 的更新并行。但是,由于一次循环迭代中的循环计数器取决于前一次迭代的值,因此循环计数器本身不能并行递增。因此,如果我们只想要计数器的最终值(我们在 print(sess.run(i)) 行上打印),那么 x 将永远不会增加,但计数器将在单个线程上更新。相反,如果我们想要输出的值(我们在 print(sess.run(out).shape) 行上打印),那么计数器可以在它自己的线程上递增,而 x 可以在单独的线程上并行递增。在极端情况下,可以想象,递增计数器的线程会在x 递增一次之前一直运行直到完成。唯一永远不会发生的事情是更新x 的线程永远不会领先于计数器线程,因为递增x 的线程取决于计数器的值。

import tensorflow as tf

n = 10000
x = tf.constant(list(range(n)))
c = lambda i, x:i < n
b = lambda i, x:(tf.compat.v1.Print(i + 1, [i]), tf.compat.v1.Print(x + 1,
[i], "x:"))
i, out = tf.while_loop(c, b, (0, x))
with tf.compat.v1.Session() as sess:
    print(sess.run(i))  # prints [0] ... [9999]

    # The following line may increment the counter and x in parallel.
    # The counter thread may get ahead of the other thread, but not the
    # other way around. So you may see things like
    # [9996] x:[9987]
    # meaning that the counter thread is on iteration 9996,
    # while the other thread is on iteration 9987
    print(sess.run(out).shape)

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.compat.v1.while_loop。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。