当前位置: 首页>>代码示例>>Python>>正文


Python data.Sampler方法代码示例

本文整理汇总了Python中torch.utils.data.Sampler方法的典型用法代码示例。如果您正苦于以下问题:Python data.Sampler方法的具体用法?Python data.Sampler怎么用?Python data.Sampler使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.utils.data的用法示例。


在下文中一共展示了data.Sampler方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __iter__

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import Sampler [as 别名]
def __iter__(self):
        # deterministically shuffle based on epoch
        g = torch.Generator()
        g.manual_seed(self.epoch)
        if self.shuffle:
            indices = torch.randperm(len(self.dataset), generator=g).tolist()
        else:
            indices = list(range(len(self.dataset)))

        # add extra samples to make it evenly divisible
        indices += indices[:(self.total_size - len(indices))]
        assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank:self.total_size:self.num_replicas]
        assert len(indices) == self.num_samples

        if isinstance(self.dataset, Sampler):
            orig_indices = list(iter(self.dataset))
            indices = [orig_indices[i] for i in indices]

        return iter(indices) 
开发者ID:cvlab-columbia,项目名称:oops,代码行数:24,代码来源:sampler.py

示例2: __init__

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import Sampler [as 别名]
def __init__(self,
                 dataset: torch.utils.data.Dataset,
                 episodes_per_epoch: int = None,
                 n: int = None,
                 k: int = None,
                 q: int = None,
                 num_tasks: int = 1,
                 fixed_tasks: List[Iterable[int]] = None):
        """PyTorch Sampler subclass that generates batches of n-shot, k-way, q-query tasks.

        Each n-shot task contains a "support set" of `k` sets of `n` samples and a "query set" of `k` sets
        of `q` samples. The support set and the query set are all grouped into one Tensor such that the first n * k
        samples are from the support set while the remaining q * k samples are from the query set.

        The support and query sets are sampled such that they are disjoint i.e. do not contain overlapping samples.

        # Arguments
            dataset: Instance of torch.utils.data.Dataset from which to draw samples
            episodes_per_epoch: Arbitrary number of batches of n-shot tasks to generate in one epoch
            n_shot: int. Number of samples for each class in the n-shot classification tasks.
            k_way: int. Number of classes in the n-shot classification tasks.
            q_queries: int. Number query samples for each class in the n-shot classification tasks.
            num_tasks: Number of n-shot tasks to group into a single batch
            fixed_tasks: If this argument is specified this Sampler will always generate tasks from
                the specified classes
        """
        super(NShotTaskSampler, self).__init__(dataset)
        self.episodes_per_epoch = episodes_per_epoch
        self.dataset = dataset
        if num_tasks < 1:
            raise ValueError('num_tasks must be > 1.')

        self.num_tasks = num_tasks
        # TODO: Raise errors if initialise badly
        self.k = k
        self.n = n
        self.q = q
        self.fixed_tasks = fixed_tasks

        self.i_task = 0 
开发者ID:oscarknagg,项目名称:few-shot,代码行数:42,代码来源:core.py

示例3: __init__

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import Sampler [as 别名]
def __init__(self, sampler, batch_size, max_iteration=100000000, drop_last=True):
        """
        数据加载,默认循环加载1亿次,几近无限迭代.
        每次迭代输出一个批次的数据.
        :param sampler:         采样器,传入 不同采样器 实现 不同的采样策略,    RandomSampler随机采样,SequentialSampler顺序采样
        :param batch_size:      批次大小
        :param max_iteration:   迭代次数
        :param drop_last:       是否弃掉最后的不够一批次的数据。True则弃掉;False保留,并返回,但是这一批次会小于指定批次大小。
        """
        if not isinstance(sampler, Sampler):
            raise ValueError("sampler should be an instance of "
                             "torch.utils.data.Sampler, but got sampler={}"
                             .format(sampler))
        if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \
                batch_size <= 0:
            raise ValueError("batch_size should be a positive integer value, "
                             "but got batch_size={}".format(batch_size))
        if not isinstance(max_iteration, _int_classes) or isinstance(max_iteration, bool) or \
                max_iteration <= 0:
            raise ValueError("max_iter should be a positive integer value, "
                             "but got max_iter={}".format(max_iteration))

        if not isinstance(drop_last, bool):
            raise ValueError("drop_last should be a boolean value, but got "
                             "drop_last={}".format(drop_last))
        self.sampler = sampler
        self.batch_size = batch_size
        self.max_iteration = max_iteration
        self.drop_last = drop_last 
开发者ID:yatengLG,项目名称:SSD-Pytorch,代码行数:31,代码来源:Dataloader.py

示例4: __init__

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import Sampler [as 别名]
def __init__(self, sampler: Sampler):
        """
        Args:
            sampler (Sampler): @TODO: Docs. Contribution is welcome
        """
        self.sampler = sampler
        self.sampler_list = None 
开发者ID:catalyst-team,项目名称:catalyst,代码行数:9,代码来源:dataset.py


注:本文中的torch.utils.data.Sampler方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。