將秩-R
張量的給定維度解包到秩-(R-1)
張量中。
用法
tf.unstack(
value, num=None, axis=0, name='unstack'
)
參數
-
value
A rankR > 0
Tensor
被取消堆疊。 -
num
一個int
。維度axis
的長度。自動推斷是否為None
(默認值)。 -
axis
一個int
。要拆開的軸。默認為第一個維度。負值環繞,因此有效範圍是[-R, R)
。 -
name
操作的名稱(可選)。
返回
-
Tensor
對象列表從value
取消堆疊。
拋出
-
ValueError
如果axis
超出範圍[-R, R)
。 -
ValueError
如果num
未指定且無法推斷。 -
InvalidArgumentError
如果num
與value
的形狀不匹配。
通過沿 axis
維度將張量從 value
解包。
x = tf.reshape(tf.range(12), (3,4))
p, q, r = tf.unstack(x)
p.shape.as_list()
[4]
i, j, k, l = tf.unstack(x, axis=1)
i.shape.as_list()
[3]
這與堆棧相反。
x = tf.stack([i, j, k, l], axis=1)
更一般地說,如果你有一個形狀為 (A, B, C, D)
的張量:
A, B, C, D = [2, 3, 4, 5]
t = tf.random.normal(shape=[A, B, C, D])
返回的張量數等於目標 axis
的長度:
axis = 2
items = tf.unstack(t, axis=axis)
len(items) == t.shape[axis]
True
每個結果張量的形狀等於輸入張量的形狀,目標 axis
被移除。
items[0].shape.as_list() # [A, B, D]
[2, 3, 5]
每個張量 items[i]
的值等於 input
在索引 i
處跨越 axis
的切片:
for i in range(len(items)):
slice = t[:,:,i,:]
assert tf.reduce_all(slice == items[i])
Python 可迭代解包
通過即刻執行,您可以使用 python 的可迭代解包來解開張量的第 0 軸:
t = tf.constant([1,2,3])
a,b,c = t
unstack
仍然是必要的,因為可迭代解包在 @tf.function
中不起作用:符號張量不可迭代。
你需要在這裏使用tf.unstack
:
@tf.function
def bad(t):
a,b,c = t
return a
bad(t)
Traceback (most recent call last):
OperatorNotAllowedInGraphError:...
@tf.function
def good(t):
a,b,c = tf.unstack(t)
return a
good(t).numpy()
1
未知形狀
渴望張量具有具體的值,因此它們的形狀總是已知的。在 tf.function
中,符號張量可能具有未知的形狀。如果 axis
的長度未知,則 tf.unstack
將失敗,因為它無法處理未知數量的張量:
@tf.function(input_signature=[tf.TensorSpec([None], tf.float32)])
def bad(t):
tensors = tf.unstack(t)
return tensors[0]
bad(tf.constant([1,2,3]))
Traceback (most recent call last):
ValueError:Cannot infer argument `num` from shape (None,)
如果您知道axis
長度,您可以將其作為num
參數傳遞。但這必須是一個常數值。
如果您實際上需要在單個 tf.function
跟蹤中使用可變數量的張量,則需要使用顯式循環和 tf.TensorArray
。
相關用法
- Python tf.unique用法及代碼示例
- Python tf.unique_with_counts用法及代碼示例
- Python tf.unravel_index用法及代碼示例
- Python tf.compat.v1.distributions.Multinomial.stddev用法及代碼示例
- Python tf.compat.v1.distribute.MirroredStrategy.experimental_distribute_dataset用法及代碼示例
- Python tf.compat.v1.data.TFRecordDataset.interleave用法及代碼示例
- Python tf.summary.scalar用法及代碼示例
- Python tf.linalg.LinearOperatorFullMatrix.matvec用法及代碼示例
- Python tf.linalg.LinearOperatorToeplitz.solve用法及代碼示例
- Python tf.raw_ops.TPUReplicatedInput用法及代碼示例
- Python tf.raw_ops.Bitcast用法及代碼示例
- Python tf.compat.v1.distributions.Bernoulli.cross_entropy用法及代碼示例
- Python tf.compat.v1.Variable.eval用法及代碼示例
- Python tf.compat.v1.train.FtrlOptimizer.compute_gradients用法及代碼示例
- Python tf.distribute.OneDeviceStrategy.experimental_distribute_values_from_function用法及代碼示例
- Python tf.math.special.fresnel_cos用法及代碼示例
- Python tf.keras.applications.inception_resnet_v2.preprocess_input用法及代碼示例
- Python tf.compat.v1.layers.conv3d用法及代碼示例
- Python tf.Variable.__lt__用法及代碼示例
- Python tf.keras.metrics.Mean.merge_state用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.unstack。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。