当前位置: 首页>>代码示例>>Python>>正文


Python sampler.RandomSampler方法代码示例

本文整理汇总了Python中torch.utils.data.sampler.RandomSampler方法的典型用法代码示例。如果您正苦于以下问题:Python sampler.RandomSampler方法的具体用法?Python sampler.RandomSampler怎么用?Python sampler.RandomSampler使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.utils.data.sampler的用法示例。


在下文中一共展示了sampler.RandomSampler方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_len

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import RandomSampler [as 别名]
def test_len(self):
        batch_size = 3
        drop_uneven = True
        dataset = [i for i in range(10)]
        group_ids = [random.randint(0, 1) for _ in dataset]
        sampler = RandomSampler(dataset)

        batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)
        result = list(batch_sampler)
        self.assertEqual(len(result), len(batch_sampler))
        self.assertEqual(len(result), len(batch_sampler))

        batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)
        batch_sampler_len = len(batch_sampler)
        result = list(batch_sampler)
        self.assertEqual(len(result), batch_sampler_len)
        self.assertEqual(len(result), len(batch_sampler)) 
开发者ID:Res2Net,项目名称:Res2Net-maskrcnn,代码行数:19,代码来源:test_data_samplers.py

示例2: __iter__

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import RandomSampler [as 别名]
def __iter__(self):
        src_list = list()
        tgt_list = list()
        # sampler is RandomSampler
        for i in self.sampler:
            self.count += 1
            src, tgt = self.sampler.data_source[i]
            src_list.append(src)
            tgt_list.append(tgt)
            if self.count % self.batch_size == 0:
                assert len(src_list) == self.batch_size
                src = rnn.pad_sequence(src_list, batch_first=True, padding_value=self.pad_id)
                tgt = rnn.pad_sequence(tgt_list, batch_first=True, padding_value=self.pad_id)
                src_list.clear()
                tgt_list.clear()
                yield src, tgt 
开发者ID:reppy4620,项目名称:Dialog,代码行数:18,代码来源:loader.py

示例3: build_train_sampler

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import RandomSampler [as 别名]
def build_train_sampler(
    data_source, train_sampler, batch_size=32, num_instances=4, **kwargs
):
    """Builds a training sampler.

    Args:
        data_source (list): contains tuples of (img_path(s), pid, camid).
        train_sampler (str): sampler name (default: ``RandomSampler``).
        batch_size (int, optional): batch size. Default is 32.
        num_instances (int, optional): number of instances per identity in a
            batch (when using ``RandomIdentitySampler``). Default is 4.
    """
    assert train_sampler in AVAI_SAMPLERS, \
        'train_sampler must be one of {}, but got {}'.format(AVAI_SAMPLERS, train_sampler)

    if train_sampler == 'RandomIdentitySampler':
        sampler = RandomIdentitySampler(data_source, batch_size, num_instances)

    elif train_sampler == 'SequentialSampler':
        sampler = SequentialSampler(data_source)

    elif train_sampler == 'RandomSampler':
        sampler = RandomSampler(data_source)

    return sampler 
开发者ID:KaiyangZhou,项目名称:deep-person-reid,代码行数:27,代码来源:sampler.py

示例4: shuffle_dataset

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import RandomSampler [as 别名]
def shuffle_dataset(loader, cur_epoch):
    """"
    Shuffles the data.
    Args:
        loader (loader): data loader to perform shuffle.
        cur_epoch (int): number of the current epoch.
    """
    sampler = (
        loader.batch_sampler.sampler
        if isinstance(loader.batch_sampler, ShortCycleBatchSampler)
        else loader.sampler
    )
    assert isinstance(
        sampler, (RandomSampler, DistributedSampler)
    ), "Sampler type '{}' not supported".format(type(sampler))
    # RandomSampler handles shuffling automatically
    if isinstance(sampler, DistributedSampler):
        # DistributedSampler shuffles data based on epoch
        sampler.set_epoch(cur_epoch) 
开发者ID:facebookresearch,项目名称:SlowFast,代码行数:21,代码来源:loader.py

示例5: Our_Dataloader

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import RandomSampler [as 别名]
def Our_Dataloader(dataset,batch_size,shuffle=True,num_workers=2,drop_last=True,max_iteration=100000000):
    """
    几近无限迭代器,迭代次数为1亿次,每次迭代输出一个批次的数据.
    :param dataset:         数据集
    :param batch_size:      批次数
    :param max_iteration:   迭代的总次数,默认1亿次,具体迭代次数,在取数据时进行判断会更为灵活
    :param shuffle:
    :param num_workers:
    :param drop_last:
    :return:
    """
    if shuffle:
        sampler = RandomSampler(dataset)        # 随机采样器
    else:
        sampler = SequentialSampler(dataset)    # 顺序采样器
    batch_sampler = BatchSampler_Our(sampler=sampler,
                                     batch_size=batch_size,
                                     max_iteration=max_iteration,
                                     drop_last=drop_last)
    loader = DataLoader(dataset=dataset,batch_sampler=batch_sampler,num_workers=num_workers,collate_fn=BatchCollator(is_train=dataset.is_train))
    return loader 
开发者ID:yatengLG,项目名称:SSD-Pytorch,代码行数:23,代码来源:Dataloader.py

示例6: make_data_loader

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import RandomSampler [as 别名]
def make_data_loader(args, corpus, train_time=False):

    dataset = data.MultiCorpusDataset(
        args,
        corpus,
        args.word_dict,
        args.feature_dict,
        single_answer=False,
        para_mode=args.para_mode,
        train_time=train_time
    )
    sampler = SequentialSampler(dataset) if not train_time else RandomSampler(dataset)
    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        sampler=sampler,
        num_workers=args.data_workers,
        collate_fn=vector.batchify(args, args.para_mode, train_time=train_time),
        pin_memory=True
    )

    return loader 
开发者ID:rajarshd,项目名称:Multi-Step-Reasoning,代码行数:24,代码来源:train_para_encoder.py

示例7: make_data_loader

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import RandomSampler [as 别名]
def make_data_loader(args, exs, train_time=False):

    dataset = data.ReaderDataset(
        args,
        exs,
        args.word_dict,
        args.feature_dict,
        single_answer=False,
        train_time=train_time
    )
    sampler = SequentialSampler(dataset) if not train_time else RandomSampler(dataset)
    batch_size = args.batch_size if train_time else args.test_batch_size
    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=sampler,
        num_workers=0,
        collate_fn=vector.batchify,
        pin_memory=True
    )

    return loader 
开发者ID:rajarshd,项目名称:Multi-Step-Reasoning,代码行数:24,代码来源:train.py

示例8: __init__

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import RandomSampler [as 别名]
def __init__(self, origsilo, trainset, devset, testset):
        self.tensor_names = origsilo.tensor_names
        self.data = {"train": trainset, "dev": devset, "test": testset}
        self.processor = origsilo.processor
        self.batch_size = origsilo.batch_size
        # should not be necessary, xval makes no sense with huge data
        # sampler_train = DistributedSampler(self.data["train"])
        sampler_train = RandomSampler(trainset)

        self.data_loader_train = NamedDataLoader(
            dataset=trainset,
            sampler=sampler_train,
            batch_size=self.batch_size,
            tensor_names=self.tensor_names,
        )
        self.data_loader_dev = NamedDataLoader(
            dataset=devset,
            sampler=SequentialSampler(devset),
            batch_size=self.batch_size,
            tensor_names=self.tensor_names,
        )
        self.data_loader_test = NamedDataLoader(
            dataset=testset,
            sampler=SequentialSampler(testset),
            batch_size=self.batch_size,
            tensor_names=self.tensor_names,
        )
        self.loaders = {
            "train": self.data_loader_train,
            "dev": self.data_loader_dev,
            "test": self.data_loader_test,
        } 
开发者ID:deepset-ai,项目名称:FARM,代码行数:34,代码来源:data_silo.py

示例9: __init__

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import RandomSampler [as 别名]
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
                 num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.collate_fn = collate_fn
        self.pin_memory = pin_memory
        self.drop_last = drop_last

        if batch_sampler is not None:
            if batch_size > 1 or shuffle or sampler is not None or drop_last:
                raise ValueError('batch_sampler is mutually exclusive with '
                                 'batch_size, shuffle, sampler, and drop_last')

        if sampler is not None and shuffle:
            raise ValueError('sampler is mutually exclusive with shuffle')

        if batch_sampler is None:
            if sampler is None:
                if shuffle:
                    sampler = RandomSampler(dataset)
                else:
                    sampler = SequentialSampler(dataset)
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)

        self.sampler = sampler
        self.batch_sampler = batch_sampler 
开发者ID:ignacio-rocco,项目名称:weakalign,代码行数:29,代码来源:dataloader.py

示例10: __init__

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import RandomSampler [as 别名]
def __init__(self, pairs, shuffle=False, batch_size=1, drop_last=False):
        if shuffle:
            self.sampler = RandomSampler(pairs)
        else:
            self.sampler = SequentialSampler(pairs)
        self.batch_size = batch_size
        self.drop_last = drop_last 
开发者ID:shahsohil,项目名称:DCC,代码行数:9,代码来源:custom_data.py

示例11: __init__

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import RandomSampler [as 别名]
def __init__(self, data: Dataset, pad_id: int):
        super().__init__(RandomSampler(data), Config.batch_size, True)
        self.pad_id = pad_id
        self.count = 0 
开发者ID:reppy4620,项目名称:Dialog,代码行数:6,代码来源:loader.py

示例12: get_data

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import RandomSampler [as 别名]
def get_data(split_id, data_dir, img_size, scale_size, batch_size,
             workers, train_list, val_list):
    root = data_dir

    normalizer = T.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])  # RGB imagenet

    # with data augmentation 
    train_transformer = T.Compose([
        T.RandomResizedCrop(img_size),
        T.RandomHorizontalFlip(),
        T.ToTensor(),   # [0, 255] to [0.0, 1.0]
        normalizer,     #  normalize each channel of the input
     ])

    test_transformer = T.Compose([
        T.Resize(scale_size),
        T.CenterCrop(img_size),
        T.ToTensor(),
        normalizer,
    ])

    train_loader = DataLoader(
        Preprocessor(train_list, root=root,
                     transform=train_transformer),
        batch_size=batch_size, num_workers=workers,
        sampler=RandomSampler(train_list),
        pin_memory=True, drop_last=False)

    val_loader = DataLoader(
        Preprocessor(val_list, root=root,
                     transform=test_transformer),
        batch_size=batch_size, num_workers=workers,
        shuffle=False, pin_memory=True)

    return train_loader, val_loader 
开发者ID:aliyun,项目名称:alibabacloud-quantization-networks,代码行数:38,代码来源:main.py

示例13: get_data

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import RandomSampler [as 别名]
def get_data(split_id, data_dir, img_size, scale_size, batch_size,
             workers, train_list, val_list):
    root = data_dir

    normalizer = T.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])  # RGB imagenet
    # with data augmentation 
    train_transformer = T.Compose([
        T.Resize(scale_size),
        T.RandomCrop(img_size),
        #T.RandomResizedCrop(img_size),
        T.RandomHorizontalFlip(),
        T.ToTensor(),   # [0, 255] to [0.0, 1.0]
        normalizer,     #  normalize each channel of the input
     ])

    test_transformer = T.Compose([
        T.Resize(scale_size),
        T.CenterCrop(img_size),
        T.ToTensor(),
        normalizer,
    ])

    train_loader = DataLoader(
        Preprocessor(train_list, root=root,
                     transform=train_transformer),
        batch_size=batch_size, num_workers=workers,
        sampler=RandomSampler(train_list),
        pin_memory=True, drop_last=False)

    val_loader = DataLoader(
        Preprocessor(val_list, root=root,
                     transform=test_transformer),
        batch_size=batch_size, num_workers=workers,
        shuffle=False, pin_memory=True)

    return train_loader, val_loader 
开发者ID:aliyun,项目名称:alibabacloud-quantization-networks,代码行数:39,代码来源:quan_weight_main.py

示例14: get_data

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import RandomSampler [as 别名]
def get_data(split_id, data_dir, img_size, scale_size, batch_size,
             workers, train_list, val_list):
    root = data_dir

    normalizer = T.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])  # RGB imagenet
    # with data augmentation 
    train_transformer = T.Compose([
         T.Resize(scale_size),
         T.RandomCrop(img_size),
         T.RandomHorizontalFlip(),
         T.ToTensor(),   # [0, 255] to [0.0, 1.0]
         normalizer,     #  normalize each channel of the input
     ])

    test_transformer = T.Compose([
        T.Resize(scale_size),
        T.CenterCrop(img_size),
        T.ToTensor(),
        normalizer,
    ])

    train_loader = DataLoader(
        Preprocessor(train_list, root=root,
                     transform=train_transformer),
        batch_size=batch_size, num_workers=workers,
        sampler=RandomSampler(train_list),
        pin_memory=True, drop_last=False)

    val_loader = DataLoader(
        Preprocessor(val_list, root=root,
                     transform=test_transformer),
        batch_size=batch_size, num_workers=workers,
        shuffle=False, pin_memory=True)

    return train_loader, val_loader 
开发者ID:aliyun,项目名称:alibabacloud-quantization-networks,代码行数:38,代码来源:quan_all_main.py

示例15: loader_setup

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import RandomSampler [as 别名]
def loader_setup(self, args):
        logging.info("Creating samplers ...")
        train_sampler = RandomSampler(self.train_data)
        dev_sampler = RandomSampler(self.dev_data)


        logging.info("Creating data loaders ...")
        self.train_loader = DataLoader(
            dataset=self.train_data,
            batch_size=args.batch_size,
            num_workers=args.workers,
            pin_memory=args.pin_memory,
            sampler=train_sampler,
        )
        self.dev_loader = DataLoader(
            dataset=self.dev_data,
            batch_size=args.batch_size,
            num_workers=args.workers,
            pin_memory=args.pin_memory,
            sampler=dev_sampler,
        )

        self.display_loader = DataLoader(
            dataset=self.display_data,
            batch_size=args.batch_size,
            num_workers=args.workers,
            pin_memory=args.pin_memory,
            drop_last=False,
        )

        logging.debug("Determining batches ...")
        self.nbatches = len(self.train_loader)
        logging.info("Train Loader created, batches: {}".format(self.nbatches)) 
开发者ID:facebookresearch,项目名称:fastMRI,代码行数:35,代码来源:base_trainer.py


注:本文中的torch.utils.data.sampler.RandomSampler方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。