本文整理汇总了Python中torch.utils.data.sampler.SequentialSampler方法的典型用法代码示例。如果您正苦于以下问题:Python sampler.SequentialSampler方法的具体用法?Python sampler.SequentialSampler怎么用?Python sampler.SequentialSampler使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.utils.data.sampler
的用法示例。
在下文中一共展示了sampler.SequentialSampler方法的9个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test_respect_order
# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import SequentialSampler [as 别名]
def test_respect_order(self):
drop_uneven = False
dataset = [i for i in range(10)]
group_ids = [0, 0, 1, 0, 1, 1, 0, 1, 1, 0]
sampler = SequentialSampler(dataset)
expected = [
[[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]],
[[0, 1, 3], [2, 4, 5], [6, 9], [7, 8]],
[[0, 1, 3, 6], [2, 4, 5, 7], [8], [9]],
]
for idx, batch_size in enumerate([1, 3, 4]):
batch_sampler = GroupedBatchSampler(
sampler, group_ids, batch_size, drop_uneven
)
result = list(batch_sampler)
self.assertEqual(result, expected[idx])
示例2: test_number_of_iters_and_elements
# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import SequentialSampler [as 别名]
def test_number_of_iters_and_elements(self):
for batch_size in [2, 3, 4]:
for num_iterations in [4, 10, 20]:
for drop_last in [False, True]:
dataset = [i for i in range(10)]
sampler = SequentialSampler(dataset)
batch_sampler = BatchSampler(
sampler, batch_size, drop_last=drop_last
)
iter_sampler = IterationBasedBatchSampler(
batch_sampler, num_iterations
)
assert len(iter_sampler) == num_iterations
for i, batch in enumerate(iter_sampler):
start = (i % len(batch_sampler)) * batch_size
end = min(start + batch_size, len(dataset))
expected = [x for x in range(start, end)]
self.assertEqual(batch, expected)
示例3: test_distributed_batch_sampler
# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import SequentialSampler [as 别名]
def test_distributed_batch_sampler():
sampler = SequentialSampler(list(range(15)))
batch_sampler = BatchSampler(sampler, 10, False)
distributed_sampler = DistributedBatchSampler(batch_sampler, num_replicas=4, rank=0)
assert list(distributed_sampler) == [[0, 4, 8], [10, 14]]
assert len(distributed_sampler) == 2
distributed_sampler = DistributedBatchSampler(batch_sampler, num_replicas=4, rank=1)
assert list(distributed_sampler) == [[1, 5, 9], [11]]
assert len(distributed_sampler) == 2
distributed_sampler = DistributedBatchSampler(batch_sampler, num_replicas=4, rank=2)
assert list(distributed_sampler) == [[2, 6], [12]]
assert len(distributed_sampler) == 2
distributed_sampler = DistributedBatchSampler(batch_sampler, num_replicas=4, rank=3)
assert list(distributed_sampler) == [[3, 7], [13]]
assert len(distributed_sampler) == 2
示例4: get_dataloader
# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import SequentialSampler [as 别名]
def get_dataloader(self, batch_size, type, num_workers, shuffle):
"""
get dataloader on train or dev dataset
:param batch_size:
:param type: 'train' or 'dev'
:return:
"""
data = self._data[type]
dataset = CQA_Dataset(data['context'],
data['question'],
data['answer_range'],
self.meta_data,
self.global_config['preprocess'])
if shuffle:
sampler = SortedBatchSampler(dataset.get_lengths(), batch_size)
else:
sampler = SequentialSampler(dataset)
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=batch_size,
sampler=sampler,
collate_fn=self.collect_fun,
num_workers=num_workers)
return dataloader
示例5: get_data_loader
# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import SequentialSampler [as 别名]
def get_data_loader(X_in, y_in, batch_size, extended_batch_sampler=True, epoch_size=25000, upsample=False, seed=42):
""" Returns a dataloader that enables larger epochs on small datasets and
has upsampling functionality.
# Arguments:
X_in: Inputs of the given dataset.
y_in: Outputs of the given dataset.
batch_size: Batch size.
epoch_size: Number of samples in an epoch.
upsample: Whether upsampling should be done. This flag should only be
set on binary class problems.
# Returns:
DataLoader.
"""
dataset = DeepMojiDataset(X_in, y_in)
if extended_batch_sampler:
batch_sampler = DeepMojiBatchSampler(y_in, batch_size, epoch_size=epoch_size, upsample=upsample, seed=seed)
else:
batch_sampler = BatchSampler(SequentialSampler(y_in), batch_size, drop_last=False)
return DataLoader(dataset, batch_sampler=batch_sampler, num_workers=0)
示例6: build_train_sampler
# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import SequentialSampler [as 别名]
def build_train_sampler(
data_source, train_sampler, batch_size=32, num_instances=4, **kwargs
):
"""Builds a training sampler.
Args:
data_source (list): contains tuples of (img_path(s), pid, camid).
train_sampler (str): sampler name (default: ``RandomSampler``).
batch_size (int, optional): batch size. Default is 32.
num_instances (int, optional): number of instances per identity in a
batch (when using ``RandomIdentitySampler``). Default is 4.
"""
assert train_sampler in AVAI_SAMPLERS, \
'train_sampler must be one of {}, but got {}'.format(AVAI_SAMPLERS, train_sampler)
if train_sampler == 'RandomIdentitySampler':
sampler = RandomIdentitySampler(data_source, batch_size, num_instances)
elif train_sampler == 'SequentialSampler':
sampler = SequentialSampler(data_source)
elif train_sampler == 'RandomSampler':
sampler = RandomSampler(data_source)
return sampler
示例7: Our_Dataloader
# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import SequentialSampler [as 别名]
def Our_Dataloader(dataset,batch_size,shuffle=True,num_workers=2,drop_last=True,max_iteration=100000000):
"""
几近无限迭代器,迭代次数为1亿次,每次迭代输出一个批次的数据.
:param dataset: 数据集
:param batch_size: 批次数
:param max_iteration: 迭代的总次数,默认1亿次,具体迭代次数,在取数据时进行判断会更为灵活
:param shuffle:
:param num_workers:
:param drop_last:
:return:
"""
if shuffle:
sampler = RandomSampler(dataset) # 随机采样器
else:
sampler = SequentialSampler(dataset) # 顺序采样器
batch_sampler = BatchSampler_Our(sampler=sampler,
batch_size=batch_size,
max_iteration=max_iteration,
drop_last=drop_last)
loader = DataLoader(dataset=dataset,batch_sampler=batch_sampler,num_workers=num_workers,collate_fn=BatchCollator(is_train=dataset.is_train))
return loader
示例8: test_respect_order_simple
# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import SequentialSampler [as 别名]
def test_respect_order_simple(self):
drop_uneven = False
dataset = [i for i in range(40)]
group_ids = [i // 10 for i in dataset]
sampler = SequentialSampler(dataset)
for batch_size in [1, 3, 5, 6]:
batch_sampler = GroupedBatchSampler(
sampler, group_ids, batch_size, drop_uneven
)
result = list(batch_sampler)
merged_result = list(itertools.chain.from_iterable(result))
self.assertEqual(merged_result, dataset)
示例9: test_respect_order_drop_uneven
# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import SequentialSampler [as 别名]
def test_respect_order_drop_uneven(self):
batch_size = 3
drop_uneven = True
dataset = [i for i in range(10)]
group_ids = [0, 0, 1, 0, 1, 1, 0, 1, 1, 0]
sampler = SequentialSampler(dataset)
batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)
result = list(batch_sampler)
expected = [[0, 1, 3], [2, 4, 5]]
self.assertEqual(result, expected)