当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.split用法及代码示例


将张量 value 拆分为子张量列表。

用法

tf.split(
    value, num_or_size_splits, axis=0, num=None, name='split'
)

参数

  • value 要拆分的Tensor
  • num_or_size_splits int 表示沿 axis 的分割数,或一维整数 Tensor 或 Python 列表,其中包含沿 axis 的每个输出张量的大小。如果是 int ,那么它必须平分 value.shape[axis] ;否则沿分割轴的大小总和必须与 value 的大小相匹配。
  • axis 一个 int 或标量 int32 Tensor 。要拆分的维度。必须在 [-rank(value), rank(value)) 范围内。默认为 0。
  • num 可选的 int ,用于在无法从 size_splits 的形状推断时指定输出的数量。
  • name 操作的名称(可选)。

返回

  • 如果num_or_size_splitsint,则返回num_or_size_splits Tensor 对象的列表;如果 num_or_size_splits 是一维列表或一维列表,则 Tensor 返回 num_or_size_splits.get_shape[0] Tensor 拆分产生的对象 value

抛出

  • ValueError 如果 num 未指定且无法推断。
  • ValueError 如果 num_or_size_splits 是标量 Tensor

另见tf.unstack

如果 num_or_size_splitsint ,则它将 value 沿维度 axis 拆分为 num_or_size_splits 较小的张量。这要求 value.shape[axis] 可以被 num_or_size_splits 整除。

如果 num_or_size_splits 是一维张量(或列表),则 value 被拆分为 len(num_or_size_splits) 元素。 i -th 元素的形状与 value 的大小相同,但沿维度 axis 的大小为 num_or_size_splits[i]

例如:

x = tf.Variable(tf.random.uniform([5, 30], -1, 1))

# Split `x` into 3 tensors along dimension 1
s0, s1, s2 = tf.split(x, num_or_size_splits=3, axis=1)
tf.shape(s0).numpy()
array([ 5, 10], dtype=int32)

# Split `x` into 3 tensors with sizes [4, 15, 11] along dimension 1
split0, split1, split2 = tf.split(x, [4, 15, 11], 1)
tf.shape(split0).numpy()
array([5, 4], dtype=int32)
tf.shape(split1).numpy()
array([ 5, 15], dtype=int32)
tf.shape(split2).numpy()
array([ 5, 11], dtype=int32)

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.split。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。