本文整理汇总了Python中torch.utils.data.distributed.DistributedSampler方法的典型用法代码示例。如果您正苦于以下问题:Python distributed.DistributedSampler方法的具体用法?Python distributed.DistributedSampler怎么用?Python distributed.DistributedSampler使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.utils.data.distributed
的用法示例。
在下文中一共展示了distributed.DistributedSampler方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: prepare_dataloaders
# 需要导入模块: from torch.utils.data import distributed [as 别名]
# 或者: from torch.utils.data.distributed import DistributedSampler [as 别名]
def prepare_dataloaders(hparams):
# Get data, data loaders and collate function ready
trainset = TextMelLoader(hparams.training_files, hparams)
valset = TextMelLoader(hparams.validation_files, hparams)
collate_fn = TextMelCollate(hparams.n_frames_per_step)
if hparams.distributed_run:
train_sampler = DistributedSampler(trainset)
shuffle = False
else:
train_sampler = None
shuffle = True
train_loader = DataLoader(trainset, num_workers=1, shuffle=shuffle,
sampler=train_sampler,
batch_size=hparams.batch_size, pin_memory=False,
drop_last=True, collate_fn=collate_fn)
return train_loader, valset, collate_fn
示例2: _get_sampler
# 需要导入模块: from torch.utils.data import distributed [as 别名]
# 或者: from torch.utils.data.distributed import DistributedSampler [as 别名]
def _get_sampler(self, epoch: int):
"""
Return a :class:`torch.utils.data.sampler.Sampler` to sample the data.
This is used to distribute the data across the replicas. If shuffling
is enabled, every epoch will have a different shuffle.
Args:
epoch: The epoch being fetched.
Returns:
A sampler which tells the data loader which sample to load next.
"""
world_size = get_world_size()
rank = get_rank()
sampler = DistributedSampler(
self, num_replicas=world_size, rank=rank, shuffle=self.shuffle
)
sampler.set_epoch(epoch)
return sampler
示例3: get_loader
# 需要导入模块: from torch.utils.data import distributed [as 别名]
# 或者: from torch.utils.data.distributed import DistributedSampler [as 别名]
def get_loader(self, force_update=False, override_settings=None, subset_indices=None):
if force_update or self.regime.update(self.epoch, self.steps):
setting = self.get_setting()
if override_settings is not None:
setting.update(override_settings)
self._transform = get_transform(**setting['transform'])
setting['data'].setdefault('transform', self._transform)
self._data = get_dataset(**setting['data'])
if subset_indices is not None:
self._data = Subset(self._data, subset_indices)
if setting['other'].get('distributed', False):
setting['loader']['sampler'] = DistributedSampler(self._data)
setting['loader']['shuffle'] = None
# pin-memory currently broken for distributed
setting['loader']['pin_memory'] = False
self._sampler = setting['loader'].get('sampler', None)
self._loader = torch.utils.data.DataLoader(
self._data, **setting['loader'])
return self._loader
示例4: convert_to_distributed
# 需要导入模块: from torch.utils.data import distributed [as 别名]
# 或者: from torch.utils.data.distributed import DistributedSampler [as 别名]
def convert_to_distributed(self, which_dataset=None, num_replicas=None, rank=None):
samplers = {}
if which_dataset is None:
samplers["train"] = DistributedSampler(self.dataset_train, num_replicas=None, rank=None)
self.loader_train = DataLoader(self.dataset_train, self.batch_size, False, sampler=samplers["train"])
else:
if which_dataset == "train":
samplers["train"] = DistributedSampler(self.dataset_train, num_replicas=num_replicas, rank=rank)
self.loader_train = DataLoader(self.dataset_train, self.batch_size, False,
sampler=samplers["train"])
elif which_dataset == "valid":
samplers["valid"] = DistributedSampler(self.dataset_valid, num_replicas=num_replicas, rank=rank)
self.loader_valid = DataLoader(self.dataset_valid, self.batch_size, False,
sampler=samplers["valid"])
elif which_dataset == "test":
self.loader_test.sampler = samplers["test"]
self.loader_test = DataLoader(self.dataset_test, self.batch_size, False,
sampler=samplers["test"])
else:
ValueError(
"param `which_dataset` can only be set 'train, valid and test'. Got %s instead" % which_dataset)
return samplers
示例5: get_loaders
# 需要导入模块: from torch.utils.data import distributed [as 别名]
# 或者: from torch.utils.data.distributed import DistributedSampler [as 别名]
def get_loaders(traindir, valdir, sz, bs, fp16=True, val_bs=None, workers=8, rect_val=False, min_scale=0.08, distributed=False):
val_bs = val_bs or bs
train_tfms = [
transforms.RandomResizedCrop(sz, scale=(min_scale, 1.0)),
transforms.RandomHorizontalFlip()
]
train_dataset = datasets.ImageFolder(traindir, transforms.Compose(train_tfms))
train_sampler = (DistributedSampler(train_dataset, num_replicas=env_world_size(), rank=env_rank()) if distributed else None)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=bs, shuffle=(train_sampler is None),
num_workers=workers, pin_memory=True, collate_fn=fast_collate,
sampler=train_sampler)
val_dataset, val_sampler = create_validation_set(valdir, val_bs, sz, rect_val=rect_val, distributed=distributed)
val_loader = torch.utils.data.DataLoader(
val_dataset,
num_workers=workers, pin_memory=True, collate_fn=fast_collate,
batch_sampler=val_sampler)
train_loader = BatchTransformDataLoader(train_loader, fp16=fp16)
val_loader = BatchTransformDataLoader(val_loader, fp16=fp16)
return train_loader, val_loader, train_sampler, val_sampler
示例6: _construct_loader
# 需要导入模块: from torch.utils.data import distributed [as 别名]
# 或者: from torch.utils.data.distributed import DistributedSampler [as 别名]
def _construct_loader(dataset_name, split, batch_size, shuffle, drop_last):
"""Constructs the data loader for the given dataset."""
err_str = "Dataset '{}' not supported".format(dataset_name)
assert dataset_name in _DATASETS and dataset_name in _PATHS, err_str
# Retrieve the data path for the dataset
data_path = os.path.join(_DATA_DIR, _PATHS[dataset_name])
# Construct the dataset
dataset = _DATASETS[dataset_name](data_path, split)
# Create a sampler for multi-process training
sampler = DistributedSampler(dataset) if cfg.NUM_GPUS > 1 else None
# Create a loader
loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=(False if sampler else shuffle),
sampler=sampler,
num_workers=cfg.DATA_LOADER.NUM_WORKERS,
pin_memory=cfg.DATA_LOADER.PIN_MEMORY,
drop_last=drop_last,
)
return loader
示例7: get_train_dataloader
# 需要导入模块: from torch.utils.data import distributed [as 别名]
# 或者: from torch.utils.data.distributed import DistributedSampler [as 别名]
def get_train_dataloader(self, train_examples, verbose=True):
train_features = convert_examples_to_features(
train_examples, self.label_map, self.rparams.max_seq_length, self.tokenizer,
verbose=verbose,
)
train_data, train_tokens = convert_to_dataset(
train_features, label_mode=get_label_mode(self.label_map),
)
if self.rparams.local_rank == -1:
train_sampler = RandomSampler(train_data)
else:
train_sampler = DistributedSampler(train_data)
train_dataloader = DataLoader(
train_data, sampler=train_sampler, batch_size=self.rparams.train_batch_size,
)
return HybridLoader(train_dataloader, train_tokens)
示例8: get_train_dataloader
# 需要导入模块: from torch.utils.data import distributed [as 别名]
# 或者: from torch.utils.data.distributed import DistributedSampler [as 别名]
def get_train_dataloader(self, train_examples, verbose=True):
train_features = convert_examples_to_features(
examples=train_examples,
max_seq_length=self.rparams.max_seq_length,
tokenizer=self.tokenizer,
select_prob=self.rparams.select_prob,
verbose=verbose,
)
train_data, train_tokens = convert_to_dataset(train_features)
if self.rparams.local_rank == -1:
train_sampler = RandomSampler(train_data)
else:
train_sampler = DistributedSampler(train_data)
train_dataloader = DataLoader(
train_data, sampler=train_sampler, batch_size=self.rparams.train_batch_size,
)
return HybridLoader(train_dataloader, train_tokens)
示例9: prepare_dataloaders
# 需要导入模块: from torch.utils.data import distributed [as 别名]
# 或者: from torch.utils.data.distributed import DistributedSampler [as 别名]
def prepare_dataloaders(hparams):
# Get data, data loaders and collate function ready
#pids = [id2sp[hparams.speaker_A], id2sp[hparams.speaker_B]]
trainset = TextMelIDLoader(hparams.training_list, hparams.mel_mean_std,
hparams.speaker_A, hparams.speaker_B, pids=None)
valset = TextMelIDLoader(hparams.validation_list, hparams.mel_mean_std,
hparams.speaker_A, hparams.speaker_B, pids=None)
collate_fn = TextMelIDCollate(lcm(hparams.n_frames_per_step_encoder,
hparams.n_frames_per_step_decoder))
train_sampler = DistributedSampler(trainset) \
if hparams.distributed_run else None
train_loader = DataLoader(trainset, num_workers=1, shuffle=True,
sampler=train_sampler,
batch_size=hparams.batch_size, pin_memory=False,
drop_last=True, collate_fn=collate_fn)
return train_loader, valset, collate_fn
示例10: auto_add_sampler
# 需要导入模块: from torch.utils.data import distributed [as 别名]
# 或者: from torch.utils.data.distributed import DistributedSampler [as 别名]
def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
# don't do anything if it's not a dataloader
is_dataloader = isinstance(dataloader, DataLoader)
# don't manipulate iterable datasets
is_iterable_ds = _has_iterable_dataset(dataloader)
if not is_dataloader or is_iterable_ds:
return dataloader
need_dist_sampler = (self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu)
if self.replace_sampler_ddp and need_dist_sampler:
if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)):
raise MisconfigurationException(
'You seem to have configured a sampler in your DataLoader. This will be replaced '
' by `DistributedSampler` since `replace_sampler_ddp` is True and you are using'
' distributed training. Either remove the sampler from your DataLoader or set'
' `replace_sampler_ddp`=False if you want to use your custom sampler.')
# replace with distributed sampler
sampler = self._get_distributed_sampler(dataloader)
dataloader = self.replace_sampler(dataloader, sampler)
return dataloader
示例11: _get_distributed_sampler
# 需要导入模块: from torch.utils.data import distributed [as 别名]
# 或者: from torch.utils.data.distributed import DistributedSampler [as 别名]
def _get_distributed_sampler(self, dataloader):
if self.use_tpu:
kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
elif self.use_horovod:
kwargs = dict(num_replicas=hvd.size(), rank=hvd.rank())
else:
world_size = {
'ddp': self.num_nodes * self.num_processes,
'ddp_spawn': self.num_nodes * self.num_processes,
'ddp2': self.num_nodes,
'ddp_cpu': self.num_processes * self.num_nodes
}
assert self.distributed_backend is not None
kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.global_rank)
sampler = DistributedSampler(dataloader.dataset, **kwargs)
return sampler
示例12: read_eval_data
# 需要导入模块: from torch.utils.data import distributed [as 别名]
# 或者: from torch.utils.data.distributed import DistributedSampler [as 别名]
def read_eval_data(args, tokenizer, logger):
eval_path = os.path.join(args.data_dir, args.predict_file)
eval_set = read_absa_data(eval_path)
eval_examples = convert_absa_data(dataset=eval_set, verbose_logging=args.verbose_logging)
eval_features = convert_examples_to_features(eval_examples, tokenizer, args.max_seq_length,
args.verbose_logging, logger)
logger.info("Num orig examples = %d", len(eval_examples))
logger.info("Num split features = %d", len(eval_features))
logger.info("Batch size = %d", args.predict_batch_size)
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
if args.local_rank == -1:
eval_sampler = SequentialSampler(eval_data)
else:
eval_sampler = DistributedSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.predict_batch_size)
return eval_examples, eval_features, eval_dataloader
示例13: setup_loader
# 需要导入模块: from torch.utils.data import distributed [as 别名]
# 或者: from torch.utils.data.distributed import DistributedSampler [as 别名]
def setup_loader(dataset: Dataset,
batch_size: int,
local_rank: int,
n_gpu: int,
gradient_accumulation_steps: int,
num_workers: int) -> DataLoader:
sampler = DistributedSampler(dataset) if local_rank != -1 else RandomSampler(dataset)
batch_size = get_effective_batch_size(
batch_size, local_rank, n_gpu, gradient_accumulation_steps) * n_gpu
# WARNING: this will fail if the primary sequence is not the first thing the dataset returns
batch_sampler = BucketBatchSampler(
sampler, batch_size, False, lambda x: len(x[0]), dataset)
loader = DataLoader(
dataset,
num_workers=num_workers,
collate_fn=dataset.collate_fn, # type: ignore
batch_sampler=batch_sampler)
return loader
示例14: shuffle_dataset
# 需要导入模块: from torch.utils.data import distributed [as 别名]
# 或者: from torch.utils.data.distributed import DistributedSampler [as 别名]
def shuffle_dataset(loader, cur_epoch):
""""
Shuffles the data.
Args:
loader (loader): data loader to perform shuffle.
cur_epoch (int): number of the current epoch.
"""
sampler = (
loader.batch_sampler.sampler
if isinstance(loader.batch_sampler, ShortCycleBatchSampler)
else loader.sampler
)
assert isinstance(
sampler, (RandomSampler, DistributedSampler)
), "Sampler type '{}' not supported".format(type(sampler))
# RandomSampler handles shuffling automatically
if isinstance(sampler, DistributedSampler):
# DistributedSampler shuffles data based on epoch
sampler.set_epoch(cur_epoch)
示例15: __init__
# 需要导入模块: from torch.utils.data import distributed [as 别名]
# 或者: from torch.utils.data.distributed import DistributedSampler [as 别名]
def __init__(self, dataset, batch_size, distributed=False, num_workers=0, timeout=1000):
if not distributed:
super(ChunkDataloader, self).__init__(dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
collate_fn=self.collate_fn)
else:
import horovod.torch as hvd
sampler = DistributedSampler(dataset, num_replicas=hvd.size(), rank=hvd.rank())
super(ChunkDataloader, self).__init__(dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
collate_fn=self.collate_fn,
drop_last=False,
timeout=timeout)