當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.compat.v1.scan用法及代碼示例


掃描從維度 0 的 elems 解壓縮的張量列表。

用法

tf.compat.v1.scan(
    fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
    swap_memory=False, infer_shape=True, reverse=False, name=None
)

參數

  • fn 要執行的可調用對象。它接受兩個參數。如果提供了一個,第一個將具有與 initializer 相同的結構,否則它將具有與 elems 相同的結構。第二個將具有與 elems 相同的(可能是嵌套的)結構。如果提供了一個,它的輸出必須與 initializer 具有相同的結構,否則它必須與 elems 具有相同的結構。
  • elems 張量或(可能是嵌套的)張量序列,每個張量都將沿其第一維展開。結果切片的嵌套序列將是 fn 的第一個參數。
  • initializer (可選)張量或(可能是嵌套的)張量序列、累加器的初始值以及 fn 的預期輸出類型。
  • parallel_iterations (可選)允許並行運行的迭代次數。
  • back_prop (可選)True 啟用對反向傳播的支持。
  • swap_memory (可選)True 啟用GPU-CPU 內存交換。
  • infer_shape (可選)False 禁用一致輸出形狀的測試。
  • reverse (可選)True 從後到前掃描張量(而不是從前到後)。
  • name (可選)返回張量的名稱前綴。

返回

  • 張量或(可能是嵌套的)張量序列。每個張量將應用 fn 的結果沿第一個維度從 elems 解壓縮的張量打包,以及前一個累加器值,從第一個到最後一個(或從最後一個到第一個,如果 reverse=True )。

拋出

  • TypeError 如果fn 不可調用或fninitializer 的輸出結構不匹配。
  • ValueError 如果fninitializer 的輸出長度不匹配。

另見tf.map_fn

scan 的最簡單版本重複地將可調用的fn 應用於從第一個到最後一個元素的序列。元素由從維度 0 的 elems 解壓縮的張量組成。可調用的 fn 將兩個張量作為參數。第一個參數是從前麵的 fn 調用計算的累加值,第二個參數是 elems 的當前位置的值。如果initializer 為None,則elems 必須至少包含一個元素,並且它的第一個元素用作初始值設定項。

假設 elems 被解壓縮到 values 中,這是一個張量列表。結果張量的形狀是 [len(values)] + fn(initializer, values[0]).shape 。如果 reverse=True,則為 fn(initializer, values[-1]).shape。

此方法還允許multi-arity elems 和累加器。如果elems 是一個(可能是嵌套的)張量列表或元組,那麽這些張量中的每一個都必須具有匹配的第一個(解包)維度。 fn 的第二個參數必須與 elems 的結構匹配。

如果沒有提供initializer,則假設fn的輸出結構和dtypes與其輸入相同;在這種情況下,fn 的第一個參數必須與 elems 的結構匹配。

如果提供了 initializer,則 fn 的輸出必須與 initializer 具有相同的結構;並且fn 的第一個參數必須匹配這個結構。

例如,如果 elems(t1, [t2, t3]) 並且 initializer[i1, i2] 那麽 python2fn 的適當簽名是:fn = lambda (acc_p1, acc_p2), (t1, [t2, t3]):fn 必須返回一個列表 [acc_n1, acc_n2]fn 的另一種正確簽名以及在 python3 中工作的簽名是: fn = lambda a, t: ,其中 at 對應於輸入元組。

例子:

elems = np.array([1, 2, 3, 4, 5, 6])
sum = scan(lambda a, x:a + x, elems)
# sum == [1, 3, 6, 10, 15, 21]
sum = scan(lambda a, x:a + x, elems, reverse=True)
# sum == [21, 20, 18, 15, 11, 6]
elems = np.array([1, 2, 3, 4, 5, 6])
initializer = np.array(0)
sum_one = scan(
    lambda a, x:x[0] - x[1] + a, (elems + 1, elems), initializer)
# sum_one == [1, 2, 3, 4, 5, 6]
elems = np.array([1, 0, 0, 0, 0, 0])
initializer = (np.array(0), np.array(1))
fibonaccis = scan(lambda a, _:(a[1], a[0] + a[1]), elems, initializer)
# fibonaccis == ([1, 1, 2, 3, 5, 8], [1, 2, 3, 5, 8, 13])

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.compat.v1.scan。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。