本文整理匯總了Python中torch.utils.data.Sampler方法的典型用法代碼示例。如果您正苦於以下問題:Python data.Sampler方法的具體用法?Python data.Sampler怎麽用?Python data.Sampler使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類torch.utils.data
的用法示例。
在下文中一共展示了data.Sampler方法的4個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: __iter__
# 需要導入模塊: from torch.utils import data [as 別名]
# 或者: from torch.utils.data import Sampler [as 別名]
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
if self.shuffle:
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
if isinstance(self.dataset, Sampler):
orig_indices = list(iter(self.dataset))
indices = [orig_indices[i] for i in indices]
return iter(indices)
示例2: __init__
# 需要導入模塊: from torch.utils import data [as 別名]
# 或者: from torch.utils.data import Sampler [as 別名]
def __init__(self,
dataset: torch.utils.data.Dataset,
episodes_per_epoch: int = None,
n: int = None,
k: int = None,
q: int = None,
num_tasks: int = 1,
fixed_tasks: List[Iterable[int]] = None):
"""PyTorch Sampler subclass that generates batches of n-shot, k-way, q-query tasks.
Each n-shot task contains a "support set" of `k` sets of `n` samples and a "query set" of `k` sets
of `q` samples. The support set and the query set are all grouped into one Tensor such that the first n * k
samples are from the support set while the remaining q * k samples are from the query set.
The support and query sets are sampled such that they are disjoint i.e. do not contain overlapping samples.
# Arguments
dataset: Instance of torch.utils.data.Dataset from which to draw samples
episodes_per_epoch: Arbitrary number of batches of n-shot tasks to generate in one epoch
n_shot: int. Number of samples for each class in the n-shot classification tasks.
k_way: int. Number of classes in the n-shot classification tasks.
q_queries: int. Number query samples for each class in the n-shot classification tasks.
num_tasks: Number of n-shot tasks to group into a single batch
fixed_tasks: If this argument is specified this Sampler will always generate tasks from
the specified classes
"""
super(NShotTaskSampler, self).__init__(dataset)
self.episodes_per_epoch = episodes_per_epoch
self.dataset = dataset
if num_tasks < 1:
raise ValueError('num_tasks must be > 1.')
self.num_tasks = num_tasks
# TODO: Raise errors if initialise badly
self.k = k
self.n = n
self.q = q
self.fixed_tasks = fixed_tasks
self.i_task = 0
示例3: __init__
# 需要導入模塊: from torch.utils import data [as 別名]
# 或者: from torch.utils.data import Sampler [as 別名]
def __init__(self, sampler, batch_size, max_iteration=100000000, drop_last=True):
"""
數據加載,默認循環加載1億次,幾近無限迭代.
每次迭代輸出一個批次的數據.
:param sampler: 采樣器,傳入 不同采樣器 實現 不同的采樣策略, RandomSampler隨機采樣,SequentialSampler順序采樣
:param batch_size: 批次大小
:param max_iteration: 迭代次數
:param drop_last: 是否棄掉最後的不夠一批次的數據。True則棄掉;False保留,並返回,但是這一批次會小於指定批次大小。
"""
if not isinstance(sampler, Sampler):
raise ValueError("sampler should be an instance of "
"torch.utils.data.Sampler, but got sampler={}"
.format(sampler))
if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \
batch_size <= 0:
raise ValueError("batch_size should be a positive integer value, "
"but got batch_size={}".format(batch_size))
if not isinstance(max_iteration, _int_classes) or isinstance(max_iteration, bool) or \
max_iteration <= 0:
raise ValueError("max_iter should be a positive integer value, "
"but got max_iter={}".format(max_iteration))
if not isinstance(drop_last, bool):
raise ValueError("drop_last should be a boolean value, but got "
"drop_last={}".format(drop_last))
self.sampler = sampler
self.batch_size = batch_size
self.max_iteration = max_iteration
self.drop_last = drop_last
示例4: __init__
# 需要導入模塊: from torch.utils import data [as 別名]
# 或者: from torch.utils.data import Sampler [as 別名]
def __init__(self, sampler: Sampler):
"""
Args:
sampler (Sampler): @TODO: Docs. Contribution is welcome
"""
self.sampler = sampler
self.sampler_list = None