将秩-R
张量的给定维度解包到秩-(R-1)
张量中。
用法
tf.unstack(
value, num=None, axis=0, name='unstack'
)
参数
-
value
A rankR > 0
Tensor
被取消堆叠。 -
num
一个int
。维度axis
的长度。自动推断是否为None
(默认值)。 -
axis
一个int
。要拆开的轴。默认为第一个维度。负值环绕,因此有效范围是[-R, R)
。 -
name
操作的名称(可选)。
返回
-
Tensor
对象列表从value
取消堆叠。
抛出
-
ValueError
如果axis
超出范围[-R, R)
。 -
ValueError
如果num
未指定且无法推断。 -
InvalidArgumentError
如果num
与value
的形状不匹配。
通过沿 axis
维度将张量从 value
解包。
x = tf.reshape(tf.range(12), (3,4))
p, q, r = tf.unstack(x)
p.shape.as_list()
[4]
i, j, k, l = tf.unstack(x, axis=1)
i.shape.as_list()
[3]
这与堆栈相反。
x = tf.stack([i, j, k, l], axis=1)
更一般地说,如果你有一个形状为 (A, B, C, D)
的张量:
A, B, C, D = [2, 3, 4, 5]
t = tf.random.normal(shape=[A, B, C, D])
返回的张量数等于目标 axis
的长度:
axis = 2
items = tf.unstack(t, axis=axis)
len(items) == t.shape[axis]
True
每个结果张量的形状等于输入张量的形状,目标 axis
被移除。
items[0].shape.as_list() # [A, B, D]
[2, 3, 5]
每个张量 items[i]
的值等于 input
在索引 i
处跨越 axis
的切片:
for i in range(len(items)):
slice = t[:,:,i,:]
assert tf.reduce_all(slice == items[i])
Python 可迭代解包
通过即刻执行,您可以使用 python 的可迭代解包来解开张量的第 0 轴:
t = tf.constant([1,2,3])
a,b,c = t
unstack
仍然是必要的,因为可迭代解包在 @tf.function
中不起作用:符号张量不可迭代。
你需要在这里使用tf.unstack
:
@tf.function
def bad(t):
a,b,c = t
return a
bad(t)
Traceback (most recent call last):
OperatorNotAllowedInGraphError:...
@tf.function
def good(t):
a,b,c = tf.unstack(t)
return a
good(t).numpy()
1
未知形状
渴望张量具有具体的值,因此它们的形状总是已知的。在 tf.function
中,符号张量可能具有未知的形状。如果 axis
的长度未知,则 tf.unstack
将失败,因为它无法处理未知数量的张量:
@tf.function(input_signature=[tf.TensorSpec([None], tf.float32)])
def bad(t):
tensors = tf.unstack(t)
return tensors[0]
bad(tf.constant([1,2,3]))
Traceback (most recent call last):
ValueError:Cannot infer argument `num` from shape (None,)
如果您知道axis
长度,您可以将其作为num
参数传递。但这必须是一个常数值。
如果您实际上需要在单个 tf.function
跟踪中使用可变数量的张量,则需要使用显式循环和 tf.TensorArray
。
相关用法
- Python tf.unique用法及代码示例
- Python tf.unique_with_counts用法及代码示例
- Python tf.unravel_index用法及代码示例
- Python tf.compat.v1.distributions.Multinomial.stddev用法及代码示例
- Python tf.compat.v1.distribute.MirroredStrategy.experimental_distribute_dataset用法及代码示例
- Python tf.compat.v1.data.TFRecordDataset.interleave用法及代码示例
- Python tf.summary.scalar用法及代码示例
- Python tf.linalg.LinearOperatorFullMatrix.matvec用法及代码示例
- Python tf.linalg.LinearOperatorToeplitz.solve用法及代码示例
- Python tf.raw_ops.TPUReplicatedInput用法及代码示例
- Python tf.raw_ops.Bitcast用法及代码示例
- Python tf.compat.v1.distributions.Bernoulli.cross_entropy用法及代码示例
- Python tf.compat.v1.Variable.eval用法及代码示例
- Python tf.compat.v1.train.FtrlOptimizer.compute_gradients用法及代码示例
- Python tf.distribute.OneDeviceStrategy.experimental_distribute_values_from_function用法及代码示例
- Python tf.math.special.fresnel_cos用法及代码示例
- Python tf.keras.applications.inception_resnet_v2.preprocess_input用法及代码示例
- Python tf.compat.v1.layers.conv3d用法及代码示例
- Python tf.Variable.__lt__用法及代码示例
- Python tf.keras.metrics.Mean.merge_state用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.unstack。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。