當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python tf.split用法及代碼示例

將張量 value 拆分為子張量列表。

用法

tf.split(
    value, num_or_size_splits, axis=0, num=None, name='split'
)

參數

  • value 要拆分的Tensor
  • num_or_size_splits int 表示沿 axis 的分割數,或一維整數 Tensor 或 Python 列表,其中包含沿 axis 的每個輸出張量的大小。如果是 int ,那麽它必須平分 value.shape[axis] ;否則沿分割軸的大小總和必須與 value 的大小相匹配。
  • axis 一個 int 或標量 int32 Tensor 。要拆分的維度。必須在 [-rank(value), rank(value)) 範圍內。默認為 0。
  • num 可選的 int ,用於在無法從 size_splits 的形狀推斷時指定輸出的數量。
  • name 操作的名稱(可選)。

返回

  • 如果num_or_size_splitsint,則返回num_or_size_splits Tensor 對象的列表;如果 num_or_size_splits 是一維列表或一維列表,則 Tensor 返回 num_or_size_splits.get_shape[0] Tensor 拆分產生的對象 value

拋出

  • ValueError 如果 num 未指定且無法推斷。
  • ValueError 如果 num_or_size_splits 是標量 Tensor

另見tf.unstack

如果 num_or_size_splitsint ,則它將 value 沿維度 axis 拆分為 num_or_size_splits 較小的張量。這要求 value.shape[axis] 可以被 num_or_size_splits 整除。

如果 num_or_size_splits 是一維張量(或列表),則 value 被拆分為 len(num_or_size_splits) 元素。 i -th 元素的形狀與 value 的大小相同,但沿維度 axis 的大小為 num_or_size_splits[i]

例如:

x = tf.Variable(tf.random.uniform([5, 30], -1, 1))

# Split `x` into 3 tensors along dimension 1
s0, s1, s2 = tf.split(x, num_or_size_splits=3, axis=1)
tf.shape(s0).numpy()
array([ 5, 10], dtype=int32)

# Split `x` into 3 tensors with sizes [4, 15, 11] along dimension 1
split0, split1, split2 = tf.split(x, [4, 15, 11], 1)
tf.shape(split0).numpy()
array([5, 4], dtype=int32)
tf.shape(split1).numpy()
array([ 5, 15], dtype=int32)
tf.shape(split2).numpy()
array([ 5, 11], dtype=int32)

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.split。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。