当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.compat.v1.distribute.experimental.CentralStorageStrategy.experimental_distribute_dataset用法及代码示例

用法

experimental_distribute_dataset(
    dataset, options=None
)

参数

返回

tf.data.Dataset 创建 tf.distribute.DistributedDataset

返回的tf.distribute.DistributedDataset 可以类似于常规数据集进行迭代。注意:用户不能向 tf.distribute.DistributedDataset 添加任何更多的转换。您只能创建一个迭代器或检查它生成的数据的tf.TypeSpec。请参阅tf.distribute.DistributedDataset 的 API 文档以了解更多信息。

下面是一个例子:

global_batch_size = 2
# Passing the devices is optional.
strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"])
# Create a dataset
dataset = tf.data.Dataset.range(4).batch(global_batch_size)
# Distribute that dataset
dist_dataset = strategy.experimental_distribute_dataset(dataset)
@tf.function
def replica_fn(input):
  return input*2
result = []
# Iterate over the `tf.distribute.DistributedDataset`
for x in dist_dataset:
  # process dataset elements
  result.append(strategy.run(replica_fn, args=(x,)))
print(result)
[PerReplica:{
  0:<tf.Tensor:shape=(1,), dtype=int64, numpy=array([0])>,
  1:<tf.Tensor:shape=(1,), dtype=int64, numpy=array([2])>
}, PerReplica:{
  0:<tf.Tensor:shape=(1,), dtype=int64, numpy=array([4])>,
  1:<tf.Tensor:shape=(1,), dtype=int64, numpy=array([6])>
}]

此方法背后发生的三个关键操作是批处理、分片和预取。

在上面的代码片段中,datasetglobal_batch_size 批处理,并在其上调用 experimental_distribute_datasetdataset 重新批处理为等于全局批处理大小除以同步副本数的新批处理大小。我们使用 Pythonic for 循环遍历它。 x 是一个 tf.distribute.DistributedValues 包含所有副本的数据,每个副本获取新批量大小的数据。 tf.distribute.Strategy.run 将负责将 x 中的正确 per-replica 数据馈送到在每个副本上执行的右侧 replica_fn

分片包含跨多个工作人员和每个工作人员内部的自动分片。首先,在multi-worker 分布式训练中(即当您使用tf.distribute.experimental.MultiWorkerMirroredStrategytf.distribute.TPUStrategy 时),在一组workers 上自动分片数据集意味着每个worker 被分配整个数据集的一个子集(如果正确的tf.data.experimental.AutoShardPolicy 是放)。这是为了确保在每个步骤中,每个工作人员将处理非重叠数据集元素的全局批量大小。自动分片有几个不同的选项,可以使用 tf.data.experimental.DistributeOptions 指定。然后,在每个工作人员内进行分片意味着该方法将在所有工作人员设备之间拆分数据(如果存在多个)。无论multi-worker 自动分片如何,都会发生这种情况。

注意:对于跨多个工作人员的自动分片,默认模式是 tf.data.experimental.AutoShardPolicy.AUTO 。如果数据集是从读取器数据集(例如 tf.data.TFRecordDatasettf.data.TextLineDataset 等)创建的,则此模式将尝试按文件对输入数据集进行分片,或者按数据对数据集进行分片,其中每个工作人员将读取整个数据集,并且只处理分配给它的分片。但是,如果每个工作人员的输入文件少于一个,我们建议您通过将 tf.data.experimental.DistributeOptions.auto_shard_policy 设置为 tf.data.experimental.AutoShardPolicy.OFF 来禁用跨工作人员的数据集自动分片。

默认情况下,此方法在用户提供的tf.data.Dataset 实例的末尾添加预取转换。预取转换的参数 buffer_size 等于同步的副本数。

如果上述批量拆分和数据集分片逻辑不可取,请使用tf.distribute.Strategy.distribute_datasets_from_function相反,它不会为您执行任何自动批处理或分片。

注意:如果您使用的是 TPUStrategy,则使用时工作人员处理数据的顺序tf.distribute.Strategy.experimental_distribute_dataset或者tf.distribute.Strategy.distribute_datasets_from_function不保证。如果您正在使用,这通常是必需的tf.distribute规模预测。但是,您可以为批处理中的每个元素插入索引并相应地对输出进行排序。参考这个片段有关如何排序输出的示例。

注意:tf.distribute.experimental_distribute_datasettf.distribute.distribute_datasets_from_function 目前不支持有状态数据集转换。当前忽略数据集可能具有的任何有状态操作。例如,如果您的数据集有一个使用tf.random.uniform 旋转图像的map_fn,那么您有一个依赖于执行python 进程的本地机器上的状态(即随机种子)的数据集图。

有关此方法的更多用法和属性的教程,请参阅分布式输入教程。如果您对最后部分批处理感兴趣,请阅读本节。

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.compat.v1.distribute.experimental.CentralStorageStrategy.experimental_distribute_dataset。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。