当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch DistributedSampler用法及代码示例


本文简要介绍python语言中 torch.utils.data.distributed.DistributedSampler 的用法。

用法:

class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False)

参数

  • dataset-用于采样的数据集。

  • num_replicas(int,可选的) -参与分布式训练的进程数。默认情况下,从当前分布式组中检索world_size

  • rank(int,可选的) -num_replicas 中当前进程的排名。默认情况下,从当前分布式组中检索rank

  • shuffle(bool,可选的) -如果True(默认),采样器将打乱索引。

  • seed(int,可选的) -如果 shuffle=True ,则用于洗牌采样器的随机种子。这个数字在分布式组中的所有进程中应该是相同的。默认值:0

  • drop_last(bool,可选的) -如果 True ,则采样器将丢弃数据的尾部,以使其在副本数量上均匀整除。如果 False ,采样器将添加额外的索引以使数据在副本中均匀分割。默认值:False

将数据加载到数据集子集的采样器。

torch.nn.parallel.DistributedDataParallel 结合使用特别有用。在这种情况下,每个进程都可以传递 DistributedSampler 实例作为 DataLoader 采样器,并加载其独有的原始数据集的子集。

注意

假设数据集大小不变。

警告

在分布式模式下,调用set_epoch()每个时期开始时的方法创建DataLoader迭代器是使混洗跨多个时期正常工作所必需的。否则,将始终使用相同的顺序。

例子:

>>> sampler = DistributedSampler(dataset) if is_distributed else None
>>> loader = DataLoader(dataset, shuffle=(sampler is None),
...                     sampler=sampler)
>>> for epoch in range(start_epoch, n_epochs):
...     if is_distributed:
...         sampler.set_epoch(epoch)
...     train(loader)

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.utils.data.distributed.DistributedSampler。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。