当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.unstack用法及代码示例


将秩-R张量的给定维度解包到秩-(R-1)张量中。

用法

tf.unstack(
    value, num=None, axis=0, name='unstack'
)

参数

  • value A rank R > 0 Tensor 被取消堆叠。
  • num 一个 int 。维度 axis 的长度。自动推断是否为None(默认值)。
  • axis 一个 int 。要拆开的轴。默认为第一个维度。负值环绕,因此有效范围是 [-R, R)
  • name 操作的名称(可选)。

返回

  • Tensor 对象列表从 value 取消堆叠。

抛出

  • ValueError 如果 axis 超出范围 [-R, R)
  • ValueError 如果 num 未指定且无法推断。
  • InvalidArgumentError 如果 numvalue 的形状不匹配。

通过沿 axis 维度将张量从 value 解包。

x = tf.reshape(tf.range(12), (3,4))

p, q, r = tf.unstack(x)
p.shape.as_list()
[4]
i, j, k, l = tf.unstack(x, axis=1)
i.shape.as_list()
[3]

这与堆栈相反。

x = tf.stack([i, j, k, l], axis=1)

更一般地说,如果你有一个形状为 (A, B, C, D) 的张量:

A, B, C, D = [2, 3, 4, 5]
t = tf.random.normal(shape=[A, B, C, D])

返回的张量数等于目标 axis 的长度:

axis = 2
items = tf.unstack(t, axis=axis)
len(items) == t.shape[axis]
True

每个结果张量的形状等于输入张量的形状,目标 axis 被移除。

items[0].shape.as_list()  # [A, B, D]
[2, 3, 5]

每个张量 items[i] 的值等于 input 在索引 i 处跨越 axis 的切片:

for i in range(len(items)):
  slice = t[:,:,i,:]
  assert tf.reduce_all(slice == items[i])

Python 可迭代解包

通过即刻执行,您可以使用 python 的可迭代解包来解开张量的第 0 轴:

t = tf.constant([1,2,3])
a,b,c = t

unstack 仍然是必要的,因为可迭代解包在 @tf.function 中不起作用:符号张量不可迭代。

你需要在这里使用tf.unstack

@tf.function
def bad(t):
  a,b,c = t
  return a

bad(t)
Traceback (most recent call last):

OperatorNotAllowedInGraphError:...
@tf.function
def good(t):
  a,b,c = tf.unstack(t)
  return a

good(t).numpy()
1

未知形状

渴望张量具有具体的值,因此它们的形状总是已知的。在 tf.function 中,符号张量可能具有未知的形状。如果 axis 的长度未知,则 tf.unstack 将失败,因为它无法处理未知数量的张量:

@tf.function(input_signature=[tf.TensorSpec([None], tf.float32)])
def bad(t):
  tensors = tf.unstack(t)
  return tensors[0]

bad(tf.constant([1,2,3]))
Traceback (most recent call last):

ValueError:Cannot infer argument `num` from shape (None,)

如果您知道axis 长度,您可以将其作为num 参数传递。但这必须是一个常数值。

如果您实际上需要在单个 tf.function 跟踪中使用可变数量的张量,则需要使用显式循环和 tf.TensorArray

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.unstack。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。