沿批处理维度的分片computation
用于并行执行。
用法
tf.compat.v1.tpu.batch_parallel(
computation:Callable[..., Any],
inputs:Optional[List[List[Optional[core_types.Tensor]]]] = None,
num_shards:int = 1,
infeed_queue:Optional[tpu_feed.InfeedQueue] = None,
device_assignment:Optional[tf.tpu.experimental.DeviceAssignment] = None,
name:Optional[Text] = None,
xla_options:Optional[tf.tpu.XLAOptions] = None
)
参数
-
computation
一个 Python 函数,它构建一个计算以应用于输入的每个分片。 -
inputs
输入张量列表或无(相当于一个空列表)。每个张量的第 0 维必须具有可被num_shards
整除的大小。 -
num_shards
分片的数量。 -
infeed_queue
如果不是None
,则从InfeedQueue
将参数元组作为输入附加到computation
。 -
device_assignment
如果不是None
,则使用DeviceAssignment
说明计算中的逻辑核心与 TPU 拓扑中的物理核心之间的映射。如果None
,则使用默认设备分配。DeviceAssignment
如果计算的每个分片只使用一个核,或者只有一个分片,或者分片数等于 TPU 系统中的核数,则可以省略DeviceAssignment
。 -
name
(已弃用)什么都不做。 -
xla_options
tpu.XLAOptions
的实例,指示传递给 XLA 编译器的选项。使用None
作为默认选项。
返回
- 输出张量列表。
抛出
-
ValueError
如果num_shards <= 0
shard() 周围的便利包装。
inputs
必须是张量列表或无(相当于一个空列表)。每个输入沿第 0 维拆分为 num_shards
片段,计算并行应用于每个分片。
如果张量被 computation
词汇捕获,则张量将广播到所有分片。例如:,
x = tf.constant(7) def computation(): return x + 3 ... = shard(computation, ...)
所有分片的输出沿它们的第 0 维连接在一起。
计算的输入和输出必须至少是 rank-1 张量。
相关用法
- Python tf.compat.v1.tpu.bfloat16_scope用法及代码示例
- Python tf.compat.v1.tpu.experimental.AdamParameters用法及代码示例
- Python tf.compat.v1.tpu.experimental.embedding_column用法及代码示例
- Python tf.compat.v1.tpu.experimental.FtrlParameters用法及代码示例
- Python tf.compat.v1.tpu.rewrite用法及代码示例
- Python tf.compat.v1.tpu.shutdown_system用法及代码示例
- Python tf.compat.v1.tpu.experimental.shared_embedding_columns用法及代码示例
- Python tf.compat.v1.tpu.outside_compilation用法及代码示例
- Python tf.compat.v1.tpu.experimental.StochasticGradientDescentParameters用法及代码示例
- Python tf.compat.v1.tpu.shard用法及代码示例
- Python tf.compat.v1.tpu.replicate用法及代码示例
- Python tf.compat.v1.tpu.experimental.AdagradParameters用法及代码示例
- Python tf.compat.v1.train.FtrlOptimizer.compute_gradients用法及代码示例
- Python tf.compat.v1.train.get_or_create_global_step用法及代码示例
- Python tf.compat.v1.train.cosine_decay_restarts用法及代码示例
- Python tf.compat.v1.train.Optimizer用法及代码示例
- Python tf.compat.v1.truncated_normal_initializer.from_config用法及代码示例
- Python tf.compat.v1.train.AdagradOptimizer.compute_gradients用法及代码示例
- Python tf.compat.v1.train.init_from_checkpoint用法及代码示例
- Python tf.compat.v1.truncated_normal_initializer用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.compat.v1.tpu.batch_parallel。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。