当前位置: 首页>>代码示例>>Python>>正文


Python distributed.DistributedSampler方法代码示例

本文整理汇总了Python中torch.utils.data.distributed.DistributedSampler方法的典型用法代码示例。如果您正苦于以下问题:Python distributed.DistributedSampler方法的具体用法?Python distributed.DistributedSampler怎么用?Python distributed.DistributedSampler使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.utils.data.distributed的用法示例。


在下文中一共展示了distributed.DistributedSampler方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: prepare_dataloaders

# 需要导入模块: from torch.utils.data import distributed [as 别名]
# 或者: from torch.utils.data.distributed import DistributedSampler [as 别名]
def prepare_dataloaders(hparams):
    # Get data, data loaders and collate function ready
    trainset = TextMelLoader(hparams.training_files, hparams)
    valset = TextMelLoader(hparams.validation_files, hparams)
    collate_fn = TextMelCollate(hparams.n_frames_per_step)

    if hparams.distributed_run:
        train_sampler = DistributedSampler(trainset)
        shuffle = False
    else:
        train_sampler = None
        shuffle = True

    train_loader = DataLoader(trainset, num_workers=1, shuffle=shuffle,
                              sampler=train_sampler,
                              batch_size=hparams.batch_size, pin_memory=False,
                              drop_last=True, collate_fn=collate_fn)
    return train_loader, valset, collate_fn 
开发者ID:alphacep,项目名称:tn2-wg,代码行数:20,代码来源:train.py

示例2: _get_sampler

# 需要导入模块: from torch.utils.data import distributed [as 别名]
# 或者: from torch.utils.data.distributed import DistributedSampler [as 别名]
def _get_sampler(self, epoch: int):
        """
        Return a :class:`torch.utils.data.sampler.Sampler` to sample the data.

        This is used to distribute the data across the replicas. If shuffling
        is enabled, every epoch will have a different shuffle.

        Args:
            epoch: The epoch being fetched.

        Returns:
            A sampler which tells the data loader which sample to load next.
        """
        world_size = get_world_size()
        rank = get_rank()
        sampler = DistributedSampler(
            self, num_replicas=world_size, rank=rank, shuffle=self.shuffle
        )
        sampler.set_epoch(epoch)
        return sampler 
开发者ID:facebookresearch,项目名称:ClassyVision,代码行数:22,代码来源:classy_dataset.py

示例3: get_loader

# 需要导入模块: from torch.utils.data import distributed [as 别名]
# 或者: from torch.utils.data.distributed import DistributedSampler [as 别名]
def get_loader(self, force_update=False, override_settings=None, subset_indices=None):
        if force_update or self.regime.update(self.epoch, self.steps):
            setting = self.get_setting()
            if override_settings is not None:
                setting.update(override_settings)
            self._transform = get_transform(**setting['transform'])
            setting['data'].setdefault('transform', self._transform)
            self._data = get_dataset(**setting['data'])
            if subset_indices is not None:
                self._data = Subset(self._data, subset_indices)
            if setting['other'].get('distributed', False):
                setting['loader']['sampler'] = DistributedSampler(self._data)
                setting['loader']['shuffle'] = None
                # pin-memory currently broken for distributed
                setting['loader']['pin_memory'] = False
            self._sampler = setting['loader'].get('sampler', None)
            self._loader = torch.utils.data.DataLoader(
                self._data, **setting['loader'])
        return self._loader 
开发者ID:eladhoffer,项目名称:convNet.pytorch,代码行数:21,代码来源:data.py

示例4: convert_to_distributed

# 需要导入模块: from torch.utils.data import distributed [as 别名]
# 或者: from torch.utils.data.distributed import DistributedSampler [as 别名]
def convert_to_distributed(self, which_dataset=None, num_replicas=None, rank=None):
        samplers = {}
        if which_dataset is None:
            samplers["train"] = DistributedSampler(self.dataset_train, num_replicas=None, rank=None)
            self.loader_train = DataLoader(self.dataset_train, self.batch_size, False, sampler=samplers["train"])

        else:
            if which_dataset == "train":
                samplers["train"] = DistributedSampler(self.dataset_train, num_replicas=num_replicas, rank=rank)
                self.loader_train = DataLoader(self.dataset_train, self.batch_size, False,
                                               sampler=samplers["train"])
            elif which_dataset == "valid":
                samplers["valid"] = DistributedSampler(self.dataset_valid, num_replicas=num_replicas, rank=rank)
                self.loader_valid = DataLoader(self.dataset_valid, self.batch_size, False,
                                               sampler=samplers["valid"])
            elif which_dataset == "test":
                self.loader_test.sampler = samplers["test"]
                self.loader_test = DataLoader(self.dataset_test, self.batch_size, False,
                                              sampler=samplers["test"])
            else:
                ValueError(
                    "param `which_dataset` can only be set 'train, valid and test'. Got %s instead" % which_dataset)
        return samplers 
开发者ID:dingguanglei,项目名称:jdit,代码行数:25,代码来源:dataset.py

示例5: get_loaders

# 需要导入模块: from torch.utils.data import distributed [as 别名]
# 或者: from torch.utils.data.distributed import DistributedSampler [as 别名]
def get_loaders(traindir, valdir, sz, bs, fp16=True, val_bs=None, workers=8, rect_val=False, min_scale=0.08, distributed=False):
    val_bs = val_bs or bs
    train_tfms = [
            transforms.RandomResizedCrop(sz, scale=(min_scale, 1.0)),
            transforms.RandomHorizontalFlip()
    ]
    train_dataset = datasets.ImageFolder(traindir, transforms.Compose(train_tfms))
    train_sampler = (DistributedSampler(train_dataset, num_replicas=env_world_size(), rank=env_rank()) if distributed else None)

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=bs, shuffle=(train_sampler is None),
        num_workers=workers, pin_memory=True, collate_fn=fast_collate, 
        sampler=train_sampler)

    val_dataset, val_sampler = create_validation_set(valdir, val_bs, sz, rect_val=rect_val, distributed=distributed)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        num_workers=workers, pin_memory=True, collate_fn=fast_collate, 
        batch_sampler=val_sampler)

    train_loader = BatchTransformDataLoader(train_loader, fp16=fp16)
    val_loader = BatchTransformDataLoader(val_loader, fp16=fp16)

    return train_loader, val_loader, train_sampler, val_sampler 
开发者ID:cybertronai,项目名称:imagenet18_old,代码行数:26,代码来源:dataloader.py

示例6: _construct_loader

# 需要导入模块: from torch.utils.data import distributed [as 别名]
# 或者: from torch.utils.data.distributed import DistributedSampler [as 别名]
def _construct_loader(dataset_name, split, batch_size, shuffle, drop_last):
    """Constructs the data loader for the given dataset."""
    err_str = "Dataset '{}' not supported".format(dataset_name)
    assert dataset_name in _DATASETS and dataset_name in _PATHS, err_str
    # Retrieve the data path for the dataset
    data_path = os.path.join(_DATA_DIR, _PATHS[dataset_name])
    # Construct the dataset
    dataset = _DATASETS[dataset_name](data_path, split)
    # Create a sampler for multi-process training
    sampler = DistributedSampler(dataset) if cfg.NUM_GPUS > 1 else None
    # Create a loader
    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=(False if sampler else shuffle),
        sampler=sampler,
        num_workers=cfg.DATA_LOADER.NUM_WORKERS,
        pin_memory=cfg.DATA_LOADER.PIN_MEMORY,
        drop_last=drop_last,
    )
    return loader 
开发者ID:facebookresearch,项目名称:pycls,代码行数:23,代码来源:loader.py

示例7: get_train_dataloader

# 需要导入模块: from torch.utils.data import distributed [as 别名]
# 或者: from torch.utils.data.distributed import DistributedSampler [as 别名]
def get_train_dataloader(self, train_examples, verbose=True):
        train_features = convert_examples_to_features(
            train_examples, self.label_map, self.rparams.max_seq_length, self.tokenizer,
            verbose=verbose,
        )
        train_data, train_tokens = convert_to_dataset(
            train_features, label_mode=get_label_mode(self.label_map),
        )
        if self.rparams.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(
            train_data, sampler=train_sampler, batch_size=self.rparams.train_batch_size,
        )
        return HybridLoader(train_dataloader, train_tokens) 
开发者ID:zphang,项目名称:bert_on_stilts,代码行数:18,代码来源:runners.py

示例8: get_train_dataloader

# 需要导入模块: from torch.utils.data import distributed [as 别名]
# 或者: from torch.utils.data.distributed import DistributedSampler [as 别名]
def get_train_dataloader(self, train_examples, verbose=True):
        train_features = convert_examples_to_features(
            examples=train_examples,
            max_seq_length=self.rparams.max_seq_length,
            tokenizer=self.tokenizer,
            select_prob=self.rparams.select_prob,
            verbose=verbose,
        )
        train_data, train_tokens = convert_to_dataset(train_features)
        if self.rparams.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(
            train_data, sampler=train_sampler, batch_size=self.rparams.train_batch_size,
        )
        return HybridLoader(train_dataloader, train_tokens) 
开发者ID:zphang,项目名称:bert_on_stilts,代码行数:19,代码来源:runners.py

示例9: prepare_dataloaders

# 需要导入模块: from torch.utils.data import distributed [as 别名]
# 或者: from torch.utils.data.distributed import DistributedSampler [as 别名]
def prepare_dataloaders(hparams):
    # Get data, data loaders and collate function ready
    #pids = [id2sp[hparams.speaker_A], id2sp[hparams.speaker_B]]
    trainset = TextMelIDLoader(hparams.training_list, hparams.mel_mean_std, 
        hparams.speaker_A, hparams.speaker_B, pids=None)
    valset = TextMelIDLoader(hparams.validation_list, hparams.mel_mean_std,
        hparams.speaker_A, hparams.speaker_B, pids=None)
    collate_fn = TextMelIDCollate(lcm(hparams.n_frames_per_step_encoder,
                                      hparams.n_frames_per_step_decoder))

    train_sampler = DistributedSampler(trainset) \
        if hparams.distributed_run else None

    train_loader = DataLoader(trainset, num_workers=1, shuffle=True,
                              sampler=train_sampler,
                              batch_size=hparams.batch_size, pin_memory=False,
                              drop_last=True, collate_fn=collate_fn)
    return train_loader, valset, collate_fn 
开发者ID:jxzhanggg,项目名称:nonparaSeq2seqVC_code,代码行数:20,代码来源:train.py

示例10: auto_add_sampler

# 需要导入模块: from torch.utils.data import distributed [as 别名]
# 或者: from torch.utils.data.distributed import DistributedSampler [as 别名]
def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:

        # don't do anything if it's not a dataloader
        is_dataloader = isinstance(dataloader, DataLoader)
        # don't manipulate iterable datasets
        is_iterable_ds = _has_iterable_dataset(dataloader)

        if not is_dataloader or is_iterable_ds:
            return dataloader
        need_dist_sampler = (self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu)

        if self.replace_sampler_ddp and need_dist_sampler:
            if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)):
                raise MisconfigurationException(
                    'You seem to have configured a sampler in your DataLoader. This will be replaced '
                    ' by `DistributedSampler` since `replace_sampler_ddp` is True and you are using'
                    ' distributed training. Either remove the sampler from your DataLoader or set'
                    ' `replace_sampler_ddp`=False if you want to use your custom sampler.')

            # replace with distributed sampler
            sampler = self._get_distributed_sampler(dataloader)
            dataloader = self.replace_sampler(dataloader, sampler)

        return dataloader 
开发者ID:PyTorchLightning,项目名称:pytorch-lightning,代码行数:26,代码来源:data_loading.py

示例11: _get_distributed_sampler

# 需要导入模块: from torch.utils.data import distributed [as 别名]
# 或者: from torch.utils.data.distributed import DistributedSampler [as 别名]
def _get_distributed_sampler(self, dataloader):
        if self.use_tpu:
            kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
        elif self.use_horovod:
            kwargs = dict(num_replicas=hvd.size(), rank=hvd.rank())
        else:
            world_size = {
                'ddp': self.num_nodes * self.num_processes,
                'ddp_spawn': self.num_nodes * self.num_processes,
                'ddp2': self.num_nodes,
                'ddp_cpu': self.num_processes * self.num_nodes
            }
            assert self.distributed_backend is not None
            kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.global_rank)
        sampler = DistributedSampler(dataloader.dataset, **kwargs)
        return sampler 
开发者ID:PyTorchLightning,项目名称:pytorch-lightning,代码行数:18,代码来源:data_loading.py

示例12: read_eval_data

# 需要导入模块: from torch.utils.data import distributed [as 别名]
# 或者: from torch.utils.data.distributed import DistributedSampler [as 别名]
def read_eval_data(args, tokenizer, logger):
    eval_path = os.path.join(args.data_dir, args.predict_file)
    eval_set = read_absa_data(eval_path)
    eval_examples = convert_absa_data(dataset=eval_set, verbose_logging=args.verbose_logging)

    eval_features = convert_examples_to_features(eval_examples, tokenizer, args.max_seq_length,
                                                 args.verbose_logging, logger)

    logger.info("Num orig examples = %d", len(eval_examples))
    logger.info("Num split features = %d", len(eval_features))
    logger.info("Batch size = %d", args.predict_batch_size)
    all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
    all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
    eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
    if args.local_rank == -1:
        eval_sampler = SequentialSampler(eval_data)
    else:
        eval_sampler = DistributedSampler(eval_data)
    eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.predict_batch_size)
    return eval_examples, eval_features, eval_dataloader 
开发者ID:huminghao16,项目名称:SpanABSA,代码行数:24,代码来源:run_joint_span.py

示例13: setup_loader

# 需要导入模块: from torch.utils.data import distributed [as 别名]
# 或者: from torch.utils.data.distributed import DistributedSampler [as 别名]
def setup_loader(dataset: Dataset,
                 batch_size: int,
                 local_rank: int,
                 n_gpu: int,
                 gradient_accumulation_steps: int,
                 num_workers: int) -> DataLoader:
    sampler = DistributedSampler(dataset) if local_rank != -1 else RandomSampler(dataset)
    batch_size = get_effective_batch_size(
        batch_size, local_rank, n_gpu, gradient_accumulation_steps) * n_gpu
    # WARNING: this will fail if the primary sequence is not the first thing the dataset returns
    batch_sampler = BucketBatchSampler(
        sampler, batch_size, False, lambda x: len(x[0]), dataset)

    loader = DataLoader(
        dataset,
        num_workers=num_workers,
        collate_fn=dataset.collate_fn,  # type: ignore
        batch_sampler=batch_sampler)

    return loader 
开发者ID:songlab-cal,项目名称:tape,代码行数:22,代码来源:setup_utils.py

示例14: shuffle_dataset

# 需要导入模块: from torch.utils.data import distributed [as 别名]
# 或者: from torch.utils.data.distributed import DistributedSampler [as 别名]
def shuffle_dataset(loader, cur_epoch):
    """"
    Shuffles the data.
    Args:
        loader (loader): data loader to perform shuffle.
        cur_epoch (int): number of the current epoch.
    """
    sampler = (
        loader.batch_sampler.sampler
        if isinstance(loader.batch_sampler, ShortCycleBatchSampler)
        else loader.sampler
    )
    assert isinstance(
        sampler, (RandomSampler, DistributedSampler)
    ), "Sampler type '{}' not supported".format(type(sampler))
    # RandomSampler handles shuffling automatically
    if isinstance(sampler, DistributedSampler):
        # DistributedSampler shuffles data based on epoch
        sampler.set_epoch(cur_epoch) 
开发者ID:facebookresearch,项目名称:SlowFast,代码行数:21,代码来源:loader.py

示例15: __init__

# 需要导入模块: from torch.utils.data import distributed [as 别名]
# 或者: from torch.utils.data.distributed import DistributedSampler [as 别名]
def __init__(self, dataset, batch_size, distributed=False, num_workers=0, timeout=1000):
 
        if not distributed: 
            super(ChunkDataloader, self).__init__(dataset,
                                              batch_size=batch_size,
                                              shuffle=True,
                                              num_workers=num_workers,
                                              collate_fn=self.collate_fn)
        else:
            import horovod.torch as hvd
            sampler = DistributedSampler(dataset, num_replicas=hvd.size(), rank=hvd.rank())
            super(ChunkDataloader, self).__init__(dataset,
                                           batch_size=batch_size,
                                           sampler=sampler,
                                           num_workers=num_workers,
                                           collate_fn=self.collate_fn,
                                           drop_last=False,
                                           timeout=timeout) 
开发者ID:jzlianglu,项目名称:pykaldi2,代码行数:20,代码来源:dataloader.py


注:本文中的torch.utils.data.distributed.DistributedSampler方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。