當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch DistributedSampler用法及代碼示例


本文簡要介紹python語言中 torch.utils.data.distributed.DistributedSampler 的用法。

用法:

class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False)

參數

  • dataset-用於采樣的數據集。

  • num_replicas(int,可選的) -參與分布式訓練的進程數。默認情況下,從當前分布式組中檢索world_size

  • rank(int,可選的) -num_replicas 中當前進程的排名。默認情況下,從當前分布式組中檢索rank

  • shuffle(bool,可選的) -如果True(默認),采樣器將打亂索引。

  • seed(int,可選的) -如果 shuffle=True ,則用於洗牌采樣器的隨機種子。這個數字在分布式組中的所有進程中應該是相同的。默認值:0

  • drop_last(bool,可選的) -如果 True ,則采樣器將丟棄數據的尾部,以使其在副本數量上均勻整除。如果 False ,采樣器將添加額外的索引以使數據在副本中均勻分割。默認值:False

將數據加載到數據集子集的采樣器。

torch.nn.parallel.DistributedDataParallel 結合使用特別有用。在這種情況下,每個進程都可以傳遞 DistributedSampler 實例作為 DataLoader 采樣器,並加載其獨有的原始數據集的子集。

注意

假設數據集大小不變。

警告

在分布式模式下,調用set_epoch()每個時期開始時的方法創建DataLoader迭代器是使混洗跨多個時期正常工作所必需的。否則,將始終使用相同的順序。

例子:

>>> sampler = DistributedSampler(dataset) if is_distributed else None
>>> loader = DataLoader(dataset, shuffle=(sampler is None),
...                     sampler=sampler)
>>> for epoch in range(start_epoch, n_epochs):
...     if is_distributed:
...         sampler.set_epoch(epoch)
...     train(loader)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.utils.data.distributed.DistributedSampler。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。