当前位置: 首页>>代码示例>>Python>>正文


Python data.BatchSampler方法代码示例

本文整理汇总了Python中torch.utils.data.BatchSampler方法的典型用法代码示例。如果您正苦于以下问题:Python data.BatchSampler方法的具体用法?Python data.BatchSampler怎么用?Python data.BatchSampler使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.utils.data的用法示例。


在下文中一共展示了data.BatchSampler方法的14个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: generate_batch

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import BatchSampler [as 别名]
def generate_batch(self, episodes, episode_labels):
        total_steps = sum([len(e) for e in episodes])
        assert total_steps > self.batch_size
        print('Total Steps: {}'.format(total_steps))
        # Episode sampler
        # Sample `num_samples` episodes then batchify them with `self.batch_size` episodes per batch
        sampler = BatchSampler(RandomSampler(range(len(episodes)),
                                             replacement=True, num_samples=total_steps),
                               self.batch_size, drop_last=True)

        for indices in sampler:
            episodes_batch = [episodes[x] for x in indices]
            episode_labels_batch = [episode_labels[x] for x in indices]
            xs, labels = [], appendabledict()
            for ep_ind, episode in enumerate(episodes_batch):
                # Get one sample from this episode
                t = np.random.randint(len(episode))
                xs.append(episode[t])
                labels.append_update(episode_labels_batch[ep_ind][t])
            yield torch.stack(xs).float().to(self.device) / 255., labels 
开发者ID:mila-iqia,项目名称:atari-representation-learning,代码行数:22,代码来源:probe.py

示例2: generate_batch

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import BatchSampler [as 别名]
def generate_batch(self, episodes):
        total_steps = sum([len(e) for e in episodes])
        print('Total Steps: {}'.format(total_steps))
        # Episode sampler
        # Sample `num_samples` episodes then batchify them with `self.batch_size` episodes per batch
        sampler = BatchSampler(RandomSampler(range(len(episodes)),
                                             replacement=True, num_samples=total_steps),
                               self.batch_size, drop_last=True)
        for indices in sampler:
            episodes_batch = [episodes[x] for x in indices]
            x_t, x_tprev, x_that, ts, thats = [], [], [], [], []
            for episode in episodes_batch:
                # Get one sample from this episode
                t, t_hat = 0, 0
                t, t_hat = np.random.randint(0, len(episode)), np.random.randint(0, len(episode))
                x_t.append(episode[t])

                x_tprev.append(episode[t - 1])
                ts.append([t])
            yield torch.stack(x_t).float().to(self.device) / 255., torch.stack(x_tprev).float().to(self.device) / 255. 
开发者ID:mila-iqia,项目名称:atari-representation-learning,代码行数:22,代码来源:stdim.py

示例3: generate_batch

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import BatchSampler [as 别名]
def generate_batch(self, episodes):
        total_steps = sum([len(e) for e in episodes])
        print('Total Steps: {}'.format(total_steps))
        # Episode sampler
        # Sample `num_samples` episodes then batchify them with `self.batch_size` episodes per batch
        sampler = BatchSampler(RandomSampler(range(len(episodes)),
                                             replacement=True, num_samples=total_steps),
                               self.batch_size, drop_last=True)
        for indices in sampler:
            episodes_batch = [episodes[x] for x in indices]
            x_t, x_tn = [], []
            for episode in episodes_batch:
                # Get one sample from this episode
                t = np.random.randint(0, len(episode) - self.pred_offset)
                t_n = t + self.pred_offset

                x_t.append(episode[t])
                x_tn.append(episode[t_n])
            yield torch.stack(x_t).float().to(self.device) / 255., \
                  torch.stack(x_tn).float().to(self.device) / 255. 
开发者ID:mila-iqia,项目名称:atari-representation-learning,代码行数:22,代码来源:no_action_feedforward_predictor.py

示例4: generate_batch

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import BatchSampler [as 别名]
def generate_batch(self, episodes):
        total_steps = sum([len(e) for e in episodes])
        print('Total Steps: {}'.format(total_steps))
        # Episode sampler
        # Sample `num_samples` episodes then batchify them with `self.batch_size` episodes per batch
        sampler = BatchSampler(RandomSampler(range(len(episodes)),
                                             replacement=True, num_samples=total_steps),
                               self.batch_size, drop_last=True)
        for indices in sampler:
            episodes_batch = [episodes[x] for x in indices]
            x_t, x_tprev, x_that, ts, thats = [], [], [], [], []
            for episode in episodes_batch:
                # Get one sample from this episode
                t, t_hat = 0, 0
                t, t_hat = np.random.randint(0, len(episode)), np.random.randint(0, len(episode))
                x_t.append(episode[t])
            yield torch.stack(x_t).float().to(self.device) / 255. 
开发者ID:mila-iqia,项目名称:atari-representation-learning,代码行数:19,代码来源:vae.py

示例5: test_engine_with_dataloader_no_auto_batching

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import BatchSampler [as 别名]
def test_engine_with_dataloader_no_auto_batching():
    # tests https://github.com/pytorch/ignite/issues/941
    from torch.utils.data import DataLoader, BatchSampler, RandomSampler

    data = torch.rand(64, 4, 10)
    data_loader = DataLoader(
        data, batch_size=None, sampler=BatchSampler(RandomSampler(data), batch_size=8, drop_last=True)
    )

    counter = [0]

    def foo(e, b):
        print("{}-{}: {}".format(e.state.epoch, e.state.iteration, b))
        counter[0] += 1

    engine = DeterministicEngine(foo)
    engine.run(data_loader, epoch_length=10, max_epochs=5)

    assert counter[0] == 50 
开发者ID:pytorch,项目名称:ignite,代码行数:21,代码来源:test_deterministic.py

示例6: test_engine_with_dataloader_no_auto_batching

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import BatchSampler [as 别名]
def test_engine_with_dataloader_no_auto_batching():
    # tests https://github.com/pytorch/ignite/issues/941
    from torch.utils.data import DataLoader, BatchSampler, RandomSampler

    data = torch.rand(64, 4, 10)
    data_loader = DataLoader(
        data, batch_size=None, sampler=BatchSampler(RandomSampler(data), batch_size=8, drop_last=True)
    )

    counter = [0]

    def foo(e, b):
        print("{}-{}: {}".format(e.state.epoch, e.state.iteration, b))
        counter[0] += 1

    engine = Engine(foo)
    engine.run(data_loader, epoch_length=10, max_epochs=5)

    assert counter[0] == 50 
开发者ID:pytorch,项目名称:ignite,代码行数:21,代码来源:test_engine.py

示例7: generate_batch

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import BatchSampler [as 别名]
def generate_batch(self, episodes):
        total_steps = sum([len(e) for e in episodes])
        print('Total Steps: {}'.format(total_steps))
        # Episode sampler
        # Sample `num_samples` episodes then batchify them with `self.batch_size` episodes per batch
        sampler = BatchSampler(RandomSampler(range(len(episodes)),
                                             replacement=True, num_samples=total_steps),
                               self.batch_size, drop_last=True)
        for indices in sampler:
            episodes_batch = [episodes[x] for x in indices]
            x_t, x_tprev, x_that, ts, thats = [], [], [], [], []
            for episode in episodes_batch:
                # Get one sample from this episode
                t, t_hat = 0, 0
                t, t_hat = np.random.randint(0, len(episode)), np.random.randint(0, len(episode))
                x_t.append(episode[t])

                # Apply the same transform to x_{t-1} and x_{t_hat}
                # https://github.com/pytorch/vision/issues/9#issuecomment-383110707
                # Use numpy's random seed because Cutout uses np
                # seed = random.randint(0, 2 ** 32)
                # np.random.seed(seed)
                x_tprev.append(episode[t - 1])
                # np.random.seed(seed)
                #x_that.append(episode[t_hat])

                ts.append([t])
                #thats.append([t_hat])
            yield torch.stack(x_t).float().to(self.device) / 255., torch.stack(x_tprev).float().to(self.device) / 255. 
开发者ID:mila-iqia,项目名称:atari-representation-learning,代码行数:31,代码来源:global_infonce_stdim.py

示例8: generate_batch

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import BatchSampler [as 别名]
def generate_batch(self, episodes):
        total_steps = sum([len(e) for e in episodes])
        print('Total Steps: {}'.format(total_steps))
        # Episode sampler
        # Sample `num_samples` episodes then batchify them with `self.batch_size` episodes per batch
        sampler = BatchSampler(RandomSampler(range(len(episodes)),
                                             replacement=True, num_samples=total_steps),
                               self.batch_size, drop_last=True)
        for indices in sampler:
            episodes_batch = [episodes[x] for x in indices]
            x_t, x_tprev, x_that, ts, thats = [], [], [], [], []
            for episode in episodes_batch:
                # Get one sample from this episode
                t, t_hat = 0, 0
                t, t_hat = np.random.randint(0, len(episode)), np.random.randint(0, len(episode))
                x_t.append(episode[t])

                # Apply the same transform to x_{t-1} and x_{t_hat}
                # https://github.com/pytorch/vision/issues/9#issuecomment-383110707
                # Use numpy's random seed because Cutout uses np
                # seed = random.randint(0, 2 ** 32)
                # np.random.seed(seed)
                x_tprev.append(episode[t - 1])
                # np.random.seed(seed)
                x_that.append(episode[t_hat])

                ts.append([t])
                thats.append([t_hat])
            yield torch.stack(x_t).float().to(self.device) / 255., torch.stack(x_tprev).float().to(self.device) / 255., \
                  torch.stack(x_that).float().to(self.device) / 255., torch.Tensor(ts).to(self.device), \
                  torch.Tensor(thats).to(self.device) 
开发者ID:mila-iqia,项目名称:atari-representation-learning,代码行数:33,代码来源:temporal_dim.py

示例9: generate_batch

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import BatchSampler [as 别名]
def generate_batch(self, episodes):
        episodes = [episode for episode in episodes if len(episode) >= self.sequence_length]
        # Episode sampler
        # Sample `num_samples` episodes then batchify them with `self.batch_size` episodes per batch
        sampler = BatchSampler(RandomSampler(range(len(episodes)),
                                             replacement=True, num_samples=len(episodes) * self.sequence_length),
                               self.batch_size, drop_last=True)
        for indices in sampler:
            episodes_batch = [episodes[x] for x in indices]
            sequences = []
            for episode in episodes_batch:
              start_index = np.random.randint(0, len(episode) - self.sequence_length+1)
              seq = episode[start_index: start_index + self.sequence_length]
              sequences.append(torch.stack(seq))
            yield torch.stack(sequences).float() 
开发者ID:mila-iqia,项目名称:atari-representation-learning,代码行数:17,代码来源:cpc.py

示例10: __init__

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import BatchSampler [as 别名]
def __init__(self, dataset, batch_size, negative_sampling=False,
               num_sampling_users=0, num_workers=0, collate_fn=None):
    self.dataset = dataset # type: RecommendationDataset
    self.num_sampling_users = num_sampling_users
    self.num_workers = num_workers
    self.batch_size = batch_size
    self.negative_sampling = negative_sampling

    if self.num_sampling_users == 0:
      self.num_sampling_users = batch_size

    assert self.num_sampling_users >= batch_size, 'num_sampling_users should be at least equal to the batch_size'

    self.batch_collator = BatchCollator(batch_size=self.batch_size, negative_sampling=self.negative_sampling)

    # Wrapping a BatchSampler within a BatchSampler
    # in order to fetch the whole mini-batch at once
    # from the dataset instead of fetching each sample on its own
    batch_sampler = BatchSampler(BatchSampler(RandomSampler(dataset),
                                              batch_size=self.num_sampling_users, drop_last=False),
                                 batch_size=1, drop_last=False)

    if collate_fn is None:
      self._collate_fn = self.batch_collator.collate
      self._use_default_data_generator = True
    else:
      self._collate_fn = collate_fn
      self._use_default_data_generator = False

    self._dataloader = DataLoader(dataset, batch_sampler=batch_sampler,
                                  num_workers=num_workers, collate_fn=self._collate) 
开发者ID:amoussawi,项目名称:recoder,代码行数:33,代码来源:data.py

示例11: test_reproducible_batch_sampler_wrong_input

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import BatchSampler [as 别名]
def test_reproducible_batch_sampler_wrong_input():
    with pytest.raises(TypeError, match=r"Argument batch_sampler should be torch.utils.data.sampler.BatchSampler"):
        ReproducibleBatchSampler("abc") 
开发者ID:pytorch,项目名称:ignite,代码行数:5,代码来源:test_deterministic.py

示例12: _get_outputs

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import BatchSampler [as 别名]
def _get_outputs(self, x1, x2):
        """
        Private function to get the transformed data and the corresponding
        loss for the given inputs.

        Parameters
        ----------
        x1 : torch.tensor
            Input view 1 data.
        x2 : torch.tensor
            Input view 2 data.

        Returns
        -------
        losses : list
            List of losses for each batch taken from the input data.
        outputs : list of tensors
            outputs[i] is the output of the deep models for view i.
        """
        with torch.no_grad():
            self.model_.eval()
            data_size = x1.size(0)
            batch_idxs = list(BatchSampler(SequentialSampler(range(data_size)),
                              batch_size=self.batch_size_,
                              drop_last=False))
            losses = []
            outputs1 = []
            outputs2 = []
            for batch_idx in batch_idxs:
                batch_x1 = x1[batch_idx, :]
                batch_x2 = x2[batch_idx, :]
                o1, o2 = self.model_(batch_x1, batch_x2)
                outputs1.append(o1)
                outputs2.append(o2)
                loss = self.loss_(o1, o2)
                losses.append(loss.item())
        outputs = [torch.cat(outputs1, dim=0).cpu().numpy(),
                   torch.cat(outputs2, dim=0).cpu().numpy()]

        return losses, outputs 
开发者ID:neurodata,项目名称:mvlearn,代码行数:42,代码来源:dcca.py

示例13: __init__

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import BatchSampler [as 别名]
def __init__(
        self,
        federated_dataset,
        batch_size=8,
        shuffle=False,
        num_iterators=1,
        drop_last=False,
        collate_fn=default_collate,
        iter_per_worker=False,
        **kwargs,
    ):
        if len(kwargs) > 0:
            options = ", ".join([f"{k}: {v}" for k, v in kwargs.items()])
            logging.warning(f"The following options are not supported: {options}")

        try:
            self.workers = federated_dataset.workers
        except AttributeError:
            raise Exception(
                "Your dataset is not a FederatedDataset, please use "
                "torch.utils.data.DataLoader instead."
            )

        self.federated_dataset = federated_dataset
        self.batch_size = batch_size
        self.drop_last = drop_last
        self.collate_fn = collate_fn
        self.iter_class = _DataLoaderOneWorkerIter if iter_per_worker else _DataLoaderIter

        # Build a batch sampler per worker
        self.batch_samplers = {}
        for worker in self.workers:
            data_range = range(len(federated_dataset[worker]))
            if shuffle:
                sampler = RandomSampler(data_range)
            else:
                sampler = SequentialSampler(data_range)
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)
            self.batch_samplers[worker] = batch_sampler

        if iter_per_worker:
            self.num_iterators = len(self.workers)
        else:
            # You can't have more iterators than n - 1 workers, because you always
            # need a worker idle in the worker switch process made by iterators
            if len(self.workers) == 1:
                self.num_iterators = 1
            else:
                self.num_iterators = min(num_iterators, len(self.workers) - 1) 
开发者ID:OpenMined,项目名称:PySyft,代码行数:51,代码来源:dataloader.py

示例14: __init__

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import BatchSampler [as 别名]
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
                 num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
                 timeout=0, worker_init_fn=None):
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.collate_fn = collate_fn
        self.pin_memory = pin_memory
        self.drop_last = drop_last
        self.timeout = timeout
        self.worker_init_fn = worker_init_fn

        if timeout < 0:
            raise ValueError('timeout option should be non-negative')

        if batch_sampler is not None:
            if batch_size > 1 or shuffle or sampler is not None or drop_last:
                raise ValueError('batch_sampler option is mutually exclusive '
                                 'with batch_size, shuffle, sampler, and '
                                 'drop_last')
            self.batch_size = None
            self.drop_last = None

        if sampler is not None and shuffle:
            raise ValueError('sampler option is mutually exclusive with '
                             'shuffle')

        if self.num_workers < 0:
            raise ValueError('num_workers option cannot be negative; '
                             'use num_workers=0 to disable multiprocessing.')

        if batch_sampler is None:
            if sampler is None:
                if shuffle:
                    sampler = RandomSampler(dataset)
                else:
                    sampler = SequentialSampler(dataset)
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)

        self.sampler = sampler
        self.batch_sampler = batch_sampler
        self.__initialized = True 
开发者ID:namisan,项目名称:mt-dnn,代码行数:44,代码来源:dataloader.py


注:本文中的torch.utils.data.BatchSampler方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。