當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.vectorized_map用法及代碼示例


從維度 0 的 elems 解壓縮的張量列表上的並行映射。

用法

tf.vectorized_map(
    fn, elems, fallback_to_while_loop=True
)

參數

  • fn 要執行的可調用對象。它接受一個參數,該參數將具有與 elems 相同的(可能是嵌套的)結構,並返回一個可能嵌套的張量和操作結構,該結構可能與 elems 的結構不同。
  • elems 張量或(可能是嵌套的)張量序列,每個張量都將沿其第一維展開。結果切片的嵌套序列將由 fn 映射。所有元素的第一個維度必須廣播到一致的值;等效地,對於某些常見的批量大小 B >= 1 ,每個元素張量必須具有 B1 的第一維。
  • fallback_to_while_loop 如果為 true,則在未能矢量化操作時,將不支持的操作包裝在 tf.while_loop 中以執行映射迭代。請注意,這種回退隻發生在不受支持的操作上,fn 的其他部分仍然是矢量化的。如果為 false,則在遇到不受支持的操作時,會拋出 ValueError。請注意,回退可能會導致減速,因為矢量化通常會產生一到兩個數量級的加速。

返回

  • 張量或(可能是嵌套的)張量序列。每個張量從第一維到最後一個維度,將 fn 應用於從 elems 解包的張量的結果打包。

    盡管它們作為 user-visible 輸入和輸出不太常見,但請注意,表示張量列表的類型為 tf.variant 的張量(例如來自 tf.raw_ops.TensorListFromTensor )是通過堆疊列表內容而不是變量本身來矢量化的,因此容器張量在返回時將具有標量形狀,而不是通常的堆疊形狀。這提高了控製流梯度矢量化的性能。

拋出

  • ValueError 如果矢量化失敗並且 fallback_to_while_loop 為 False。

此方法的用法方式類似於tf.map_fn,但經過優化以運行得更快,可能具有更大的內存占用。加速是通過矢量化獲得的(參見Auto-Vectorizing TensorFlow Graphs:Jacobians, Auto-Batching and Beyond)。矢量化背後的想法是在語義上並行啟動 fn 的所有調用,並在所有這些調用中融合相應的操作。這種融合是在圖形生成時靜態完成的,生成的代碼在性能上通常與手動融合的版本相似。

因為 tf.vectorized_map 完全並行化批處理,所以此方法通常比使用 tf.map_fn 快得多,尤其是在即刻模式下。然而,這是一個實驗性函數,目前有很多限製:

  • fn 的不同語義調用之間不應存在數據依賴性,即以任何順序映射輸入的元素應該是安全的。
  • 有狀態內核可能大多不受支持,因為它們通常意味著數據依賴。不過,我們確實支持一組有限的此類有狀態內核(如 RandomFoo、讀取等變量操作等)。
  • fn 對控製流操作的支持有限。
  • fn 應該返回張量或操作的嵌套結構。但是,如果返回一個操作,它應該有零輸出。
  • fn 計算中任何中間或輸出張量的形狀和 dtype 不應依賴於 fn 的輸入。

例子:

def outer_product(a):
  return tf.tensordot(a, a, 0)

batch_size = 100
a = tf.ones((batch_size, 32, 32))
c = tf.vectorized_map(outer_product, a)
assert c.shape == (batch_size, 32, 32, 32, 32)
# Computing per-example gradients

batch_size = 10
num_features = 32
layer = tf.keras.layers.Dense(1)

def model_fn(arg):
  with tf.GradientTape() as g:
    inp, label = arg
    inp = tf.expand_dims(inp, 0)
    label = tf.expand_dims(label, 0)
    prediction = layer(inp)
    loss = tf.nn.l2_loss(label - prediction)
  return g.gradient(loss, (layer.kernel, layer.bias))

inputs = tf.random.uniform([batch_size, num_features])
labels = tf.random.uniform([batch_size, 1])
per_example_gradients = tf.vectorized_map(model_fn, (inputs, labels))
assert per_example_gradients[0].shape == (batch_size, num_features, 1)
assert per_example_gradients[1].shape == (batch_size, 1)

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.vectorized_map。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。