当前位置: 首页>>代码示例>>Python>>正文


Python sampler.BatchSampler方法代码示例

本文整理汇总了Python中torch.utils.data.sampler.BatchSampler方法的典型用法代码示例。如果您正苦于以下问题:Python sampler.BatchSampler方法的具体用法?Python sampler.BatchSampler怎么用?Python sampler.BatchSampler使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.utils.data.sampler的用法示例。


在下文中一共展示了sampler.BatchSampler方法的13个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_number_of_iters_and_elements

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import BatchSampler [as 别名]
def test_number_of_iters_and_elements(self):
        for batch_size in [2, 3, 4]:
            for num_iterations in [4, 10, 20]:
                for drop_last in [False, True]:
                    dataset = [i for i in range(10)]
                    sampler = SequentialSampler(dataset)
                    batch_sampler = BatchSampler(
                        sampler, batch_size, drop_last=drop_last
                    )

                    iter_sampler = IterationBasedBatchSampler(
                        batch_sampler, num_iterations
                    )
                    assert len(iter_sampler) == num_iterations
                    for i, batch in enumerate(iter_sampler):
                        start = (i % len(batch_sampler)) * batch_size
                        end = min(start + batch_size, len(dataset))
                        expected = [x for x in range(start, end)]
                        self.assertEqual(batch, expected) 
开发者ID:Res2Net,项目名称:Res2Net-maskrcnn,代码行数:21,代码来源:test_data_samplers.py

示例2: update

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import BatchSampler [as 别名]
def update(self):
        if self.memory_count >= self.capacity:
            state = torch.tensor([t.state for t in self.memory]).float()
            action = torch.LongTensor([t.action for t in self.memory]).view(-1,1).long()
            reward = torch.tensor([t.reward for t in self.memory]).float()
            next_state = torch.tensor([t.next_state for t in self.memory]).float()

            reward = (reward - reward.mean()) / (reward.std() + 1e-7)
            with torch.no_grad():
                target_v = reward + self.gamma * self.target_net(next_state).max(1)[0]

            #Update...
            for index in BatchSampler(SubsetRandomSampler(range(len(self.memory))), batch_size=self.batch_size, drop_last=False):
                v = (self.act_net(state).gather(1, action))[index]
                loss = self.loss_func(target_v[index].unsqueeze(1), (self.act_net(state).gather(1, action))[index])
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                self.writer.add_scalar('loss/value_loss', loss, self.update_count)
                self.update_count +=1
                if self.update_count % 100 ==0:
                    self.target_net.load_state_dict(self.act_net.state_dict())
        else:
            print("Memory Buff is too less") 
开发者ID:sweetice,项目名称:Deep-reinforcement-learning-with-pytorch,代码行数:26,代码来源:DQN_CartPole-v0.py

示例3: sample

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import BatchSampler [as 别名]
def sample(self, advantages, num_mini_batch):
        num_steps, num_processes = self.rewards.size()[0:2]
        batch_size = num_processes * num_steps
        # Make sure we have at least enough for a bunch of batches of size 1.
        assert batch_size >= num_mini_batch

        mini_batch_size = batch_size // num_mini_batch
        sampler = BatchSampler(SubsetRandomSampler(range(batch_size)), mini_batch_size, drop_last=False)
        for indices in sampler:
            observations_batch = self.observations[:-1].view(-1,
                                        *self.observations.size()[2:])[indices]
            actions_batch = self.actions.view(-1, self.actions.size(-1))[indices]
            return_batch = self.returns[:-1].view(-1, 1)[indices]
            masks_batch = self.masks[:-1].view(-1, 1)[indices]
            old_action_log_probs_batch = self.action_log_probs.view(-1, 1)[indices]
            adv = advantages.view(-1, 1)[indices]

            yield observations_batch, actions_batch, \
                return_batch, masks_batch, old_action_log_probs_batch, adv 
开发者ID:ASzot,项目名称:ppo-pytorch,代码行数:21,代码来源:memory.py

示例4: test_distributed_batch_sampler

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import BatchSampler [as 别名]
def test_distributed_batch_sampler():
    sampler = SequentialSampler(list(range(15)))
    batch_sampler = BatchSampler(sampler, 10, False)

    distributed_sampler = DistributedBatchSampler(batch_sampler, num_replicas=4, rank=0)
    assert list(distributed_sampler) == [[0, 4, 8], [10, 14]]
    assert len(distributed_sampler) == 2

    distributed_sampler = DistributedBatchSampler(batch_sampler, num_replicas=4, rank=1)
    assert list(distributed_sampler) == [[1, 5, 9], [11]]
    assert len(distributed_sampler) == 2

    distributed_sampler = DistributedBatchSampler(batch_sampler, num_replicas=4, rank=2)
    assert list(distributed_sampler) == [[2, 6], [12]]
    assert len(distributed_sampler) == 2

    distributed_sampler = DistributedBatchSampler(batch_sampler, num_replicas=4, rank=3)
    assert list(distributed_sampler) == [[3, 7], [13]]
    assert len(distributed_sampler) == 2 
开发者ID:PetrochukM,项目名称:PyTorch-NLP,代码行数:21,代码来源:test_distributed_batch_sampler.py

示例5: feed_forward_generator

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import BatchSampler [as 别名]
def feed_forward_generator(self, advantages, num_mini_batch):
        num_steps, num_processes = self.rewards.size()[0:2]
        batch_size = num_processes * num_steps
        assert batch_size >= num_mini_batch, (
            "PPO requires the number of processes ({}) "
            "* number of steps ({}) = {} "
            "to be greater than or equal to the number of PPO mini batches ({})."
            "".format(num_processes, num_steps, num_processes * num_steps, num_mini_batch))
        mini_batch_size = batch_size // num_mini_batch
        sampler = BatchSampler(SubsetRandomSampler(range(batch_size)), mini_batch_size, drop_last=False)
        for indices in sampler:
            obs_batch = self.obs[:-1].view(-1, *self.obs.size()[2:])[indices]
            recurrent_hidden_states_batch = self.recurrent_hidden_states[:-1].view(-1,
                self.recurrent_hidden_states.size(-1))[indices]
            actions_batch = self.actions.view(-1, self.actions.size(-1))[indices]
            value_preds_batch = self.value_preds[:-1].view(-1, 1)[indices]
            return_batch = self.returns[:-1].view(-1, 1)[indices]
            masks_batch = self.masks[:-1].view(-1, 1)[indices]
            old_action_log_probs_batch = self.action_log_probs.view(-1, 1)[indices]
            adv_targ = advantages.view(-1, 1)[indices]

            yield obs_batch, recurrent_hidden_states_batch, actions_batch, \
                value_preds_batch, return_batch, masks_batch, old_action_log_probs_batch, adv_targ 
开发者ID:montrealrobotics,项目名称:dal,代码行数:25,代码来源:storage.py

示例6: feed_forward_generator

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import BatchSampler [as 别名]
def feed_forward_generator(self, advantages, num_mini_batch):
        num_steps, num_processes = self.rewards.size()[0:2]
        batch_size = num_processes * num_steps
        assert batch_size >= num_mini_batch, (
            "PPO requires the number of processes ({}) "
            "* number of steps ({}) = {} "
            "to be greater than or equal to the number of PPO mini batches ({})."
            "".format(num_processes, num_steps, num_processes * num_steps, num_mini_batch))
        mini_batch_size = batch_size // num_mini_batch
        sampler = BatchSampler(SubsetRandomSampler(range(batch_size)), mini_batch_size, drop_last=False)
        for indices in sampler:
            obs_batch = self.obs[:-1].view(-1, *self.obs.size()[2:])[indices]
            recurrent_hidden_states_batch = self.recurrent_hidden_states[:-1].view(-1,
                self.recurrent_hidden_states.size(-1))[indices]
            actions_batch = self.actions.view(-1, self.actions.size(-1))[indices]
            return_batch = self.returns[:-1].view(-1, 1)[indices]
            masks_batch = self.masks[:-1].view(-1, 1)[indices]
            old_action_log_probs_batch = self.action_log_probs.view(-1, 1)[indices]
            adv_targ = advantages.view(-1, 1)[indices]

            yield obs_batch, recurrent_hidden_states_batch, actions_batch, \
                return_batch, masks_batch, old_action_log_probs_batch, adv_targ 
开发者ID:maximecb,项目名称:gym-miniworld,代码行数:24,代码来源:storage.py

示例7: feed_forward_generator

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import BatchSampler [as 别名]
def feed_forward_generator(self, advantages, num_mini_batch):
        num_steps, num_processes = self.rewards.size()[0:2]
        batch_size = num_processes * num_steps
        mini_batch_size = batch_size // num_mini_batch
        sampler = BatchSampler(SubsetRandomSampler(range(batch_size)), mini_batch_size, drop_last=False)
        for indices in sampler:
            indices = torch.LongTensor(indices)

            if advantages.is_cuda:
                indices = indices.cuda()

            observations_batch = self.observations[:-1].view(-1,
                                        *self.observations.size()[2:])[indices]
            actions_batch = self.actions.view(-1, self.actions.size(-1))[indices]
            return_batch = self.returns[:-1].view(-1, 1)[indices]
            masks_batch = self.masks[:-1].view(-1, 1)[indices]
            old_action_log_probs_batch = self.action_log_probs.view(-1, 1)[indices]
            adv_targ = advantages.view(-1, 1)[indices]

            yield observations_batch, actions_batch, \
                return_batch, masks_batch, old_action_log_probs_batch, adv_targ 
开发者ID:BoyuanYan,项目名称:Actor-Critic-Based-Resource-Allocation-for-Multimodal-Optical-Networks,代码行数:23,代码来源:storage.py

示例8: get_data_loader

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import BatchSampler [as 别名]
def get_data_loader(X_in, y_in, batch_size, extended_batch_sampler=True, epoch_size=25000, upsample=False, seed=42):
    """ Returns a dataloader that enables larger epochs on small datasets and
        has upsampling functionality.

    # Arguments:
        X_in: Inputs of the given dataset.
        y_in: Outputs of the given dataset.
        batch_size: Batch size.
        epoch_size: Number of samples in an epoch.
        upsample: Whether upsampling should be done. This flag should only be
            set on binary class problems.

    # Returns:
        DataLoader.
    """
    dataset = DeepMojiDataset(X_in, y_in)

    if extended_batch_sampler:
        batch_sampler = DeepMojiBatchSampler(y_in, batch_size, epoch_size=epoch_size, upsample=upsample, seed=seed)
    else:
        batch_sampler = BatchSampler(SequentialSampler(y_in), batch_size, drop_last=False)

    return DataLoader(dataset, batch_sampler=batch_sampler, num_workers=0) 
开发者ID:natashamjaques,项目名称:neural_chat,代码行数:25,代码来源:finetuning.py

示例9: feed_forward_generator

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import BatchSampler [as 别名]
def feed_forward_generator(self, advantages, num_mini_batch):
        num_steps, num_processes = self.rewards.size()[0:2]
        batch_size = num_processes * num_steps
        assert batch_size >= num_mini_batch, (
            f"PPO requires the number processes ({num_processes}) "
            f"* number of steps ({num_steps}) = {num_processes * num_steps} "
            f"to be greater than or equal to the number of PPO mini batches ({num_mini_batch}).")
        mini_batch_size = batch_size // num_mini_batch
        observations_batch = {}
        sampler = BatchSampler(SubsetRandomSampler(range(batch_size)), mini_batch_size, drop_last=False)
        for indices in sampler:
            for k, sensor_ob in self.observations.items():
                observations_batch[k] = sensor_ob[:-1].view(-1, *sensor_ob.size()[2:])[indices]
            states_batch = self.states[:-1].view(-1, self.states.size(-1))[indices]
            actions_batch = self.actions.view(-1, self.actions.size(-1))[indices]
            return_batch = self.returns[:-1].view(-1, 1)[indices]
            masks_batch = self.masks[:-1].view(-1, 1)[indices]
            old_action_log_probs_batch = self.action_log_probs.view(-1, 1)[indices]
            adv_targ = advantages.view(-1, 1)[indices]
            yield observations_batch, states_batch, actions_batch, \
                  return_batch, masks_batch, old_action_log_probs_batch, adv_targ 
开发者ID:alexsax,项目名称:midlevel-reps,代码行数:23,代码来源:rollout.py

示例10: feed_forward_generator

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import BatchSampler [as 别名]
def feed_forward_generator(self, advantages, num_mini_batch, sampler=None):
        num_steps, num_processes = self.rewards.size()[0:2]
        batch_size = num_processes * num_steps
        assert batch_size >= num_mini_batch, (
            "PPO requires the number of processes ({}) "
            "* number of steps ({}) = {} "
            "to be greater than or equal to the number of PPO mini batches ({})."
            "".format(num_processes, num_steps, num_processes * num_steps, num_mini_batch))
        mini_batch_size = batch_size // num_mini_batch
        if sampler is None:
            sampler = BatchSampler(SubsetRandomSampler(range(batch_size)), mini_batch_size, drop_last=False)
        for indices in sampler:
            obs_batch = self.obs[:-1].view(-1, *self.obs.size()[2:])[indices]
            recurrent_hidden_states_batch = self.recurrent_hidden_states[:-1].view(-1, 
                                            self.recurrent_hidden_states.size(-1))[indices]
            actions_batch = self.actions.view(-1, self.actions.size(-1))[indices]
            value_preds_batch = self.value_preds[:-1].view(-1, 1)[indices]
            return_batch = self.returns[:-1].view(-1, 1)[indices]
            masks_batch = self.masks[:-1].view(-1, 1)[indices]
            old_action_log_probs_batch = self.action_log_probs.view(-1, 1)[indices]
            adv_targ = advantages.view(-1, 1)[indices]

            yield obs_batch, recurrent_hidden_states_batch, actions_batch, value_preds_batch, return_batch, \
                  masks_batch, old_action_log_probs_batch, adv_targ 
开发者ID:sumitsk,项目名称:marl_transfer,代码行数:26,代码来源:storage.py

示例11: magent_feed_forward_generator

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import BatchSampler [as 别名]
def magent_feed_forward_generator(rollouts_list, advantages_list, num_mini_batch):
    num_steps, num_processes = rollouts_list[0].rewards.size()[0:2]
    batch_size = num_processes * num_steps
    mini_batch_size = int((batch_size/num_mini_batch)) # size of minibatch for each agent
    sampler = BatchSampler(SubsetRandomSampler(range(batch_size)), mini_batch_size, drop_last=False)
    for indices in sampler:
        obs_batch=torch.cat([rollout.obs[:-1].view(-1,*rollout.obs.size()[2:])[indices] for rollout in rollouts_list],0)
        recurrent_hidden_states_batch = torch.cat([rollout.recurrent_hidden_states[:-1].view(-1, 
                    rollout.recurrent_hidden_states.size(-1))[indices] for rollout in rollouts_list],0)
        actions_batch = torch.cat([rollout.actions.view(-1,
                    rollout.actions.size(-1))[indices] for rollout in rollouts_list],0)
        value_preds_batch=torch.cat([rollout.value_preds[:-1].view(-1, 1)[indices] for rollout in rollouts_list],0)
        return_batch = torch.cat([rollout.returns[:-1].view(-1, 1)[indices] for rollout in rollouts_list],0)
        masks_batch = torch.cat([rollout.masks[:-1].view(-1, 1)[indices] for rollout in rollouts_list],0)
        old_action_log_probs_batch=torch.cat([rollout.action_log_probs.view(-1,1)[indices] for rollout in rollouts_list],0)
        adv_targ = torch.cat([advantages.view(-1, 1)[indices] for advantages in advantages_list],0)

        yield obs_batch, recurrent_hidden_states_batch, actions_batch, value_preds_batch, return_batch,\
              masks_batch, old_action_log_probs_batch, adv_targ 
开发者ID:sumitsk,项目名称:marl_transfer,代码行数:21,代码来源:ppo.py

示例12: update_dataloader

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import BatchSampler [as 别名]
def update_dataloader(dataloader: DataLoader, new_batch_sampler: BatchSampler) -> DataLoader:
    """Helper function to replace current batch sampler of the dataloader by a new batch sampler. Function returns new
    dataloader with new batch sampler.

    Args:
        dataloader (torch.utils.data.DataLoader): input dataloader
        new_batch_sampler (torch.utils.data.sampler.BatchSampler): new batch sampler to use

    Returns:
        DataLoader
    """
    params_keys = [k for k in dataloader.__dict__.keys() if not k.startswith("_")]
    for k in ["batch_size", "sampler", "drop_last", "batch_sampler", "dataset_kind"]:
        if k in params_keys:
            params_keys.remove(k)
    params = {k: getattr(dataloader, k) for k in params_keys}
    params["batch_sampler"] = new_batch_sampler
    return type(dataloader)(**params) 
开发者ID:pytorch,项目名称:ignite,代码行数:20,代码来源:deterministic.py

示例13: make_batch_data_sampler

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import BatchSampler [as 别名]
def make_batch_data_sampler(sampler, images_per_batch, num_iters=None, start_iter=0, drop_last=True):
    batch_sampler = data.sampler.BatchSampler(sampler, images_per_batch, drop_last=drop_last)
    if num_iters is not None:
        batch_sampler = IterationBasedBatchSampler(batch_sampler, num_iters, start_iter)
    return batch_sampler 
开发者ID:LikeLy-Journey,项目名称:SegmenTron,代码行数:7,代码来源:distributed.py


注:本文中的torch.utils.data.sampler.BatchSampler方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。