通過將 fn
應用於軸 0 上未堆疊的每個元素來轉換 elems
。(不推薦使用的參數)
用法
tf.map_fn(
fn, elems, dtype=None, parallel_iterations=None, back_prop=True,
swap_memory=False, infer_shape=True, name=None, fn_output_signature=None
)
參數
-
fn
要執行的可調用對象。它接受一個參數,該參數將具有與elems
相同的(可能是嵌套的)結構。其輸出必須具有與fn_output_signature
相同的結構(如果提供);否則它必須具有與elems
相同的結構。 -
elems
張量或(可能是嵌套的)張量序列,每個張量都將沿其第一維展開。fn
將應用於結果切片的嵌套序列。elems
可能包括不規則和稀疏的張量。elems
必須至少包含一個張量。 -
dtype
已棄用:等效於fn_output_signature
。 -
parallel_iterations
(可選)允許並行運行的迭代次數。構建圖時,默認值為 10。即刻執行時,默認值設置為 1。 -
back_prop
(可選)不推薦使用:更喜歡使用tf.stop_gradient
。 False 禁用對反向傳播的支持。 -
swap_memory
(可選)True 啟用GPU-CPU 內存交換。 -
infer_shape
(可選)False 禁用一致輸出形狀的測試。 -
name
(可選)返回張量的名稱前綴。 -
fn_output_signature
的輸出簽名fn
.必須指定如果fn
的輸入和輸出簽名不同(即,如果它們的結構、數據類型或張量類型不匹配)。fn_output_signature
可以使用以下任何一種方式指定:- A
tf.DType
或tf.TensorSpec
(用於說明tf.Tensor
) - A
tf.RaggedTensorSpec
(說明tf.RaggedTensor
) - A
tf.SparseTensorSpec
(說明tf.sparse.SparseTensor
) - 包含上述類型的(可能是嵌套的)元組、列表或字典。
- A
返回
-
張量或(可能是嵌套的)張量序列。每個張量將應用
fn
的結果沿第一維從第一維到最後一個從elems
未堆疊的張量堆疊。結果可能包括參差不齊和稀疏的張量。
拋出
-
TypeError
如果fn
不可調用或fn
和fn_output_signature
的輸出結構不匹配。 -
ValueError
如果fn
和fn_output_signature
的輸出長度不匹配,或者elems
不包含任何張量。
警告:不推薦使用某些參數:(dtype)
。它們將在未來的版本中被刪除。更新說明:改用fn_output_signature
另見tf.scan
。
map_fn
在軸0上解棧elems
以獲得元素序列;調用 fn
來轉換每個元素;然後將轉換後的值重新堆疊在一起。
使用 single-Tensor 輸入和輸出映射函數
如果 elems
是單個張量並且 fn
的簽名是 tf.Tensor->tf.Tensor
,那麽 map_fn(fn, elems)
等價於 tf.stack([fn(elem) for elem in tf.unstack(elems)])
。例如:
tf.map_fn(fn=lambda t:tf.range(t, t + 3), elems=tf.constant([3, 5, 2]))
<tf.Tensor:shape=(3, 3), dtype=int32, numpy=
array([[3, 4, 5],
[5, 6, 7],
[2, 3, 4]], dtype=int32)>
map_fn(fn, elems).shape = [elems.shape[0]] + fn(elems[0]).shape
.
使用 multi-arity 輸入和輸出映射函數
map_fn
還支持帶有 multi-arity 輸入和輸出的函數:
如果
elems
是張量的元組(或嵌套結構),那麽這些張量必須都具有相同的outer-dimension 大小(num_elems
);fn
用於從elems
轉換相應切片的每個元組(或結構)。例如,如果elems
是元組(t1, t2, t3)
,則fn
用於轉換每個切片元組(t1[i], t2[i], t3[i])
(其中0 <= i < num_elems
)。如果
fn
返回張量的元組(或嵌套結構),則結果是通過堆疊這些結構中的相應元素形成的。
指定 fn
的輸出簽名
如果 fn
的輸入和輸出簽名不同,則必須使用 fn_output_signature
指定輸出簽名。 (如果結構、數據類型或張量類型不匹配,則輸入和輸出簽名不同)。例如:
tf.map_fn(fn=tf.strings.length, # input & output have different dtypes
elems=tf.constant(["hello", "moon"]),
fn_output_signature=tf.int32)
<tf.Tensor:shape=(2,), dtype=int32, numpy=array([5, 4], dtype=int32)>
tf.map_fn(fn=tf.strings.join, # input & output have different structures
elems=[tf.constant(['The', 'A']), tf.constant(['Dog', 'Cat'])],
fn_output_signature=tf.string)
<tf.Tensor:shape=(2,), dtype=string,
numpy=array([b'TheDog', b'ACat'], dtype=object)>
fn_output_signature
可以使用以下任何一種方式指定:
- A
tf.DType
或tf.TensorSpec
(用於說明tf.Tensor
) - A
tf.RaggedTensorSpec
(說明tf.RaggedTensor
) - A
tf.SparseTensorSpec
(說明tf.sparse.SparseTensor
) - 包含上述類型的(可能是嵌套的)元組、列表或字典。
不規則張量
map_fn
支持 tf.RaggedTensor
輸入和輸出。特別是:
如果
elems
是RaggedTensor
,那麽fn
將使用該參差不齊的張量的每一行調用。- 如果
elems
隻有一個參差不齊的維度,則傳遞給fn
的值將是tf.Tensor
s。 - 如果
elems
有多個不規則維度,則傳遞給fn
的值將是少一個不規則維度的tf.RaggedTensor
。
- 如果
如果
map_fn
的結果應該是RaggedTensor
,則使用tf.RaggedTensorSpec
來指定fn_output_signature
。- 如果
fn
返回具有不同大小的tf.Tensor
s,則使用tf.RaggedTensorSpec
和ragged_rank=0
將它們組合成單個不規則張量(將具有 ragged_rank=1)。 - 如果
fn
返回tf.RaggedTensor
s,則使用具有相同ragged_rank
的tf.RaggedTensorSpec
。
- 如果
# Example:RaggedTensor input
rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]])
tf.map_fn(tf.reduce_sum, rt, fn_output_signature=tf.int32)
<tf.Tensor:shape=(4,), dtype=int32, numpy=array([6, 0, 9, 6], dtype=int32)>
# Example:RaggedTensor output
elems = tf.constant([3, 5, 0, 2])
tf.map_fn(tf.range, elems,
fn_output_signature=tf.RaggedTensorSpec(shape=[None],
dtype=tf.int32))
<tf.RaggedTensor [[0, 1, 2], [0, 1, 2, 3, 4], [], [0, 1]]>
注意: map_fn
僅當您需要將函數映射到行一個RaggedTensor
.如果您希望將函數映射到各個值,那麽您應該使用:
tf.ragged.map_flat_values(fn, rt)
(如果 fn 可以表示為 TensorFlow ops)rt.with_flat_values(map_fn(fn, rt.flat_values))
(否則)
例如:
rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]])
tf.ragged.map_flat_values(lambda x:x + 2, rt)
<tf.RaggedTensor [[3, 4, 5], [], [6, 7], [8]]>
稀疏張量
map_fn
支持 tf.sparse.SparseTensor
輸入和輸出。特別是:
如果
elems
是SparseTensor
,則將使用該稀疏張量的每一行調用fn
。特別是,傳遞給fn
的值將是一個比elems
少一維的tf.sparse.SparseTensor
。如果
map_fn
的結果應該是SparseTensor
,則使用tf.SparseTensorSpec
來指定fn_output_signature
。fn
返回的單個SparseTensor
將堆疊成一個具有多維的單個SparseTensor
。
# Example:SparseTensor input
st = tf.sparse.SparseTensor([[0, 0], [2, 0], [2, 1]], [2, 3, 4], [4, 4])
tf.map_fn(tf.sparse.reduce_sum, st, fn_output_signature=tf.int32)
<tf.Tensor:shape=(4,), dtype=int32, numpy=array([2, 0, 7, 0], dtype=int32)>
# Example:SparseTensor output
tf.sparse.to_dense(
tf.map_fn(tf.sparse.eye, tf.constant([2, 3]),
fn_output_signature=tf.SparseTensorSpec(None, tf.float32)))
<tf.Tensor:shape=(2, 3, 3), dtype=float32, numpy=
array([[[1., 0., 0.],
[0., 1., 0.],
[0., 0., 0.]],
[[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]]], dtype=float32)>
注意: map_fn
僅當您需要將函數映射到行一個SparseTensor
.如果你想在非零值上映射一個函數,那麽你應該使用:
如果函數可以表示為 TensorFlow ops,請使用:
tf.sparse.SparseTensor(st.indices, fn(st.values), st.dense_shape)
否則,使用:
tf.sparse.SparseTensor(st.indices, tf.map_fn(fn, st.values), st.dense_shape)
map_fn
與矢量化操作
map_fn
會將 fn
使用的操作應用於 elems
的每個元素,從而產生 O(elems.shape[0])
總操作數。 map_fn
可以並行處理元素這一事實在一定程度上緩解了這種情況。但是,使用map_fn
表示的變換通常仍然比使用矢量化操作表示的等效變換效率低。
map_fn
通常應僅在以下情況之一為真時使用:
- 用矢量化操作表達所需的變換是困難的或昂貴的。
fn
創建較大的中間值,因此等效矢量化變換會占用太多內存。- 並行處理元素比等效矢量化變換更有效。
- 轉換的效率並不重要,使用
map_fn
更具可讀性。
例如,上麵給出的將 fn=lambda t:tf.range(t, t + 3)
映射到 elems
的示例可以使用矢量化操作更有效地重寫:
elems = tf.constant([3, 5, 2])
tf.range(3) + tf.expand_dims(elems, 1)
<tf.Tensor:shape=(3, 3), dtype=int32, numpy=
array([[3, 4, 5],
[5, 6, 7],
[2, 3, 4]], dtype=int32)>
在某些情況下,tf.vectorized_map
可用於自動將函數轉換為矢量化等效函數。
即刻執行
當即刻執行時,即使 parallel_iterations
設置為 > 1 的值,map_fn
也不會並行執行。您仍然可以獲得使用 tf.function
裝飾器並行運行函數的性能優勢:
fn=lambda t:tf.range(t, t + 3)
@tf.function
def func(elems):
return tf.map_fn(fn, elems, parallel_iterations=3)
func(tf.constant([3, 5, 2]))
<tf.Tensor:shape=(3, 3), dtype=int32, numpy=
array([[3, 4, 5],
[5, 6, 7],
[2, 3, 4]], dtype=int32)>
注意:如果您使用 tf.function
裝飾器,您在函數中編寫的任何非 TensorFlow Python 代碼都不會被執行。有關詳細信息,請參閱tf.function
。建議在不使用 tf.function
的情況下進行調試,但切換到它以獲得並行運行 map_fn
的性能優勢。
例子:
elems = np.array([1, 2, 3, 4, 5, 6])
tf.map_fn(lambda x:x * x, elems)
<tf.Tensor:shape=(6,), dtype=int64, numpy=array([ 1, 4, 9, 16, 25, 36])>
elems = (np.array([1, 2, 3]), np.array([-1, 1, -1]))
tf.map_fn(lambda x:x[0] * x[1], elems, fn_output_signature=tf.int64)
<tf.Tensor:shape=(3,), dtype=int64, numpy=array([-1, 2, -3])>
elems = np.array([1, 2, 3])
tf.map_fn(lambda x:(x, -x), elems,
fn_output_signature=(tf.int64, tf.int64))
(<tf.Tensor:shape=(3,), dtype=int64, numpy=array([1, 2, 3])>,
<tf.Tensor:shape=(3,), dtype=int64, numpy=array([-1, -2, -3])>)
相關用法
- Python tf.math.special.fresnel_cos用法及代碼示例
- Python tf.math.polyval用法及代碼示例
- Python tf.math.is_finite用法及代碼示例
- Python tf.math.special.bessel_k0e用法及代碼示例
- Python tf.math.acosh用法及代碼示例
- Python tf.math.invert_permutation用法及代碼示例
- Python tf.math.segment_prod用法及代碼示例
- Python tf.math.bincount用法及代碼示例
- Python tf.math.bessel_i0e用法及代碼示例
- Python tf.math.unsorted_segment_min用法及代碼示例
- Python tf.math.conj用法及代碼示例
- Python tf.math.scalar_mul用法及代碼示例
- Python tf.math.zero_fraction用法及代碼示例
- Python tf.math.reduce_max用法及代碼示例
- Python tf.math.special.fresnel_sin用法及代碼示例
- Python tf.math.segment_mean用法及代碼示例
- Python tf.math.xlog1py用法及代碼示例
- Python tf.math.less_equal用法及代碼示例
- Python tf.math.reduce_min用法及代碼示例
- Python tf.math.log_sigmoid用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.map_fn。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。