本文整理汇总了Python中torch.utils.data.DistributedSampler方法的典型用法代码示例。如果您正苦于以下问题:Python data.DistributedSampler方法的具体用法?Python data.DistributedSampler怎么用?Python data.DistributedSampler使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.utils.data
的用法示例。
在下文中一共展示了data.DistributedSampler方法的8个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _create_data_loader
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import DistributedSampler [as 别名]
def _create_data_loader(self, data_transform, data_partition, sample_rate=None):
sample_rate = sample_rate or self.hparams.sample_rate
dataset = SliceData(
root=self.hparams.data_path / f'{self.hparams.challenge}_{data_partition}',
transform=data_transform,
sample_rate=sample_rate,
challenge=self.hparams.challenge
)
is_train = (data_partition == 'train')
if is_train:
sampler = DistributedSampler(dataset)
else:
sampler = VolumeSampler(dataset)
return DataLoader(
dataset=dataset,
batch_size=self.hparams.batch_size,
num_workers=4,
pin_memory=False,
drop_last=is_train,
sampler=sampler,
)
示例2: train
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import DistributedSampler [as 别名]
def train(self,
data_loader: Iterable or DataLoader,
mode: str = TRAIN):
""" Training the model for an epoch.
:param data_loader:
:param mode: Name of this loop. Default is `train`. Passed to callbacks.
"""
self._is_train = True
self._epoch += 1
self.model.train()
if hasattr(self.loss_f, "train"):
self.loss_f.train()
with torch.enable_grad():
self._loop(data_loader, mode=mode)
if self.scheduler is not None and self.update_scheduler_by_epoch:
self.scheduler.step()
if isinstance(data_loader, DataLoader) and isinstance(data_loader.sampler, DistributedSampler):
data_loader.sampler.set_epoch(self.epoch)
示例3: __init__
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import DistributedSampler [as 别名]
def __init__(self, root, batch_size, train=True):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
])
dataset = datasets.MNIST(root, train=train, transform=transform, download=True)
sampler = None
if train and distributed_is_initialized():
sampler = data.DistributedSampler(dataset)
super(MNISTDataLoader, self).__init__(
dataset,
batch_size=batch_size,
shuffle=(sampler is None),
sampler=sampler,
)
示例4: load_datasets
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import DistributedSampler [as 别名]
def load_datasets(data_dir, real_dataset, fake_dataset, tokenizer, batch_size,
max_sequence_length, random_sequence_length, epoch_size=None, token_dropout=None, seed=None):
if fake_dataset == 'TWO':
download(real_dataset, 'xl-1542M', 'xl-1542M-nucleus', data_dir=data_dir)
elif fake_dataset == 'THREE':
download(real_dataset, 'xl-1542M', 'xl-1542M-k40', 'xl-1542M-nucleus', data_dir=data_dir)
else:
download(real_dataset, fake_dataset, data_dir=data_dir)
real_corpus = Corpus(real_dataset, data_dir=data_dir)
if fake_dataset == "TWO":
real_train, real_valid = real_corpus.train * 2, real_corpus.valid * 2
fake_corpora = [Corpus(name, data_dir=data_dir) for name in ['xl-1542M', 'xl-1542M-nucleus']]
fake_train = sum([corpus.train for corpus in fake_corpora], [])
fake_valid = sum([corpus.valid for corpus in fake_corpora], [])
elif fake_dataset == "THREE":
real_train, real_valid = real_corpus.train * 3, real_corpus.valid * 3
fake_corpora = [Corpus(name, data_dir=data_dir) for name in
['xl-1542M', 'xl-1542M-k40', 'xl-1542M-nucleus']]
fake_train = sum([corpus.train for corpus in fake_corpora], [])
fake_valid = sum([corpus.valid for corpus in fake_corpora], [])
else:
fake_corpus = Corpus(fake_dataset, data_dir=data_dir)
real_train, real_valid = real_corpus.train, real_corpus.valid
fake_train, fake_valid = fake_corpus.train, fake_corpus.valid
Sampler = DistributedSampler if distributed() and dist.get_world_size() > 1 else RandomSampler
min_sequence_length = 10 if random_sequence_length else None
train_dataset = EncodedDataset(real_train, fake_train, tokenizer, max_sequence_length, min_sequence_length,
epoch_size, token_dropout, seed)
train_loader = DataLoader(train_dataset, batch_size, sampler=Sampler(train_dataset), num_workers=0)
validation_dataset = EncodedDataset(real_valid, fake_valid, tokenizer)
validation_loader = DataLoader(validation_dataset, batch_size=1, sampler=Sampler(validation_dataset))
return train_loader, validation_loader
示例5: normal_dataloader
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import DistributedSampler [as 别名]
def normal_dataloader(train_dataset,valid_dataset,config,arg):
num_workers = config.num_workers
pin_memory = True
logger.info("\n num_workers of dataloader is {}".format(num_workers))
if arg.distributed:
train_dist_sampler = DistributedSampler(train_dataset)
#valid_sampler_dist = DistributedSampler(valid_dataset)
else:
train_dist_sampler = None
train_queue = torch.utils.data.DataLoader(train_dataset,
batch_size = config.train.batchsize,
num_workers = num_workers ,
pin_memory=pin_memory ,
shuffle = (train_dist_sampler is None),
sampler= train_dist_sampler
)
valid_queue = torch.utils.data.DataLoader(valid_dataset,
batch_size = config.test.batchsize,
num_workers = num_workers ,
pin_memory=pin_memory ,
shuffle = False, )
if arg.distributed:
return train_queue ,None, valid_queue ,train_dist_sampler
else:
return train_queue ,None, valid_queue
示例6: _force_make_distributed_loader
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import DistributedSampler [as 别名]
def _force_make_distributed_loader(loader: DataLoader) -> DataLoader:
"""
Transfers loader to distributed mode. Experimental feature.
Args:
loader (DataLoader): pytorch dataloder
Returns:
DataLoader: pytorch dataloder with distributed sampler.
"""
sampler = (
DistributedSampler(dataset=loader.dataset)
if getattr(loader, "sampler", None) is not None
else DistributedSamplerWrapper(sampler=loader.sampler)
)
loader = DataLoader(
dataset=copy(loader.dataset),
batch_size=loader.batch_size,
# shuffle=loader.shuffle,
sampler=sampler,
# batch_sampler=loader.batch_sampler,
num_workers=loader.num_workers,
# collate_fn=loader.collate_fn,
pin_memory=loader.pin_memory,
drop_last=loader.drop_last,
)
return loader
示例7: validate_loaders
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import DistributedSampler [as 别名]
def validate_loaders(loaders: Dict[str, DataLoader]) -> Dict[str, DataLoader]:
"""
Check pytorch dataloaders for distributed setup.
Transfers them to distirbuted mode if necessary.
(Experimental feature)
Args:
loaders (Dict[str, DataLoader]): dictionery with pytorch dataloaders
Returns:
Dict[str, DataLoader]: dictionery
with pytorch dataloaders (with distributed samplers if necessary)
"""
rank = get_rank()
if rank >= 0:
for key, value in loaders.items():
if not isinstance(
value.sampler, (DistributedSampler, DistributedSamplerWrapper)
):
warnings.warn(
"With distributed training setup, "
"you need ``DistributedSampler`` for your ``DataLoader``."
"Transferring to distributed mode. (Experimental feature)"
)
loaders[key] = _force_make_distributed_loader(value)
return loaders
示例8: run
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import DistributedSampler [as 别名]
def run(self,
train_loader: Iterable or DataLoader,
val_loaders: Iterable or DataLoader or Dict[str, Iterable or DataLoader],
total_iterations: int,
val_intervals: int):
""" Train the model for a given iterations. This module is almost equal to ::
for ep in range(total_iterations):
trainer.train(train_loader)
for k, v in val_loaders.items():
trainer.test(v, k)
:param train_loader:
:param val_loaders:
:param total_iterations:
:param val_intervals:
:return:
"""
class ProxyLoader(object):
def __init__(self, loader):
self.loader = loader
def __len__(self):
return val_intervals
def __iter__(self):
counter = 0
while True:
for data in self.loader:
if counter == val_intervals:
return # from python 3.7, this is valid
yield data
counter += 1
train_loader = ProxyLoader(train_loader)
if not isinstance(val_loaders, Dict) and (isinstance(val_loaders, Iterable) or
isinstance(val_loaders, DataLoader)):
val_loaders = {'val': val_loaders}
for ep in range(total_iterations // val_intervals):
self.train(train_loader)
if isinstance(train_loader.loader, DataLoader) \
and isinstance(train_loader.loader.sampler, DistributedSampler):
train_loader.loader.sampler.set_epoch(self.epoch)
for name, loader in val_loaders.items():
self.test(loader, name)