當條件 cond
為真時重複 body
。
用法
tf.compat.v1.while_loop(
cond, body, loop_vars, shape_invariants=None, parallel_iterations=10,
back_prop=True, swap_memory=False, name=None, maximum_iterations=None,
return_same_structure=False
)
參數
-
cond
表示循環終止條件的可調用對象。 -
body
表示循環體的可調用對象。 -
loop_vars
一個(可能是嵌套的)元組、namedtuple 或 numpy 數組、Tensor
和TensorArray
對象的列表。 -
shape_invariants
循環變量的形狀不變量。 -
parallel_iterations
允許並行運行的迭代次數。它必須是一個正整數。 -
back_prop
是否為此 while 循環啟用反向傳播。 -
swap_memory
是否為此循環啟用GPU-CPU 內存交換。 -
name
返回的張量的可選名稱前綴。 -
maximum_iterations
要運行的 while 循環的可選最大迭代次數。如果提供,cond
輸出是 AND-ed 附加條件確保執行的迭代次數不大於maximum_iterations
。 -
return_same_structure
如果為 True,則輸出具有與loop_vars
相同的結構。如果啟用了即刻執行,則將被忽略(並始終視為 True)。
返回
-
循環後循環變量的輸出張量。如果
return_same_structure
為 True,則返回值與loop_vars
具有相同的結構。如果return_same_structure
為 False,如果loop_vars
的長度為 1,則返回值為 Tensor、TensorArray 或 IndexedSlice,否則為列表。
拋出
-
TypeError
如果cond
或body
不可調用。 -
ValueError
如果loop_vars
為空。
cond
是一個可調用的返回布爾標量張量。 body
是一個可調用的,返回一個(可能是嵌套的)元組、命名元組或與 loop_vars
具有相同數量(長度和結構)和類型的張量列表。 loop_vars
是傳遞給 cond
和 body
的(可能是嵌套的)元組、命名元組或張量列表。 cond
和 body
都采用與 loop_vars
一樣多的參數。
除了常規的 Tensor 或 IndexedSlices,body 還可以接受和返回 TensorArray 對象。 TensorArray 對象的流將在循環之間和梯度計算期間適當地轉發。
注意while_loop
調用cond
和body
恰好一次(在調用內while_loop
, 並且在Session.run()
)。while_loop
將在創建過程中創建的圖形片段縫合在一起cond
和body
調用一些額外的圖節點來創建重複的圖流body
直到cond
返回假。
為了正確起見,tf.while_loop()
嚴格執行循環變量的形狀不變量。形狀不變量是一個(可能是部分的)形狀,它在循環的迭代中保持不變。如果確定迭代後循環變量的形狀比其形狀不變量更一般或不兼容,則會引發錯誤。例如,[11, None] 的形狀比 [11, 17] 的形狀更通用,[11, 21] 與 [11, 17] 不兼容。默認情況下(如果沒有指定參數shape_invariants
),假設loop_vars
中每個張量的初始形狀在每次迭代中都是相同的。 shape_invariants
參數允許調用者為每個循環變量指定一個不太具體的形狀不變量,如果形狀在迭代之間變化,則需要這樣做。 tf.Tensor.set_shape
函數也可以在body
函數中使用,以指示輸出循環變量具有特定形狀。 SparseTensor 和 IndexedSlices 的形狀不變量被特殊處理如下:
a) 如果循環變量是 SparseTensor,則形狀不變量必須是 TensorShape([r]) 其中 r 是稀疏張量表示的密集張量的秩。這意味著 SparseTensor 的三個張量的形狀是 ([None], [None, r], [r])。注意:這裏的形狀不變量是 SparseTensor.dense_shape 屬性的形狀。它必須是矢量的形狀。
b) 如果循環變量是 IndexedSlices,則形狀不變量必須是 IndexedSlices 的值張量的形狀不變量。這意味著 IndexedSlices 的三個張量的形狀是 (shape, [shape[0]], [shape.ndims])。
while_loop
實現非嚴格語義,允許多個迭代並行運行。 parallel_iterations
可以控製最大並行迭代次數,這使用戶可以控製內存消耗和執行順序。對於正確的程序,對於任何 parallel_iterations > 0,while_loop
應該返回相同的結果。
對於訓練,TensorFlow 存儲在前向推理中產生並在反向傳播中需要的張量。這些張量是內存消耗的主要來源,並且在 GPU 上訓練時經常會導致 OOM 錯誤。當標誌 swap_memory 為真時,我們將這些張量從 GPU 交換到 CPU。例如,這允許我們訓練具有非常長序列和大批量的 RNN 模型。
例子:
i = tf.constant(0)
c = lambda i:tf.less(i, 10)
b = lambda i:tf.add(i, 1)
r = tf.while_loop(c, b, [i])
嵌套和命名元組的示例:
import collections
Pair = collections.namedtuple('Pair', 'j, k')
ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2)))
c = lambda i, p:i < 10
b = lambda i, p:(i + 1, Pair((p.j + p.k), (p.j - p.k)))
ijk_final = tf.while_loop(c, b, ijk_0)
使用 shape_invariants 的示例:
i0 = tf.constant(0)
m0 = tf.ones([2, 2])
c = lambda i, m:i < 10
b = lambda i, m:[i+1, tf.concat([m, m], axis=0)]
tf.while_loop(
c, b, loop_vars=[i0, m0],
shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
演示非嚴格語義的示例:在以下示例中,計數器 i
的最終值不依賴於 x
。所以 while_loop
可以增加計數器與 x
的更新並行。但是,由於一次循環迭代中的循環計數器取決於前一次迭代的值,因此循環計數器本身不能並行遞增。因此,如果我們隻想要計數器的最終值(我們在 print(sess.run(i))
行上打印),那麽 x
將永遠不會增加,但計數器將在單個線程上更新。相反,如果我們想要輸出的值(我們在 print(sess.run(out).shape)
行上打印),那麽計數器可以在它自己的線程上遞增,而 x
可以在單獨的線程上並行遞增。在極端情況下,可以想象,遞增計數器的線程會在x
遞增一次之前一直運行直到完成。唯一永遠不會發生的事情是更新x
的線程永遠不會領先於計數器線程,因為遞增x
的線程取決於計數器的值。
import tensorflow as tf
n = 10000
x = tf.constant(list(range(n)))
c = lambda i, x:i < n
b = lambda i, x:(tf.compat.v1.Print(i + 1, [i]), tf.compat.v1.Print(x + 1,
[i], "x:"))
i, out = tf.while_loop(c, b, (0, x))
with tf.compat.v1.Session() as sess:
print(sess.run(i)) # prints [0] ... [9999]
# The following line may increment the counter and x in parallel.
# The counter thread may get ahead of the other thread, but not the
# other way around. So you may see things like
# [9996] x:[9987]
# meaning that the counter thread is on iteration 9996,
# while the other thread is on iteration 9987
print(sess.run(out).shape)
相關用法
- Python tf.compat.v1.where用法及代碼示例
- Python tf.compat.v1.wrap_function用法及代碼示例
- Python tf.compat.v1.distributions.Multinomial.stddev用法及代碼示例
- Python tf.compat.v1.distribute.MirroredStrategy.experimental_distribute_dataset用法及代碼示例
- Python tf.compat.v1.data.TFRecordDataset.interleave用法及代碼示例
- Python tf.compat.v1.distributions.Bernoulli.cross_entropy用法及代碼示例
- Python tf.compat.v1.Variable.eval用法及代碼示例
- Python tf.compat.v1.train.FtrlOptimizer.compute_gradients用法及代碼示例
- Python tf.compat.v1.layers.conv3d用法及代碼示例
- Python tf.compat.v1.strings.length用法及代碼示例
- Python tf.compat.v1.data.Dataset.snapshot用法及代碼示例
- Python tf.compat.v1.data.experimental.SqlDataset.reduce用法及代碼示例
- Python tf.compat.v1.feature_column.categorical_column_with_vocabulary_file用法及代碼示例
- Python tf.compat.v1.data.TextLineDataset.from_tensors用法及代碼示例
- Python tf.compat.v1.variable_scope用法及代碼示例
- Python tf.compat.v1.data.experimental.SqlDataset.as_numpy_iterator用法及代碼示例
- Python tf.compat.v1.distributions.Bernoulli.covariance用法及代碼示例
- Python tf.compat.v1.placeholder用法及代碼示例
- Python tf.compat.v1.layers.Conv3D用法及代碼示例
- Python tf.compat.v1.train.get_or_create_global_step用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.compat.v1.while_loop。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。