当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.dynamic_partition用法及代码示例


使用来自 partitions 的索引将 data 划分为 num_partitions 张量。

用法

tf.dynamic_partition(
    data, partitions, num_partitions, name=None
)

参数

  • data 一个Tensor
  • partitions Tensor 类型为 int32 。任何形状。 [0, num_partitions) 范围内的索引。
  • num_partitions int>= 1 。要输出的分区数。
  • name 操作的名称(可选)。

返回

  • data 具有相同类型的 num_partitions Tensor 对象的列表。

对于大小为 partitions.ndim 的每个索引元组 js ,切片 data[js, ...] 成为 outputs[partitions[js]] 的一部分。带有partitions[js] = i的切片按照js的字典顺序放置在outputs[i]中,outputs[i]的第一个维度是partitions中的条目数等于i。详细地,

outputs[i].shape = [sum(partitions == i)] + data.shape[partitions.ndim:]

    outputs[i] = pack([data[js, ...] for js if partitions[js] == i])

data.shape 必须以 partitions.shape 开头。

例如:

# Scalar partitions.
    partitions = 1
    num_partitions = 2
    data = [10, 20]
    outputs[0] = []  # Empty with shape [0, 2]
    outputs[1] = [[10, 20]]

    # Vector partitions.
    partitions = [0, 0, 1, 1, 0]
    num_partitions = 2
    data = [10, 20, 30, 40, 50]
    outputs[0] = [10, 20, 50]
    outputs[1] = [30, 40]

有关如何合并分区的示例,请参阅dynamic_stitch

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.dynamic_partition。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。