当前位置: 首页>>代码示例>>Python>>正文


Python sampler.SequentialSampler方法代码示例

本文整理汇总了Python中torch.utils.data.sampler.SequentialSampler方法的典型用法代码示例。如果您正苦于以下问题:Python sampler.SequentialSampler方法的具体用法?Python sampler.SequentialSampler怎么用?Python sampler.SequentialSampler使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.utils.data.sampler的用法示例。


在下文中一共展示了sampler.SequentialSampler方法的9个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_respect_order

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import SequentialSampler [as 别名]
def test_respect_order(self):
        drop_uneven = False
        dataset = [i for i in range(10)]
        group_ids = [0, 0, 1, 0, 1, 1, 0, 1, 1, 0]
        sampler = SequentialSampler(dataset)

        expected = [
            [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]],
            [[0, 1, 3], [2, 4, 5], [6, 9], [7, 8]],
            [[0, 1, 3, 6], [2, 4, 5, 7], [8], [9]],
        ]

        for idx, batch_size in enumerate([1, 3, 4]):
            batch_sampler = GroupedBatchSampler(
                sampler, group_ids, batch_size, drop_uneven
            )
            result = list(batch_sampler)
            self.assertEqual(result, expected[idx]) 
开发者ID:Res2Net,项目名称:Res2Net-maskrcnn,代码行数:20,代码来源:test_data_samplers.py

示例2: test_number_of_iters_and_elements

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import SequentialSampler [as 别名]
def test_number_of_iters_and_elements(self):
        for batch_size in [2, 3, 4]:
            for num_iterations in [4, 10, 20]:
                for drop_last in [False, True]:
                    dataset = [i for i in range(10)]
                    sampler = SequentialSampler(dataset)
                    batch_sampler = BatchSampler(
                        sampler, batch_size, drop_last=drop_last
                    )

                    iter_sampler = IterationBasedBatchSampler(
                        batch_sampler, num_iterations
                    )
                    assert len(iter_sampler) == num_iterations
                    for i, batch in enumerate(iter_sampler):
                        start = (i % len(batch_sampler)) * batch_size
                        end = min(start + batch_size, len(dataset))
                        expected = [x for x in range(start, end)]
                        self.assertEqual(batch, expected) 
开发者ID:Res2Net,项目名称:Res2Net-maskrcnn,代码行数:21,代码来源:test_data_samplers.py

示例3: test_distributed_batch_sampler

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import SequentialSampler [as 别名]
def test_distributed_batch_sampler():
    sampler = SequentialSampler(list(range(15)))
    batch_sampler = BatchSampler(sampler, 10, False)

    distributed_sampler = DistributedBatchSampler(batch_sampler, num_replicas=4, rank=0)
    assert list(distributed_sampler) == [[0, 4, 8], [10, 14]]
    assert len(distributed_sampler) == 2

    distributed_sampler = DistributedBatchSampler(batch_sampler, num_replicas=4, rank=1)
    assert list(distributed_sampler) == [[1, 5, 9], [11]]
    assert len(distributed_sampler) == 2

    distributed_sampler = DistributedBatchSampler(batch_sampler, num_replicas=4, rank=2)
    assert list(distributed_sampler) == [[2, 6], [12]]
    assert len(distributed_sampler) == 2

    distributed_sampler = DistributedBatchSampler(batch_sampler, num_replicas=4, rank=3)
    assert list(distributed_sampler) == [[3, 7], [13]]
    assert len(distributed_sampler) == 2 
开发者ID:PetrochukM,项目名称:PyTorch-NLP,代码行数:21,代码来源:test_distributed_batch_sampler.py

示例4: get_dataloader

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import SequentialSampler [as 别名]
def get_dataloader(self, batch_size, type, num_workers, shuffle):
        """
        get dataloader on train or dev dataset
        :param batch_size:
        :param type: 'train' or 'dev'
        :return:
        """
        data = self._data[type]
        dataset = CQA_Dataset(data['context'],
                              data['question'],
                              data['answer_range'],
                              self.meta_data,
                              self.global_config['preprocess'])
        if shuffle:
            sampler = SortedBatchSampler(dataset.get_lengths(), batch_size)
        else:
            sampler = SequentialSampler(dataset)
        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=batch_size,
                                                 sampler=sampler,
                                                 collate_fn=self.collect_fun,
                                                 num_workers=num_workers)
        return dataloader 
开发者ID:laddie132,项目名称:Match-LSTM,代码行数:25,代码来源:squad_dataset.py

示例5: get_data_loader

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import SequentialSampler [as 别名]
def get_data_loader(X_in, y_in, batch_size, extended_batch_sampler=True, epoch_size=25000, upsample=False, seed=42):
    """ Returns a dataloader that enables larger epochs on small datasets and
        has upsampling functionality.

    # Arguments:
        X_in: Inputs of the given dataset.
        y_in: Outputs of the given dataset.
        batch_size: Batch size.
        epoch_size: Number of samples in an epoch.
        upsample: Whether upsampling should be done. This flag should only be
            set on binary class problems.

    # Returns:
        DataLoader.
    """
    dataset = DeepMojiDataset(X_in, y_in)

    if extended_batch_sampler:
        batch_sampler = DeepMojiBatchSampler(y_in, batch_size, epoch_size=epoch_size, upsample=upsample, seed=seed)
    else:
        batch_sampler = BatchSampler(SequentialSampler(y_in), batch_size, drop_last=False)

    return DataLoader(dataset, batch_sampler=batch_sampler, num_workers=0) 
开发者ID:natashamjaques,项目名称:neural_chat,代码行数:25,代码来源:finetuning.py

示例6: build_train_sampler

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import SequentialSampler [as 别名]
def build_train_sampler(
    data_source, train_sampler, batch_size=32, num_instances=4, **kwargs
):
    """Builds a training sampler.

    Args:
        data_source (list): contains tuples of (img_path(s), pid, camid).
        train_sampler (str): sampler name (default: ``RandomSampler``).
        batch_size (int, optional): batch size. Default is 32.
        num_instances (int, optional): number of instances per identity in a
            batch (when using ``RandomIdentitySampler``). Default is 4.
    """
    assert train_sampler in AVAI_SAMPLERS, \
        'train_sampler must be one of {}, but got {}'.format(AVAI_SAMPLERS, train_sampler)

    if train_sampler == 'RandomIdentitySampler':
        sampler = RandomIdentitySampler(data_source, batch_size, num_instances)

    elif train_sampler == 'SequentialSampler':
        sampler = SequentialSampler(data_source)

    elif train_sampler == 'RandomSampler':
        sampler = RandomSampler(data_source)

    return sampler 
开发者ID:KaiyangZhou,项目名称:deep-person-reid,代码行数:27,代码来源:sampler.py

示例7: Our_Dataloader

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import SequentialSampler [as 别名]
def Our_Dataloader(dataset,batch_size,shuffle=True,num_workers=2,drop_last=True,max_iteration=100000000):
    """
    几近无限迭代器,迭代次数为1亿次,每次迭代输出一个批次的数据.
    :param dataset:         数据集
    :param batch_size:      批次数
    :param max_iteration:   迭代的总次数,默认1亿次,具体迭代次数,在取数据时进行判断会更为灵活
    :param shuffle:
    :param num_workers:
    :param drop_last:
    :return:
    """
    if shuffle:
        sampler = RandomSampler(dataset)        # 随机采样器
    else:
        sampler = SequentialSampler(dataset)    # 顺序采样器
    batch_sampler = BatchSampler_Our(sampler=sampler,
                                     batch_size=batch_size,
                                     max_iteration=max_iteration,
                                     drop_last=drop_last)
    loader = DataLoader(dataset=dataset,batch_sampler=batch_sampler,num_workers=num_workers,collate_fn=BatchCollator(is_train=dataset.is_train))
    return loader 
开发者ID:yatengLG,项目名称:SSD-Pytorch,代码行数:23,代码来源:Dataloader.py

示例8: test_respect_order_simple

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import SequentialSampler [as 别名]
def test_respect_order_simple(self):
        drop_uneven = False
        dataset = [i for i in range(40)]
        group_ids = [i // 10 for i in dataset]
        sampler = SequentialSampler(dataset)
        for batch_size in [1, 3, 5, 6]:
            batch_sampler = GroupedBatchSampler(
                sampler, group_ids, batch_size, drop_uneven
            )
            result = list(batch_sampler)
            merged_result = list(itertools.chain.from_iterable(result))
            self.assertEqual(merged_result, dataset) 
开发者ID:Res2Net,项目名称:Res2Net-maskrcnn,代码行数:14,代码来源:test_data_samplers.py

示例9: test_respect_order_drop_uneven

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import SequentialSampler [as 别名]
def test_respect_order_drop_uneven(self):
        batch_size = 3
        drop_uneven = True
        dataset = [i for i in range(10)]
        group_ids = [0, 0, 1, 0, 1, 1, 0, 1, 1, 0]
        sampler = SequentialSampler(dataset)
        batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)

        result = list(batch_sampler)

        expected = [[0, 1, 3], [2, 4, 5]]
        self.assertEqual(result, expected) 
开发者ID:Res2Net,项目名称:Res2Net-maskrcnn,代码行数:14,代码来源:test_data_samplers.py


注:本文中的torch.utils.data.sampler.SequentialSampler方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。