當前位置: 首頁>>代碼示例>>Python>>正文


Python data.Sampler方法代碼示例

本文整理匯總了Python中torch.utils.data.Sampler方法的典型用法代碼示例。如果您正苦於以下問題:Python data.Sampler方法的具體用法?Python data.Sampler怎麽用?Python data.Sampler使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在torch.utils.data的用法示例。


在下文中一共展示了data.Sampler方法的4個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: __iter__

# 需要導入模塊: from torch.utils import data [as 別名]
# 或者: from torch.utils.data import Sampler [as 別名]
def __iter__(self):
        # deterministically shuffle based on epoch
        g = torch.Generator()
        g.manual_seed(self.epoch)
        if self.shuffle:
            indices = torch.randperm(len(self.dataset), generator=g).tolist()
        else:
            indices = list(range(len(self.dataset)))

        # add extra samples to make it evenly divisible
        indices += indices[:(self.total_size - len(indices))]
        assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank:self.total_size:self.num_replicas]
        assert len(indices) == self.num_samples

        if isinstance(self.dataset, Sampler):
            orig_indices = list(iter(self.dataset))
            indices = [orig_indices[i] for i in indices]

        return iter(indices) 
開發者ID:cvlab-columbia,項目名稱:oops,代碼行數:24,代碼來源:sampler.py

示例2: __init__

# 需要導入模塊: from torch.utils import data [as 別名]
# 或者: from torch.utils.data import Sampler [as 別名]
def __init__(self,
                 dataset: torch.utils.data.Dataset,
                 episodes_per_epoch: int = None,
                 n: int = None,
                 k: int = None,
                 q: int = None,
                 num_tasks: int = 1,
                 fixed_tasks: List[Iterable[int]] = None):
        """PyTorch Sampler subclass that generates batches of n-shot, k-way, q-query tasks.

        Each n-shot task contains a "support set" of `k` sets of `n` samples and a "query set" of `k` sets
        of `q` samples. The support set and the query set are all grouped into one Tensor such that the first n * k
        samples are from the support set while the remaining q * k samples are from the query set.

        The support and query sets are sampled such that they are disjoint i.e. do not contain overlapping samples.

        # Arguments
            dataset: Instance of torch.utils.data.Dataset from which to draw samples
            episodes_per_epoch: Arbitrary number of batches of n-shot tasks to generate in one epoch
            n_shot: int. Number of samples for each class in the n-shot classification tasks.
            k_way: int. Number of classes in the n-shot classification tasks.
            q_queries: int. Number query samples for each class in the n-shot classification tasks.
            num_tasks: Number of n-shot tasks to group into a single batch
            fixed_tasks: If this argument is specified this Sampler will always generate tasks from
                the specified classes
        """
        super(NShotTaskSampler, self).__init__(dataset)
        self.episodes_per_epoch = episodes_per_epoch
        self.dataset = dataset
        if num_tasks < 1:
            raise ValueError('num_tasks must be > 1.')

        self.num_tasks = num_tasks
        # TODO: Raise errors if initialise badly
        self.k = k
        self.n = n
        self.q = q
        self.fixed_tasks = fixed_tasks

        self.i_task = 0 
開發者ID:oscarknagg,項目名稱:few-shot,代碼行數:42,代碼來源:core.py

示例3: __init__

# 需要導入模塊: from torch.utils import data [as 別名]
# 或者: from torch.utils.data import Sampler [as 別名]
def __init__(self, sampler, batch_size, max_iteration=100000000, drop_last=True):
        """
        數據加載,默認循環加載1億次,幾近無限迭代.
        每次迭代輸出一個批次的數據.
        :param sampler:         采樣器,傳入 不同采樣器 實現 不同的采樣策略,    RandomSampler隨機采樣,SequentialSampler順序采樣
        :param batch_size:      批次大小
        :param max_iteration:   迭代次數
        :param drop_last:       是否棄掉最後的不夠一批次的數據。True則棄掉;False保留,並返回,但是這一批次會小於指定批次大小。
        """
        if not isinstance(sampler, Sampler):
            raise ValueError("sampler should be an instance of "
                             "torch.utils.data.Sampler, but got sampler={}"
                             .format(sampler))
        if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \
                batch_size <= 0:
            raise ValueError("batch_size should be a positive integer value, "
                             "but got batch_size={}".format(batch_size))
        if not isinstance(max_iteration, _int_classes) or isinstance(max_iteration, bool) or \
                max_iteration <= 0:
            raise ValueError("max_iter should be a positive integer value, "
                             "but got max_iter={}".format(max_iteration))

        if not isinstance(drop_last, bool):
            raise ValueError("drop_last should be a boolean value, but got "
                             "drop_last={}".format(drop_last))
        self.sampler = sampler
        self.batch_size = batch_size
        self.max_iteration = max_iteration
        self.drop_last = drop_last 
開發者ID:yatengLG,項目名稱:SSD-Pytorch,代碼行數:31,代碼來源:Dataloader.py

示例4: __init__

# 需要導入模塊: from torch.utils import data [as 別名]
# 或者: from torch.utils.data import Sampler [as 別名]
def __init__(self, sampler: Sampler):
        """
        Args:
            sampler (Sampler): @TODO: Docs. Contribution is welcome
        """
        self.sampler = sampler
        self.sampler_list = None 
開發者ID:catalyst-team,項目名稱:catalyst,代碼行數:9,代碼來源:dataset.py


注:本文中的torch.utils.data.Sampler方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。