本文整理汇总了Python中torch.distributions.categorical.Categorical方法的典型用法代码示例。如果您正苦于以下问题:Python categorical.Categorical方法的具体用法?Python categorical.Categorical怎么用?Python categorical.Categorical使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.distributions.categorical
的用法示例。
在下文中一共展示了categorical.Categorical方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: generate_iters_indices
# 需要导入模块: from torch.distributions import categorical [as 别名]
# 或者: from torch.distributions.categorical import Categorical [as 别名]
def generate_iters_indices(self, num_of_iters):
from_iter = len(self.iter_indices_per_iteration)
for iter_num in range(from_iter, from_iter+num_of_iters):
# Get random number of samples per task (according to iteration distribution)
tsks = Categorical(probs=self.tasks_probs_over_iterations[iter_num]).sample(torch.Size([self.samples_in_batch]))
# Generate samples indices for iter_num
iter_indices = torch.zeros(0, dtype=torch.int32)
for task_idx in range(self.num_of_tasks):
if self.tasks_probs_over_iterations[iter_num][task_idx] > 0:
num_samples_from_task = (tsks == task_idx).sum().item()
self.samples_distribution_over_time[task_idx].append(num_samples_from_task)
# Randomize indices for each task (to allow creation of random task batch)
tasks_inner_permute = np.random.permutation(len(self.tasks_samples_indices[task_idx]))
rand_indices_of_task = tasks_inner_permute[:num_samples_from_task]
iter_indices = torch.cat([iter_indices, self.tasks_samples_indices[task_idx][rand_indices_of_task]])
else:
self.samples_distribution_over_time[task_idx].append(0)
self.iter_indices_per_iteration.append(iter_indices.tolist())
示例2: step
# 需要导入模块: from torch.distributions import categorical [as 别名]
# 或者: from torch.distributions.categorical import Categorical [as 别名]
def step(self, action):
""" One step forward """
with torch.no_grad():
action = torch.Tensor(action).unsqueeze(0)
mu, sigma, pi, r, d, n_h = self._rnn(action, self._lstate, self._hstate)
pi = pi.squeeze()
mixt = Categorical(torch.exp(pi)).sample().item()
self._lstate = mu[:, mixt, :] # + sigma[:, mixt, :] * torch.randn_like(mu[:, mixt, :])
self._hstate = n_h
self._obs = self._decoder(self._lstate)
np_obs = self._obs.numpy()
np_obs = np.clip(np_obs, 0, 1) * 255
np_obs = np.transpose(np_obs, (0, 2, 3, 1))
np_obs = np_obs.squeeze()
np_obs = np_obs.astype(np.uint8)
self._visual_obs = np_obs
return np_obs, r.item(), d.item() > 0
示例3: forward
# 需要导入模块: from torch.distributions import categorical [as 别名]
# 或者: from torch.distributions.categorical import Categorical [as 别名]
def forward(self):
inputs, h0 = self.input_vars, None
log_probs, entropys, sampled_arch = [], [], []
for iedge in range(self.num_edge):
outputs, h0 = self.w_lstm(inputs, h0)
logits = self.w_pred(outputs)
logits = logits / self.temperature
logits = self.tanh_constant * torch.tanh(logits)
# distribution
op_distribution = Categorical(logits=logits)
op_index = op_distribution.sample()
sampled_arch.append( op_index.item() )
op_log_prob = op_distribution.log_prob(op_index)
log_probs.append( op_log_prob.view(-1) )
op_entropy = op_distribution.entropy()
entropys.append( op_entropy.view(-1) )
# obtain the input embedding for the next step
inputs = self.w_embd(op_index)
return torch.sum(torch.cat(log_probs)), torch.sum(torch.cat(entropys)), sampled_arch
示例4: forward
# 需要导入模块: from torch.distributions import categorical [as 别名]
# 或者: from torch.distributions.categorical import Categorical [as 别名]
def forward(self, x, a=None, old_logits=None):
logits = self.logits(x)
policy = Categorical(logits=logits)
pi = policy.sample()
logp_pi = policy.log_prob(pi).squeeze()
if a is not None:
logp = policy.log_prob(a).squeeze()
else:
logp = None
if old_logits is not None:
old_policy = Categorical(logits=old_logits)
d_kl = kl_divergence(old_policy, policy).mean()
else:
d_kl = None
info = {"old_logits": logits.detach().numpy()}
return pi, logp, logp_pi, info, d_kl
示例5: forward
# 需要导入模块: from torch.distributions import categorical [as 别名]
# 或者: from torch.distributions.categorical import Categorical [as 别名]
def forward(self, obs, memory):
x = obs.image.transpose(1, 3).transpose(2, 3)
x = self.image_conv(x)
x = x.reshape(x.shape[0], -1)
if self.use_memory:
hidden = (memory[:, :self.semi_memory_size], memory[:, self.semi_memory_size:])
hidden = self.memory_rnn(x, hidden)
embedding = hidden[0]
memory = torch.cat(hidden, dim=1)
else:
embedding = x
if self.use_text:
embed_text = self._get_embed_text(obs.text)
embedding = torch.cat((embedding, embed_text), dim=1)
x = self.actor(embedding)
dist = Categorical(logits=F.log_softmax(x, dim=1))
x = self.critic(embedding)
value = x.squeeze(1)
return dist, value, memory
示例6: __init__
# 需要导入模块: from torch.distributions import categorical [as 别名]
# 或者: from torch.distributions.categorical import Categorical [as 别名]
def __init__(self, scores, mask=None):
self.mask = mask
if mask is None:
self.cat_distr = TorchCategorical(F.softmax(scores, dim=-1))
self.n = scores.shape[0]
self.log_n = math.log(self.n)
else:
self.n = self.mask.sum(dim=-1)
self.log_n = (self.n + 1e-17).log()
self.cat_distr = TorchCategorical(Categorical.masked_softmax(scores, self.mask))
示例7: _distribution
# 需要导入模块: from torch.distributions import categorical [as 别名]
# 或者: from torch.distributions.categorical import Categorical [as 别名]
def _distribution(self, obs):
logits = self.logits_net(obs)
return Categorical(logits=logits)
示例8: get_action
# 需要导入模块: from torch.distributions import categorical [as 别名]
# 或者: from torch.distributions.categorical import Categorical [as 别名]
def get_action(self, x, action=None):
logits = self.actor(x)
probs = Categorical(logits=logits)
if action is None:
action = probs.sample()
return action, probs.log_prob(action), probs.entropy()
示例9: get_action
# 需要导入模块: from torch.distributions import categorical [as 别名]
# 或者: from torch.distributions.categorical import Categorical [as 别名]
def get_action(self, x, action=None):
logits = self.actor(self.forward(x))
probs = Categorical(logits=logits)
if action is None:
action = probs.sample()
return action, probs.log_prob(action), probs.entropy()
示例10: get_action
# 需要导入模块: from torch.distributions import categorical [as 别名]
# 或者: from torch.distributions.categorical import Categorical [as 别名]
def get_action(self, x, action=None):
logits = self.forward(x)
probs = Categorical(logits=logits)
if action is None:
action = probs.sample()
return action, probs.log_prob(action), probs.entropy()
示例11: get_action
# 需要导入模块: from torch.distributions import categorical [as 别名]
# 或者: from torch.distributions.categorical import Categorical [as 别名]
def get_action(self, x, action=None):
logits = self.actor(self.forward(x))
split_logits = torch.split(logits, envs.action_space.nvec.tolist(), dim=1)
multi_categoricals = [Categorical(logits=logits) for logits in split_logits]
if action is None:
action = torch.stack([categorical.sample() for categorical in multi_categoricals])
logprob = torch.stack([categorical.log_prob(a) for a, categorical in zip(action, multi_categoricals)])
entropy = torch.stack([categorical.entropy() for categorical in multi_categoricals])
return action, logprob.sum(0), entropy.sum(0)
示例12: get_action
# 需要导入模块: from torch.distributions import categorical [as 别名]
# 或者: from torch.distributions.categorical import Categorical [as 别名]
def get_action(self, x):
action_probs, action_logps = self.forward(x)
dist = Categorical(probs=action_probs)
dist.entropy()
return dist.sample(), action_probs, action_logps, dist.entropy().sum()
示例13: predict_action
# 需要导入模块: from torch.distributions import categorical [as 别名]
# 或者: from torch.distributions.categorical import Categorical [as 别名]
def predict_action(self, img, memory):
batch_size = img.size(0)
#x = img.view(batch_size, -1)
x = self.encoder(img)
memory = self.rnn(x, memory)
action_probs = self.action_probs(memory)
dist = Categorical(logits=action_probs)
return dist, memory
##############################################################################
示例14: discrete_policy
# 需要导入模块: from torch.distributions import categorical [as 别名]
# 或者: from torch.distributions.categorical import Categorical [as 别名]
def discrete_policy(self, obs):
"""
Calcula una distribución discreta o categórica sobre las acciones dadas las observaciones del agente
:param obs: observaciones del agente
:return: politica formada por una distribución sobre las acciones a partir de las observaciones
"""
logits = self.actor(obs)
value = self.critic(obs)
self.logits = logits.to(torch.device("cpu"))
self.value = value.to(torch.device("cpu"))
self.action_distribution = Categorical(logits = self.logits)
return self.action_distribution
示例15: discrete_policy
# 需要导入模块: from torch.distributions import categorical [as 别名]
# 或者: from torch.distributions.categorical import Categorical [as 别名]
def discrete_policy(self, obs):
"""
Calculates a discrete/categorical distribution over actions given observations
:param obs: Agent's observation
:return: policy, a distribution over actions for the given observation
"""
logits = self.actor(obs)
value = self.critic(obs)
self.logits = logits.to(torch.device("cpu"))
self.value = value.to(torch.device("cpu"))
self.action_distribution = Categorical(logits=self.logits)
return self.action_distribution