当条件 cond
为真时重复 body
。
用法
tf.compat.v1.while_loop(
cond, body, loop_vars, shape_invariants=None, parallel_iterations=10,
back_prop=True, swap_memory=False, name=None, maximum_iterations=None,
return_same_structure=False
)
参数
-
cond
表示循环终止条件的可调用对象。 -
body
表示循环体的可调用对象。 -
loop_vars
一个(可能是嵌套的)元组、namedtuple 或 numpy 数组、Tensor
和TensorArray
对象的列表。 -
shape_invariants
循环变量的形状不变量。 -
parallel_iterations
允许并行运行的迭代次数。它必须是一个正整数。 -
back_prop
是否为此 while 循环启用反向传播。 -
swap_memory
是否为此循环启用GPU-CPU 内存交换。 -
name
返回的张量的可选名称前缀。 -
maximum_iterations
要运行的 while 循环的可选最大迭代次数。如果提供,cond
输出是 AND-ed 附加条件确保执行的迭代次数不大于maximum_iterations
。 -
return_same_structure
如果为 True,则输出具有与loop_vars
相同的结构。如果启用了即刻执行,则将被忽略(并始终视为 True)。
返回
-
循环后循环变量的输出张量。如果
return_same_structure
为 True,则返回值与loop_vars
具有相同的结构。如果return_same_structure
为 False,如果loop_vars
的长度为 1,则返回值为 Tensor、TensorArray 或 IndexedSlice,否则为列表。
抛出
-
TypeError
如果cond
或body
不可调用。 -
ValueError
如果loop_vars
为空。
cond
是一个可调用的返回布尔标量张量。 body
是一个可调用的,返回一个(可能是嵌套的)元组、命名元组或与 loop_vars
具有相同数量(长度和结构)和类型的张量列表。 loop_vars
是传递给 cond
和 body
的(可能是嵌套的)元组、命名元组或张量列表。 cond
和 body
都采用与 loop_vars
一样多的参数。
除了常规的 Tensor 或 IndexedSlices,body 还可以接受和返回 TensorArray 对象。 TensorArray 对象的流将在循环之间和梯度计算期间适当地转发。
注意while_loop
调用cond
和body
恰好一次(在调用内while_loop
, 并且在Session.run()
)。while_loop
将在创建过程中创建的图形片段缝合在一起cond
和body
调用一些额外的图节点来创建重复的图流body
直到cond
返回假。
为了正确起见,tf.while_loop()
严格执行循环变量的形状不变量。形状不变量是一个(可能是部分的)形状,它在循环的迭代中保持不变。如果确定迭代后循环变量的形状比其形状不变量更一般或不兼容,则会引发错误。例如,[11, None] 的形状比 [11, 17] 的形状更通用,[11, 21] 与 [11, 17] 不兼容。默认情况下(如果没有指定参数shape_invariants
),假设loop_vars
中每个张量的初始形状在每次迭代中都是相同的。 shape_invariants
参数允许调用者为每个循环变量指定一个不太具体的形状不变量,如果形状在迭代之间变化,则需要这样做。 tf.Tensor.set_shape
函数也可以在body
函数中使用,以指示输出循环变量具有特定形状。 SparseTensor 和 IndexedSlices 的形状不变量被特殊处理如下:
a) 如果循环变量是 SparseTensor,则形状不变量必须是 TensorShape([r]) 其中 r 是稀疏张量表示的密集张量的秩。这意味着 SparseTensor 的三个张量的形状是 ([None], [None, r], [r])。注意:这里的形状不变量是 SparseTensor.dense_shape 属性的形状。它必须是矢量的形状。
b) 如果循环变量是 IndexedSlices,则形状不变量必须是 IndexedSlices 的值张量的形状不变量。这意味着 IndexedSlices 的三个张量的形状是 (shape, [shape[0]], [shape.ndims])。
while_loop
实现非严格语义,允许多个迭代并行运行。 parallel_iterations
可以控制最大并行迭代次数,这使用户可以控制内存消耗和执行顺序。对于正确的程序,对于任何 parallel_iterations > 0,while_loop
应该返回相同的结果。
对于训练,TensorFlow 存储在前向推理中产生并在反向传播中需要的张量。这些张量是内存消耗的主要来源,并且在 GPU 上训练时经常会导致 OOM 错误。当标志 swap_memory 为真时,我们将这些张量从 GPU 交换到 CPU。例如,这允许我们训练具有非常长序列和大批量的 RNN 模型。
例子:
i = tf.constant(0)
c = lambda i:tf.less(i, 10)
b = lambda i:tf.add(i, 1)
r = tf.while_loop(c, b, [i])
嵌套和命名元组的示例:
import collections
Pair = collections.namedtuple('Pair', 'j, k')
ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2)))
c = lambda i, p:i < 10
b = lambda i, p:(i + 1, Pair((p.j + p.k), (p.j - p.k)))
ijk_final = tf.while_loop(c, b, ijk_0)
使用 shape_invariants 的示例:
i0 = tf.constant(0)
m0 = tf.ones([2, 2])
c = lambda i, m:i < 10
b = lambda i, m:[i+1, tf.concat([m, m], axis=0)]
tf.while_loop(
c, b, loop_vars=[i0, m0],
shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
演示非严格语义的示例:在以下示例中,计数器 i
的最终值不依赖于 x
。所以 while_loop
可以增加计数器与 x
的更新并行。但是,由于一次循环迭代中的循环计数器取决于前一次迭代的值,因此循环计数器本身不能并行递增。因此,如果我们只想要计数器的最终值(我们在 print(sess.run(i))
行上打印),那么 x
将永远不会增加,但计数器将在单个线程上更新。相反,如果我们想要输出的值(我们在 print(sess.run(out).shape)
行上打印),那么计数器可以在它自己的线程上递增,而 x
可以在单独的线程上并行递增。在极端情况下,可以想象,递增计数器的线程会在x
递增一次之前一直运行直到完成。唯一永远不会发生的事情是更新x
的线程永远不会领先于计数器线程,因为递增x
的线程取决于计数器的值。
import tensorflow as tf
n = 10000
x = tf.constant(list(range(n)))
c = lambda i, x:i < n
b = lambda i, x:(tf.compat.v1.Print(i + 1, [i]), tf.compat.v1.Print(x + 1,
[i], "x:"))
i, out = tf.while_loop(c, b, (0, x))
with tf.compat.v1.Session() as sess:
print(sess.run(i)) # prints [0] ... [9999]
# The following line may increment the counter and x in parallel.
# The counter thread may get ahead of the other thread, but not the
# other way around. So you may see things like
# [9996] x:[9987]
# meaning that the counter thread is on iteration 9996,
# while the other thread is on iteration 9987
print(sess.run(out).shape)
相关用法
- Python tf.compat.v1.where用法及代码示例
- Python tf.compat.v1.wrap_function用法及代码示例
- Python tf.compat.v1.distributions.Multinomial.stddev用法及代码示例
- Python tf.compat.v1.distribute.MirroredStrategy.experimental_distribute_dataset用法及代码示例
- Python tf.compat.v1.data.TFRecordDataset.interleave用法及代码示例
- Python tf.compat.v1.distributions.Bernoulli.cross_entropy用法及代码示例
- Python tf.compat.v1.Variable.eval用法及代码示例
- Python tf.compat.v1.train.FtrlOptimizer.compute_gradients用法及代码示例
- Python tf.compat.v1.layers.conv3d用法及代码示例
- Python tf.compat.v1.strings.length用法及代码示例
- Python tf.compat.v1.data.Dataset.snapshot用法及代码示例
- Python tf.compat.v1.data.experimental.SqlDataset.reduce用法及代码示例
- Python tf.compat.v1.feature_column.categorical_column_with_vocabulary_file用法及代码示例
- Python tf.compat.v1.data.TextLineDataset.from_tensors用法及代码示例
- Python tf.compat.v1.variable_scope用法及代码示例
- Python tf.compat.v1.data.experimental.SqlDataset.as_numpy_iterator用法及代码示例
- Python tf.compat.v1.distributions.Bernoulli.covariance用法及代码示例
- Python tf.compat.v1.placeholder用法及代码示例
- Python tf.compat.v1.layers.Conv3D用法及代码示例
- Python tf.compat.v1.train.get_or_create_global_step用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.compat.v1.while_loop。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。