本文整理汇总了Python中torch.utils.data.sampler.BatchSampler方法的典型用法代码示例。如果您正苦于以下问题:Python sampler.BatchSampler方法的具体用法?Python sampler.BatchSampler怎么用?Python sampler.BatchSampler使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.utils.data.sampler
的用法示例。
在下文中一共展示了sampler.BatchSampler方法的13个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test_number_of_iters_and_elements
# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import BatchSampler [as 别名]
def test_number_of_iters_and_elements(self):
for batch_size in [2, 3, 4]:
for num_iterations in [4, 10, 20]:
for drop_last in [False, True]:
dataset = [i for i in range(10)]
sampler = SequentialSampler(dataset)
batch_sampler = BatchSampler(
sampler, batch_size, drop_last=drop_last
)
iter_sampler = IterationBasedBatchSampler(
batch_sampler, num_iterations
)
assert len(iter_sampler) == num_iterations
for i, batch in enumerate(iter_sampler):
start = (i % len(batch_sampler)) * batch_size
end = min(start + batch_size, len(dataset))
expected = [x for x in range(start, end)]
self.assertEqual(batch, expected)
示例2: update
# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import BatchSampler [as 别名]
def update(self):
if self.memory_count >= self.capacity:
state = torch.tensor([t.state for t in self.memory]).float()
action = torch.LongTensor([t.action for t in self.memory]).view(-1,1).long()
reward = torch.tensor([t.reward for t in self.memory]).float()
next_state = torch.tensor([t.next_state for t in self.memory]).float()
reward = (reward - reward.mean()) / (reward.std() + 1e-7)
with torch.no_grad():
target_v = reward + self.gamma * self.target_net(next_state).max(1)[0]
#Update...
for index in BatchSampler(SubsetRandomSampler(range(len(self.memory))), batch_size=self.batch_size, drop_last=False):
v = (self.act_net(state).gather(1, action))[index]
loss = self.loss_func(target_v[index].unsqueeze(1), (self.act_net(state).gather(1, action))[index])
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.writer.add_scalar('loss/value_loss', loss, self.update_count)
self.update_count +=1
if self.update_count % 100 ==0:
self.target_net.load_state_dict(self.act_net.state_dict())
else:
print("Memory Buff is too less")
示例3: sample
# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import BatchSampler [as 别名]
def sample(self, advantages, num_mini_batch):
num_steps, num_processes = self.rewards.size()[0:2]
batch_size = num_processes * num_steps
# Make sure we have at least enough for a bunch of batches of size 1.
assert batch_size >= num_mini_batch
mini_batch_size = batch_size // num_mini_batch
sampler = BatchSampler(SubsetRandomSampler(range(batch_size)), mini_batch_size, drop_last=False)
for indices in sampler:
observations_batch = self.observations[:-1].view(-1,
*self.observations.size()[2:])[indices]
actions_batch = self.actions.view(-1, self.actions.size(-1))[indices]
return_batch = self.returns[:-1].view(-1, 1)[indices]
masks_batch = self.masks[:-1].view(-1, 1)[indices]
old_action_log_probs_batch = self.action_log_probs.view(-1, 1)[indices]
adv = advantages.view(-1, 1)[indices]
yield observations_batch, actions_batch, \
return_batch, masks_batch, old_action_log_probs_batch, adv
示例4: test_distributed_batch_sampler
# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import BatchSampler [as 别名]
def test_distributed_batch_sampler():
sampler = SequentialSampler(list(range(15)))
batch_sampler = BatchSampler(sampler, 10, False)
distributed_sampler = DistributedBatchSampler(batch_sampler, num_replicas=4, rank=0)
assert list(distributed_sampler) == [[0, 4, 8], [10, 14]]
assert len(distributed_sampler) == 2
distributed_sampler = DistributedBatchSampler(batch_sampler, num_replicas=4, rank=1)
assert list(distributed_sampler) == [[1, 5, 9], [11]]
assert len(distributed_sampler) == 2
distributed_sampler = DistributedBatchSampler(batch_sampler, num_replicas=4, rank=2)
assert list(distributed_sampler) == [[2, 6], [12]]
assert len(distributed_sampler) == 2
distributed_sampler = DistributedBatchSampler(batch_sampler, num_replicas=4, rank=3)
assert list(distributed_sampler) == [[3, 7], [13]]
assert len(distributed_sampler) == 2
示例5: feed_forward_generator
# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import BatchSampler [as 别名]
def feed_forward_generator(self, advantages, num_mini_batch):
num_steps, num_processes = self.rewards.size()[0:2]
batch_size = num_processes * num_steps
assert batch_size >= num_mini_batch, (
"PPO requires the number of processes ({}) "
"* number of steps ({}) = {} "
"to be greater than or equal to the number of PPO mini batches ({})."
"".format(num_processes, num_steps, num_processes * num_steps, num_mini_batch))
mini_batch_size = batch_size // num_mini_batch
sampler = BatchSampler(SubsetRandomSampler(range(batch_size)), mini_batch_size, drop_last=False)
for indices in sampler:
obs_batch = self.obs[:-1].view(-1, *self.obs.size()[2:])[indices]
recurrent_hidden_states_batch = self.recurrent_hidden_states[:-1].view(-1,
self.recurrent_hidden_states.size(-1))[indices]
actions_batch = self.actions.view(-1, self.actions.size(-1))[indices]
value_preds_batch = self.value_preds[:-1].view(-1, 1)[indices]
return_batch = self.returns[:-1].view(-1, 1)[indices]
masks_batch = self.masks[:-1].view(-1, 1)[indices]
old_action_log_probs_batch = self.action_log_probs.view(-1, 1)[indices]
adv_targ = advantages.view(-1, 1)[indices]
yield obs_batch, recurrent_hidden_states_batch, actions_batch, \
value_preds_batch, return_batch, masks_batch, old_action_log_probs_batch, adv_targ
示例6: feed_forward_generator
# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import BatchSampler [as 别名]
def feed_forward_generator(self, advantages, num_mini_batch):
num_steps, num_processes = self.rewards.size()[0:2]
batch_size = num_processes * num_steps
assert batch_size >= num_mini_batch, (
"PPO requires the number of processes ({}) "
"* number of steps ({}) = {} "
"to be greater than or equal to the number of PPO mini batches ({})."
"".format(num_processes, num_steps, num_processes * num_steps, num_mini_batch))
mini_batch_size = batch_size // num_mini_batch
sampler = BatchSampler(SubsetRandomSampler(range(batch_size)), mini_batch_size, drop_last=False)
for indices in sampler:
obs_batch = self.obs[:-1].view(-1, *self.obs.size()[2:])[indices]
recurrent_hidden_states_batch = self.recurrent_hidden_states[:-1].view(-1,
self.recurrent_hidden_states.size(-1))[indices]
actions_batch = self.actions.view(-1, self.actions.size(-1))[indices]
return_batch = self.returns[:-1].view(-1, 1)[indices]
masks_batch = self.masks[:-1].view(-1, 1)[indices]
old_action_log_probs_batch = self.action_log_probs.view(-1, 1)[indices]
adv_targ = advantages.view(-1, 1)[indices]
yield obs_batch, recurrent_hidden_states_batch, actions_batch, \
return_batch, masks_batch, old_action_log_probs_batch, adv_targ
示例7: feed_forward_generator
# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import BatchSampler [as 别名]
def feed_forward_generator(self, advantages, num_mini_batch):
num_steps, num_processes = self.rewards.size()[0:2]
batch_size = num_processes * num_steps
mini_batch_size = batch_size // num_mini_batch
sampler = BatchSampler(SubsetRandomSampler(range(batch_size)), mini_batch_size, drop_last=False)
for indices in sampler:
indices = torch.LongTensor(indices)
if advantages.is_cuda:
indices = indices.cuda()
observations_batch = self.observations[:-1].view(-1,
*self.observations.size()[2:])[indices]
actions_batch = self.actions.view(-1, self.actions.size(-1))[indices]
return_batch = self.returns[:-1].view(-1, 1)[indices]
masks_batch = self.masks[:-1].view(-1, 1)[indices]
old_action_log_probs_batch = self.action_log_probs.view(-1, 1)[indices]
adv_targ = advantages.view(-1, 1)[indices]
yield observations_batch, actions_batch, \
return_batch, masks_batch, old_action_log_probs_batch, adv_targ
开发者ID:BoyuanYan,项目名称:Actor-Critic-Based-Resource-Allocation-for-Multimodal-Optical-Networks,代码行数:23,代码来源:storage.py
示例8: get_data_loader
# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import BatchSampler [as 别名]
def get_data_loader(X_in, y_in, batch_size, extended_batch_sampler=True, epoch_size=25000, upsample=False, seed=42):
""" Returns a dataloader that enables larger epochs on small datasets and
has upsampling functionality.
# Arguments:
X_in: Inputs of the given dataset.
y_in: Outputs of the given dataset.
batch_size: Batch size.
epoch_size: Number of samples in an epoch.
upsample: Whether upsampling should be done. This flag should only be
set on binary class problems.
# Returns:
DataLoader.
"""
dataset = DeepMojiDataset(X_in, y_in)
if extended_batch_sampler:
batch_sampler = DeepMojiBatchSampler(y_in, batch_size, epoch_size=epoch_size, upsample=upsample, seed=seed)
else:
batch_sampler = BatchSampler(SequentialSampler(y_in), batch_size, drop_last=False)
return DataLoader(dataset, batch_sampler=batch_sampler, num_workers=0)
示例9: feed_forward_generator
# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import BatchSampler [as 别名]
def feed_forward_generator(self, advantages, num_mini_batch):
num_steps, num_processes = self.rewards.size()[0:2]
batch_size = num_processes * num_steps
assert batch_size >= num_mini_batch, (
f"PPO requires the number processes ({num_processes}) "
f"* number of steps ({num_steps}) = {num_processes * num_steps} "
f"to be greater than or equal to the number of PPO mini batches ({num_mini_batch}).")
mini_batch_size = batch_size // num_mini_batch
observations_batch = {}
sampler = BatchSampler(SubsetRandomSampler(range(batch_size)), mini_batch_size, drop_last=False)
for indices in sampler:
for k, sensor_ob in self.observations.items():
observations_batch[k] = sensor_ob[:-1].view(-1, *sensor_ob.size()[2:])[indices]
states_batch = self.states[:-1].view(-1, self.states.size(-1))[indices]
actions_batch = self.actions.view(-1, self.actions.size(-1))[indices]
return_batch = self.returns[:-1].view(-1, 1)[indices]
masks_batch = self.masks[:-1].view(-1, 1)[indices]
old_action_log_probs_batch = self.action_log_probs.view(-1, 1)[indices]
adv_targ = advantages.view(-1, 1)[indices]
yield observations_batch, states_batch, actions_batch, \
return_batch, masks_batch, old_action_log_probs_batch, adv_targ
示例10: feed_forward_generator
# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import BatchSampler [as 别名]
def feed_forward_generator(self, advantages, num_mini_batch, sampler=None):
num_steps, num_processes = self.rewards.size()[0:2]
batch_size = num_processes * num_steps
assert batch_size >= num_mini_batch, (
"PPO requires the number of processes ({}) "
"* number of steps ({}) = {} "
"to be greater than or equal to the number of PPO mini batches ({})."
"".format(num_processes, num_steps, num_processes * num_steps, num_mini_batch))
mini_batch_size = batch_size // num_mini_batch
if sampler is None:
sampler = BatchSampler(SubsetRandomSampler(range(batch_size)), mini_batch_size, drop_last=False)
for indices in sampler:
obs_batch = self.obs[:-1].view(-1, *self.obs.size()[2:])[indices]
recurrent_hidden_states_batch = self.recurrent_hidden_states[:-1].view(-1,
self.recurrent_hidden_states.size(-1))[indices]
actions_batch = self.actions.view(-1, self.actions.size(-1))[indices]
value_preds_batch = self.value_preds[:-1].view(-1, 1)[indices]
return_batch = self.returns[:-1].view(-1, 1)[indices]
masks_batch = self.masks[:-1].view(-1, 1)[indices]
old_action_log_probs_batch = self.action_log_probs.view(-1, 1)[indices]
adv_targ = advantages.view(-1, 1)[indices]
yield obs_batch, recurrent_hidden_states_batch, actions_batch, value_preds_batch, return_batch, \
masks_batch, old_action_log_probs_batch, adv_targ
示例11: magent_feed_forward_generator
# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import BatchSampler [as 别名]
def magent_feed_forward_generator(rollouts_list, advantages_list, num_mini_batch):
num_steps, num_processes = rollouts_list[0].rewards.size()[0:2]
batch_size = num_processes * num_steps
mini_batch_size = int((batch_size/num_mini_batch)) # size of minibatch for each agent
sampler = BatchSampler(SubsetRandomSampler(range(batch_size)), mini_batch_size, drop_last=False)
for indices in sampler:
obs_batch=torch.cat([rollout.obs[:-1].view(-1,*rollout.obs.size()[2:])[indices] for rollout in rollouts_list],0)
recurrent_hidden_states_batch = torch.cat([rollout.recurrent_hidden_states[:-1].view(-1,
rollout.recurrent_hidden_states.size(-1))[indices] for rollout in rollouts_list],0)
actions_batch = torch.cat([rollout.actions.view(-1,
rollout.actions.size(-1))[indices] for rollout in rollouts_list],0)
value_preds_batch=torch.cat([rollout.value_preds[:-1].view(-1, 1)[indices] for rollout in rollouts_list],0)
return_batch = torch.cat([rollout.returns[:-1].view(-1, 1)[indices] for rollout in rollouts_list],0)
masks_batch = torch.cat([rollout.masks[:-1].view(-1, 1)[indices] for rollout in rollouts_list],0)
old_action_log_probs_batch=torch.cat([rollout.action_log_probs.view(-1,1)[indices] for rollout in rollouts_list],0)
adv_targ = torch.cat([advantages.view(-1, 1)[indices] for advantages in advantages_list],0)
yield obs_batch, recurrent_hidden_states_batch, actions_batch, value_preds_batch, return_batch,\
masks_batch, old_action_log_probs_batch, adv_targ
示例12: update_dataloader
# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import BatchSampler [as 别名]
def update_dataloader(dataloader: DataLoader, new_batch_sampler: BatchSampler) -> DataLoader:
"""Helper function to replace current batch sampler of the dataloader by a new batch sampler. Function returns new
dataloader with new batch sampler.
Args:
dataloader (torch.utils.data.DataLoader): input dataloader
new_batch_sampler (torch.utils.data.sampler.BatchSampler): new batch sampler to use
Returns:
DataLoader
"""
params_keys = [k for k in dataloader.__dict__.keys() if not k.startswith("_")]
for k in ["batch_size", "sampler", "drop_last", "batch_sampler", "dataset_kind"]:
if k in params_keys:
params_keys.remove(k)
params = {k: getattr(dataloader, k) for k in params_keys}
params["batch_sampler"] = new_batch_sampler
return type(dataloader)(**params)
示例13: make_batch_data_sampler
# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import BatchSampler [as 别名]
def make_batch_data_sampler(sampler, images_per_batch, num_iters=None, start_iter=0, drop_last=True):
batch_sampler = data.sampler.BatchSampler(sampler, images_per_batch, drop_last=drop_last)
if num_iters is not None:
batch_sampler = IterationBasedBatchSampler(batch_sampler, num_iters, start_iter)
return batch_sampler