當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.while_loop用法及代碼示例


當條件 cond 為真時重複 body。 (不推薦使用的參數值)

用法

tf.while_loop(
    cond, body, loop_vars, shape_invariants=None, parallel_iterations=10,
    back_prop=True, swap_memory=False, maximum_iterations=None, name=None
)

參數

  • cond 表示循環終止條件的可調用對象。
  • body 表示循環體的可調用對象。
  • loop_vars 一個(可能是嵌套的)元組、namedtuple 或 numpy 數組、TensorTensorArray 對象的列表。
  • shape_invariants 循環變量的形狀不變量。
  • parallel_iterations 允許並行運行的迭代次數。它必須是一個正整數。
  • back_prop (可選)已棄用。 False 禁用對反向傳播的支持。更喜歡使用tf.stop_gradient
  • swap_memory 是否為此循環啟用GPU-CPU 內存交換。
  • maximum_iterations 要運行的 while 循環的可選最大迭代次數。如果提供,cond 輸出是 AND-ed 附加條件確保執行的迭代次數不大於 maximum_iterations
  • name 返回的張量的可選名稱前綴。

返回

  • 循環後循環變量的輸出張量。返回值的結構與 loop_vars 相同。

拋出

  • TypeError 如果 condbody 不可調用。
  • ValueError 如果loop_vars 為空。

警告:不推薦使用某些參數值:(back_prop=False)。它們將在未來的版本中被刪除。更新說明:back_prop=False 已棄用。考慮改用 tf.stop_gradient。代替:results = tf.while_loop(c, b, vars, back_prop=False) 使用:results = tf.nest.map_structure(tf.stop_gradient, tf.while_loop(c, b , 變量))

cond 是一個可調用的返回布爾標量張量。 body 是一個可調用的,返回一個(可能是嵌套的)元組、命名元組或與 loop_vars 具有相同數量(長度和結構)和類型的張量列表。 loop_vars 是傳遞給 condbody 的(可能是嵌套的)元組、命名元組或張量列表。 condbody 都采用與 loop_vars 一樣多的參數。

除了常規的 Tensor 或 IndexedSlices,body 還可以接受和返回 TensorArray 對象。 TensorArray 對象的流將在循環之間和梯度計算期間適當地轉發。

注意while_loop調用condbody 恰好一次(在調用內while_loop, 並且在Session.run())。while_loop將在創建過程中創建的圖形片段縫合在一起condbody調用一些額外的圖節點來創建重複的圖流body直到cond返回假。

為了正確起見,tf.while_loop() 嚴格執行循環變量的形狀不變量。形狀不變量是一個(可能是部分的)形狀,它在循環的迭代中保持不變。如果確定迭代後循環變量的形狀比其形狀不變量更一般或不兼容,則會引發錯誤。例如,[11, None] 的形狀比 [11, 17] 的形狀更通用,[11, 21] 與 [11, 17] 不兼容。默認情況下(如果沒有指定參數shape_invariants),假設loop_vars中每個張量的初始形狀在每次迭代中都是相同的。 shape_invariants 參數允許調用者為每個循環變量指定一個不太具體的形狀不變量,如果形狀在迭代之間變化,則需要這樣做。 tf.Tensor.set_shape 函數也可以在body 函數中使用,以指示輸出循環變量具有特定形狀。 SparseTensor 和 IndexedSlices 的形狀不變量被特殊處理如下:

a) 如果循環變量是 SparseTensor,則形狀不變量必須是 TensorShape([r]) 其中 r 是稀疏張量表示的密集張量的秩。這意味著 SparseTensor 的三個張量的形狀是 ([None], [None, r], [r])。注意:這裏的形狀不變量是 SparseTensor.dense_shape 屬性的形狀。它必須是矢量的形狀。

b) 如果循環變量是 IndexedSlices,則形狀不變量必須是 IndexedSlices 的值張量的形狀不變量。這意味著 IndexedSlices 的三個張量的形狀是 (shape, [shape[0]], [shape.ndims])。

while_loop 實現非嚴格語義,允許多個迭代並行運行。 parallel_iterations 可以控製最大並行迭代次數,這使用戶可以控製內存消耗和執行順序。對於正確的程序,對於任何 parallel_iterations > 0,while_loop 應該返回相同的結果。

對於訓練,TensorFlow 存儲在前向推理中產生並在反向傳播中需要的張量。這些張量是內存消耗的主要來源,並且在 GPU 上訓練時經常會導致 OOM 錯誤。當標誌 swap_memory 為真時,我們將這些張量從 GPU 交換到 CPU。例如,這允許我們訓練具有非常長序列和大批量的 RNN 模型。

例子:

i = tf.constant(0)
c = lambda i:tf.less(i, 10)
b = lambda i:(tf.add(i, 1), )
r = tf.while_loop(c, b, [i])

嵌套和命名元組的示例:

import collections
Pair = collections.namedtuple('Pair', 'j, k')
ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2)))
c = lambda i, p:i < 10
b = lambda i, p:(i + 1, Pair((p.j + p.k), (p.j - p.k)))
ijk_final = tf.while_loop(c, b, ijk_0)

使用 shape_invariants 的示例:

i0 = tf.constant(0)
m0 = tf.ones([2, 2])
c = lambda i, m:i < 10
b = lambda i, m:[i+1, tf.concat([m, m], axis=0)]
tf.while_loop(
    c, b, loop_vars=[i0, m0],
    shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])

演示非嚴格語義的示例:在以下示例中,計數器 i 的最終值不依賴於 x 。所以 while_loop 可以增加計數器與 x 的更新並行。但是,由於一次循環迭代中的循環計數器取決於前一次迭代的值,因此循環計數器本身不能並行遞增。因此,如果我們隻想要計數器的最終值(我們在 print(sess.run(i)) 行上打印),那麽 x 將永遠不會增加,但計數器將在單個線程上更新。相反,如果我們想要輸出的值(我們在 print(sess.run(out).shape) 行上打印),那麽計數器可以在它自己的線程上遞增,而 x 可以在單獨的線程上並行遞增。在極端情況下,可以想象,遞增計數器的線程會在x 遞增一次之前一直運行直到完成。唯一永遠不會發生的事情是更新x 的線程永遠不會領先於計數器線程,因為遞增x 的線程取決於計數器的值。

import tensorflow as tf

n = 10000
x = tf.constant(list(range(n)))
c = lambda i, x:i < n
b = lambda i, x:(tf.compat.v1.Print(i + 1, [i]), tf.compat.v1.Print(x + 1,
[i], "x:"))
i, out = tf.while_loop(c, b, (0, x))
with tf.compat.v1.Session() as sess:
    print(sess.run(i))  # prints [0] ... [9999]

    # The following line may increment the counter and x in parallel.
    # The counter thread may get ahead of the other thread, but not the
    # other way around. So you may see things like
    # [9996] x:[9987]
    # meaning that the counter thread is on iteration 9996,
    # while the other thread is on iteration 9987
    print(sess.run(out).shape)

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.while_loop。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。