当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.compat.v1.tpu.batch_parallel用法及代码示例


沿批处理维度的分片computation 用于并行执行。

用法

tf.compat.v1.tpu.batch_parallel(
    computation:Callable[..., Any],
    inputs:Optional[List[List[Optional[core_types.Tensor]]]] = None,
    num_shards:int = 1,
    infeed_queue:Optional[tpu_feed.InfeedQueue] = None,
    device_assignment:Optional[tf.tpu.experimental.DeviceAssignment] = None,
    name:Optional[Text] = None,
    xla_options:Optional[tf.tpu.XLAOptions] = None
)

参数

  • computation 一个 Python 函数,它构建一个计算以应用于输入的每个分片。
  • inputs 输入张量列表或无(相当于一个空列表)。每个张量的第 0 维必须具有可被 num_shards 整除的大小。
  • num_shards 分片的数量。
  • infeed_queue 如果不是 None ,则从 InfeedQueue 将参数元组作为输入附加到 computation
  • device_assignment 如果不是 None ,则使用 DeviceAssignment 说明计算中的逻辑核心与 TPU 拓扑中的物理核心之间的映射。如果 None ,则使用默认设备分配。 DeviceAssignment 如果计算的每个分片只使用一个核,或者只有一个分片,或者分片数等于 TPU 系统中的核数,则可以省略DeviceAssignment
  • name (已弃用)什么都不做。
  • xla_options tpu.XLAOptions 的实例,指示传递给 XLA 编译器的选项。使用 None 作为默认选项。

返回

  • 输出张量列表。

抛出

  • ValueError 如果num_shards <= 0

shard() 周围的便利包装。

inputs 必须是张量列表或无(相当于一个空列表)。每个输入沿第 0 维拆分为 num_shards 片段,计算并行应用于每个分片。

如果张量被 computation 词汇捕获,则张量将广播到所有分片。例如:,

x = tf.constant(7) def computation(): return x + 3 ... = shard(computation, ...)

所有分片的输出沿它们的第 0 维连接在一起。

计算的输入和输出必须至少是 rank-1 张量。

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.compat.v1.tpu.batch_parallel。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。