本文整理汇总了Python中torch.utils.data.Sampler方法的典型用法代码示例。如果您正苦于以下问题:Python data.Sampler方法的具体用法?Python data.Sampler怎么用?Python data.Sampler使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.utils.data
的用法示例。
在下文中一共展示了data.Sampler方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: __iter__
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import Sampler [as 别名]
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
if self.shuffle:
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
if isinstance(self.dataset, Sampler):
orig_indices = list(iter(self.dataset))
indices = [orig_indices[i] for i in indices]
return iter(indices)
示例2: __init__
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import Sampler [as 别名]
def __init__(self,
dataset: torch.utils.data.Dataset,
episodes_per_epoch: int = None,
n: int = None,
k: int = None,
q: int = None,
num_tasks: int = 1,
fixed_tasks: List[Iterable[int]] = None):
"""PyTorch Sampler subclass that generates batches of n-shot, k-way, q-query tasks.
Each n-shot task contains a "support set" of `k` sets of `n` samples and a "query set" of `k` sets
of `q` samples. The support set and the query set are all grouped into one Tensor such that the first n * k
samples are from the support set while the remaining q * k samples are from the query set.
The support and query sets are sampled such that they are disjoint i.e. do not contain overlapping samples.
# Arguments
dataset: Instance of torch.utils.data.Dataset from which to draw samples
episodes_per_epoch: Arbitrary number of batches of n-shot tasks to generate in one epoch
n_shot: int. Number of samples for each class in the n-shot classification tasks.
k_way: int. Number of classes in the n-shot classification tasks.
q_queries: int. Number query samples for each class in the n-shot classification tasks.
num_tasks: Number of n-shot tasks to group into a single batch
fixed_tasks: If this argument is specified this Sampler will always generate tasks from
the specified classes
"""
super(NShotTaskSampler, self).__init__(dataset)
self.episodes_per_epoch = episodes_per_epoch
self.dataset = dataset
if num_tasks < 1:
raise ValueError('num_tasks must be > 1.')
self.num_tasks = num_tasks
# TODO: Raise errors if initialise badly
self.k = k
self.n = n
self.q = q
self.fixed_tasks = fixed_tasks
self.i_task = 0
示例3: __init__
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import Sampler [as 别名]
def __init__(self, sampler, batch_size, max_iteration=100000000, drop_last=True):
"""
数据加载,默认循环加载1亿次,几近无限迭代.
每次迭代输出一个批次的数据.
:param sampler: 采样器,传入 不同采样器 实现 不同的采样策略, RandomSampler随机采样,SequentialSampler顺序采样
:param batch_size: 批次大小
:param max_iteration: 迭代次数
:param drop_last: 是否弃掉最后的不够一批次的数据。True则弃掉;False保留,并返回,但是这一批次会小于指定批次大小。
"""
if not isinstance(sampler, Sampler):
raise ValueError("sampler should be an instance of "
"torch.utils.data.Sampler, but got sampler={}"
.format(sampler))
if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \
batch_size <= 0:
raise ValueError("batch_size should be a positive integer value, "
"but got batch_size={}".format(batch_size))
if not isinstance(max_iteration, _int_classes) or isinstance(max_iteration, bool) or \
max_iteration <= 0:
raise ValueError("max_iter should be a positive integer value, "
"but got max_iter={}".format(max_iteration))
if not isinstance(drop_last, bool):
raise ValueError("drop_last should be a boolean value, but got "
"drop_last={}".format(drop_last))
self.sampler = sampler
self.batch_size = batch_size
self.max_iteration = max_iteration
self.drop_last = drop_last
示例4: __init__
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import Sampler [as 别名]
def __init__(self, sampler: Sampler):
"""
Args:
sampler (Sampler): @TODO: Docs. Contribution is welcome
"""
self.sampler = sampler
self.sampler_list = None