当前位置: 首页>>代码示例>>Python>>正文


Python data.DistributedSampler方法代码示例

本文整理汇总了Python中torch.utils.data.DistributedSampler方法的典型用法代码示例。如果您正苦于以下问题:Python data.DistributedSampler方法的具体用法?Python data.DistributedSampler怎么用?Python data.DistributedSampler使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.utils.data的用法示例。


在下文中一共展示了data.DistributedSampler方法的8个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _create_data_loader

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import DistributedSampler [as 别名]
def _create_data_loader(self, data_transform, data_partition, sample_rate=None):
        sample_rate = sample_rate or self.hparams.sample_rate
        dataset = SliceData(
            root=self.hparams.data_path / f'{self.hparams.challenge}_{data_partition}',
            transform=data_transform,
            sample_rate=sample_rate,
            challenge=self.hparams.challenge
        )

        is_train = (data_partition == 'train')
        if is_train:
            sampler = DistributedSampler(dataset)
        else:
            sampler = VolumeSampler(dataset)

        return DataLoader(
            dataset=dataset,
            batch_size=self.hparams.batch_size,
            num_workers=4,
            pin_memory=False,
            drop_last=is_train,
            sampler=sampler,
        ) 
开发者ID:facebookresearch,项目名称:fastMRI,代码行数:25,代码来源:mri_model.py

示例2: train

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import DistributedSampler [as 别名]
def train(self,
              data_loader: Iterable or DataLoader,
              mode: str = TRAIN):
        """ Training the model for an epoch.

        :param data_loader:
        :param mode: Name of this loop. Default is `train`. Passed to callbacks.
        """

        self._is_train = True
        self._epoch += 1
        self.model.train()
        if hasattr(self.loss_f, "train"):
            self.loss_f.train()
        with torch.enable_grad():
            self._loop(data_loader, mode=mode)

        if self.scheduler is not None and self.update_scheduler_by_epoch:
            self.scheduler.step()

        if isinstance(data_loader, DataLoader) and isinstance(data_loader.sampler, DistributedSampler):
            data_loader.sampler.set_epoch(self.epoch) 
开发者ID:moskomule,项目名称:homura,代码行数:24,代码来源:trainers.py

示例3: __init__

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import DistributedSampler [as 别名]
def __init__(self, root, batch_size, train=True):
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
        ])

        dataset = datasets.MNIST(root, train=train, transform=transform, download=True)
        sampler = None
        if train and distributed_is_initialized():
            sampler = data.DistributedSampler(dataset)

        super(MNISTDataLoader, self).__init__(
            dataset,
            batch_size=batch_size,
            shuffle=(sampler is None),
            sampler=sampler,
        ) 
开发者ID:narumiruna,项目名称:pytorch-distributed-example,代码行数:19,代码来源:main.py

示例4: load_datasets

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import DistributedSampler [as 别名]
def load_datasets(data_dir, real_dataset, fake_dataset, tokenizer, batch_size,
                  max_sequence_length, random_sequence_length, epoch_size=None, token_dropout=None, seed=None):
    if fake_dataset == 'TWO':
        download(real_dataset, 'xl-1542M', 'xl-1542M-nucleus', data_dir=data_dir)
    elif fake_dataset == 'THREE':
        download(real_dataset, 'xl-1542M', 'xl-1542M-k40', 'xl-1542M-nucleus', data_dir=data_dir)
    else:
        download(real_dataset, fake_dataset, data_dir=data_dir)

    real_corpus = Corpus(real_dataset, data_dir=data_dir)

    if fake_dataset == "TWO":
        real_train, real_valid = real_corpus.train * 2, real_corpus.valid * 2
        fake_corpora = [Corpus(name, data_dir=data_dir) for name in ['xl-1542M', 'xl-1542M-nucleus']]
        fake_train = sum([corpus.train for corpus in fake_corpora], [])
        fake_valid = sum([corpus.valid for corpus in fake_corpora], [])
    elif fake_dataset == "THREE":
        real_train, real_valid = real_corpus.train * 3, real_corpus.valid * 3
        fake_corpora = [Corpus(name, data_dir=data_dir) for name in
                        ['xl-1542M', 'xl-1542M-k40', 'xl-1542M-nucleus']]
        fake_train = sum([corpus.train for corpus in fake_corpora], [])
        fake_valid = sum([corpus.valid for corpus in fake_corpora], [])
    else:
        fake_corpus = Corpus(fake_dataset, data_dir=data_dir)

        real_train, real_valid = real_corpus.train, real_corpus.valid
        fake_train, fake_valid = fake_corpus.train, fake_corpus.valid

    Sampler = DistributedSampler if distributed() and dist.get_world_size() > 1 else RandomSampler

    min_sequence_length = 10 if random_sequence_length else None
    train_dataset = EncodedDataset(real_train, fake_train, tokenizer, max_sequence_length, min_sequence_length,
                                   epoch_size, token_dropout, seed)
    train_loader = DataLoader(train_dataset, batch_size, sampler=Sampler(train_dataset), num_workers=0)

    validation_dataset = EncodedDataset(real_valid, fake_valid, tokenizer)
    validation_loader = DataLoader(validation_dataset, batch_size=1, sampler=Sampler(validation_dataset))

    return train_loader, validation_loader 
开发者ID:openai,项目名称:gpt-2-output-dataset,代码行数:41,代码来源:train.py

示例5: normal_dataloader

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import DistributedSampler [as 别名]
def normal_dataloader(train_dataset,valid_dataset,config,arg):
    
    num_workers = config.num_workers
    pin_memory = True
    logger.info("\n num_workers of dataloader is {}".format(num_workers))

    if arg.distributed:
        train_dist_sampler =  DistributedSampler(train_dataset)
        #valid_sampler_dist =  DistributedSampler(valid_dataset)   
    else:
        train_dist_sampler = None

    train_queue = torch.utils.data.DataLoader(train_dataset, 
                    batch_size = config.train.batchsize, 
                    num_workers = num_workers ,   
                    pin_memory=pin_memory , 
                    shuffle = (train_dist_sampler is None), 
                    sampler= train_dist_sampler
                    )
    valid_queue = torch.utils.data.DataLoader(valid_dataset, 
                    batch_size = config.test.batchsize, 
                    num_workers = num_workers ,   
                    pin_memory=pin_memory , 
                    shuffle = False, )

    if arg.distributed:
        return train_queue ,None, valid_queue ,train_dist_sampler
    else:
        return train_queue ,None, valid_queue 
开发者ID:yangsenius,项目名称:PoseNFS,代码行数:31,代码来源:dataloader.py

示例6: _force_make_distributed_loader

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import DistributedSampler [as 别名]
def _force_make_distributed_loader(loader: DataLoader) -> DataLoader:
    """
    Transfers loader to distributed mode. Experimental feature.

    Args:
        loader (DataLoader): pytorch dataloder

    Returns:
        DataLoader: pytorch dataloder with distributed sampler.
    """
    sampler = (
        DistributedSampler(dataset=loader.dataset)
        if getattr(loader, "sampler", None) is not None
        else DistributedSamplerWrapper(sampler=loader.sampler)
    )
    loader = DataLoader(
        dataset=copy(loader.dataset),
        batch_size=loader.batch_size,
        # shuffle=loader.shuffle,
        sampler=sampler,
        # batch_sampler=loader.batch_sampler,
        num_workers=loader.num_workers,
        # collate_fn=loader.collate_fn,
        pin_memory=loader.pin_memory,
        drop_last=loader.drop_last,
    )
    return loader 
开发者ID:catalyst-team,项目名称:catalyst,代码行数:29,代码来源:data.py

示例7: validate_loaders

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import DistributedSampler [as 别名]
def validate_loaders(loaders: Dict[str, DataLoader]) -> Dict[str, DataLoader]:
    """
    Check pytorch dataloaders for distributed setup.
    Transfers them to distirbuted mode if necessary.
    (Experimental feature)

    Args:
        loaders (Dict[str, DataLoader]): dictionery with pytorch dataloaders

    Returns:
        Dict[str, DataLoader]: dictionery
            with pytorch dataloaders (with distributed samplers if necessary)
    """
    rank = get_rank()
    if rank >= 0:
        for key, value in loaders.items():
            if not isinstance(
                value.sampler, (DistributedSampler, DistributedSamplerWrapper)
            ):
                warnings.warn(
                    "With distributed training setup, "
                    "you need ``DistributedSampler`` for your ``DataLoader``."
                    "Transferring to distributed mode. (Experimental feature)"
                )
                loaders[key] = _force_make_distributed_loader(value)
    return loaders 
开发者ID:catalyst-team,项目名称:catalyst,代码行数:28,代码来源:data.py

示例8: run

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import DistributedSampler [as 别名]
def run(self,
            train_loader: Iterable or DataLoader,
            val_loaders: Iterable or DataLoader or Dict[str, Iterable or DataLoader],
            total_iterations: int,
            val_intervals: int):

        """ Train the model for a given iterations. This module is almost equal to ::

            for ep in range(total_iterations):
                trainer.train(train_loader)
                for k, v in val_loaders.items():
                    trainer.test(v, k)

        :param train_loader:
        :param val_loaders:
        :param total_iterations:
        :param val_intervals:
        :return:
        """

        class ProxyLoader(object):
            def __init__(self, loader):
                self.loader = loader

            def __len__(self):
                return val_intervals

            def __iter__(self):
                counter = 0
                while True:
                    for data in self.loader:
                        if counter == val_intervals:
                            return  # from python 3.7, this is valid
                        yield data
                        counter += 1

        train_loader = ProxyLoader(train_loader)
        if not isinstance(val_loaders, Dict) and (isinstance(val_loaders, Iterable) or
                                                  isinstance(val_loaders, DataLoader)):
            val_loaders = {'val': val_loaders}

        for ep in range(total_iterations // val_intervals):
            self.train(train_loader)
            if isinstance(train_loader.loader, DataLoader) \
                and isinstance(train_loader.loader.sampler, DistributedSampler):
                train_loader.loader.sampler.set_epoch(self.epoch)
            for name, loader in val_loaders.items():
                self.test(loader, name) 
开发者ID:moskomule,项目名称:homura,代码行数:50,代码来源:trainers.py


注:本文中的torch.utils.data.DistributedSampler方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。