本文简要介绍python语言中 torch.utils.data.distributed.DistributedSampler
的用法。
用法:
class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False)
dataset-用于采样的数据集。
num_replicas(int,可选的) -参与分布式训练的进程数。默认情况下,从当前分布式组中检索
world_size
。rank(int,可选的) -
num_replicas
中当前进程的排名。默认情况下,从当前分布式组中检索rank
。shuffle(bool,可选的) -如果
True
(默认),采样器将打乱索引。seed(int,可选的) -如果
shuffle=True
,则用于洗牌采样器的随机种子。这个数字在分布式组中的所有进程中应该是相同的。默认值:0
。drop_last(bool,可选的) -如果
True
,则采样器将丢弃数据的尾部,以使其在副本数量上均匀整除。如果False
,采样器将添加额外的索引以使数据在副本中均匀分割。默认值:False
。
将数据加载到数据集子集的采样器。
与
torch.nn.parallel.DistributedDataParallel
结合使用特别有用。在这种情况下,每个进程都可以传递DistributedSampler
实例作为DataLoader
采样器,并加载其独有的原始数据集的子集。注意
假设数据集大小不变。
警告
在分布式模式下,调用
set_epoch()
每个时期开始时的方法前创建DataLoader
迭代器是使混洗跨多个时期正常工作所必需的。否则,将始终使用相同的顺序。例子:
>>> sampler = DistributedSampler(dataset) if is_distributed else None >>> loader = DataLoader(dataset, shuffle=(sampler is None), ... sampler=sampler) >>> for epoch in range(start_epoch, n_epochs): ... if is_distributed: ... sampler.set_epoch(epoch) ... train(loader)
参数:
相关用法
- Python PyTorch DistributedModelParallel用法及代码示例
- Python PyTorch DistributedDataParallel用法及代码示例
- Python PyTorch DistributedDataParallel.register_comm_hook用法及代码示例
- Python PyTorch DistributedDataParallel.join用法及代码示例
- Python PyTorch DistributedModelParallel.named_parameters用法及代码示例
- Python PyTorch DistributedModelParallel.state_dict用法及代码示例
- Python PyTorch DistributedDataParallel.no_sync用法及代码示例
- Python PyTorch DistributedModelParallel.named_buffers用法及代码示例
- Python PyTorch DistributedOptimizer用法及代码示例
- Python PyTorch Dirichlet用法及代码示例
- Python PyTorch DeQuantize用法及代码示例
- Python PyTorch DenseArch用法及代码示例
- Python PyTorch DeepFM用法及代码示例
- Python PyTorch DataFrameMaker用法及代码示例
- Python PyTorch DLRM用法及代码示例
- Python PyTorch Dropout用法及代码示例
- Python PyTorch Dropout3d用法及代码示例
- Python PyTorch DataParallel用法及代码示例
- Python PyTorch Decompressor用法及代码示例
- Python PyTorch Dropout2d用法及代码示例
- Python PyTorch DeepFM.forward用法及代码示例
- Python PyTorch Demultiplexer用法及代码示例
- Python PyTorch DatasetFolder.find_classes用法及代码示例
- Python PyTorch frexp用法及代码示例
- Python PyTorch jvp用法及代码示例
注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.utils.data.distributed.DistributedSampler。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。