當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python tf.unstack用法及代碼示例

將秩-R張量的給定維度解包到秩-(R-1)張量中。

用法

tf.unstack(
    value, num=None, axis=0, name='unstack'
)

參數

  • value A rank R > 0 Tensor 被取消堆疊。
  • num 一個 int 。維度 axis 的長度。自動推斷是否為None(默認值)。
  • axis 一個 int 。要拆開的軸。默認為第一個維度。負值環繞,因此有效範圍是 [-R, R)
  • name 操作的名稱(可選)。

返回

  • Tensor 對象列表從 value 取消堆疊。

拋出

  • ValueError 如果 axis 超出範圍 [-R, R)
  • ValueError 如果 num 未指定且無法推斷。
  • InvalidArgumentError 如果 numvalue 的形狀不匹配。

通過沿 axis 維度將張量從 value 解包。

x = tf.reshape(tf.range(12), (3,4))

p, q, r = tf.unstack(x)
p.shape.as_list()
[4]
i, j, k, l = tf.unstack(x, axis=1)
i.shape.as_list()
[3]

這與堆棧相反。

x = tf.stack([i, j, k, l], axis=1)

更一般地說,如果你有一個形狀為 (A, B, C, D) 的張量:

A, B, C, D = [2, 3, 4, 5]
t = tf.random.normal(shape=[A, B, C, D])

返回的張量數等於目標 axis 的長度:

axis = 2
items = tf.unstack(t, axis=axis)
len(items) == t.shape[axis]
True

每個結果張量的形狀等於輸入張量的形狀,目標 axis 被移除。

items[0].shape.as_list()  # [A, B, D]
[2, 3, 5]

每個張量 items[i] 的值等於 input 在索引 i 處跨越 axis 的切片:

for i in range(len(items)):
  slice = t[:,:,i,:]
  assert tf.reduce_all(slice == items[i])

Python 可迭代解包

通過即刻執行,您可以使用 python 的可迭代解包來解開張量的第 0 軸:

t = tf.constant([1,2,3])
a,b,c = t

unstack 仍然是必要的,因為可迭代解包在 @tf.function 中不起作用:符號張量不可迭代。

你需要在這裏使用tf.unstack

@tf.function
def bad(t):
  a,b,c = t
  return a

bad(t)
Traceback (most recent call last):

OperatorNotAllowedInGraphError:...
@tf.function
def good(t):
  a,b,c = tf.unstack(t)
  return a

good(t).numpy()
1

未知形狀

渴望張量具有具體的值,因此它們的形狀總是已知的。在 tf.function 中,符號張量可能具有未知的形狀。如果 axis 的長度未知,則 tf.unstack 將失敗,因為它無法處理未知數量的張量:

@tf.function(input_signature=[tf.TensorSpec([None], tf.float32)])
def bad(t):
  tensors = tf.unstack(t)
  return tensors[0]

bad(tf.constant([1,2,3]))
Traceback (most recent call last):

ValueError:Cannot infer argument `num` from shape (None,)

如果您知道axis 長度,您可以將其作為num 參數傳遞。但這必須是一個常數值。

如果您實際上需要在單個 tf.function 跟蹤中使用可變數量的張量,則需要使用顯式循環和 tf.TensorArray

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.unstack。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。