本文整理汇总了Python中torch.distributions.normal.Normal方法的典型用法代码示例。如果您正苦于以下问题:Python normal.Normal方法的具体用法?Python normal.Normal怎么用?Python normal.Normal使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.distributions.normal
的用法示例。
在下文中一共展示了normal.Normal方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: predict_sentence
# 需要导入模块: from torch.distributions import normal [as 别名]
# 或者: from torch.distributions.normal import Normal [as 别名]
def predict_sentence(self, sentence_input):
"""Compute Sentence Score predictions."""
outputs = OrderedDict()
sentence_scores = self.sentence_pred(sentence_input).squeeze()
outputs[const.SENTENCE_SCORES] = sentence_scores
if self.sentence_sigma:
# Predict truncated Gaussian on [0,1]
sigma = self.sentence_sigma(sentence_input).squeeze()
outputs[const.SENT_SIGMA] = sigma
outputs['SENT_MU'] = outputs[const.SENTENCE_SCORES]
mean = outputs['SENT_MU'].clone().detach()
# Compute log-likelihood of x given mu, sigma
normal = Normal(mean, sigma)
# Renormalize on [0,1] for truncated Gaussian
partition_function = (normal.cdf(1) - normal.cdf(0)).detach()
outputs[const.SENTENCE_SCORES] = mean + (
(
sigma ** 2
* (normal.log_prob(0).exp() - normal.log_prob(1).exp())
)
/ partition_function
)
return outputs
示例2: __init__
# 需要导入模块: from torch.distributions import normal [as 别名]
# 或者: from torch.distributions.normal import Normal [as 别名]
def __init__(self, vol_size, enc_nf, dec_nf, full_size=True):
"""
Instiatiate 2018 model
:param vol_size: volume size of the atlas
:param enc_nf: the number of features maps for encoding stages
:param dec_nf: the number of features maps for decoding stages
:param full_size: boolean value full amount of decoding layers
"""
super(cvpr2018_net, self).__init__()
dim = len(vol_size)
self.unet_model = unet_core(dim, enc_nf, dec_nf, full_size)
# One conv to get the flow field
conv_fn = getattr(nn, 'Conv%dd' % dim)
self.flow = conv_fn(dec_nf[-1], dim, kernel_size=3, padding=1)
# Make flow weights + bias small. Not sure this is necessary.
nd = Normal(0, 1e-5)
self.flow.weight = nn.Parameter(nd.sample(self.flow.weight.shape))
self.flow.bias = nn.Parameter(torch.zeros(self.flow.bias.shape))
self.spatial_transform = SpatialTransformer(vol_size)
示例3: _log_prob
# 需要导入模块: from torch.distributions import normal [as 别名]
# 或者: from torch.distributions.normal import Normal [as 别名]
def _log_prob(self, r, scale_log):
"""
Compute log probability from normal distribution the same way as
torch.distributions.normal.Normal, which is:
```
-((value - loc) ** 2) / (2 * var) - log_scale - math.log(math.sqrt(2 * math.pi))
```
In the context of this class, `value = loc + r * scale`. Therefore, this
function only takes `r` & `scale`; it can be reduced to below.
The primary reason we don't use Normal class is that it currently
cannot be exported through ONNX.
"""
return -(r ** 2) / 2 - scale_log - self.const
示例4: sample_reward_next_state_terminal
# 需要导入模块: from torch.distributions import normal [as 别名]
# 或者: from torch.distributions.normal import Normal [as 别名]
def sample_reward_next_state_terminal(
self, state: rlt.FeatureData, action: rlt.FeatureData, mem_net: MemoryNetwork
):
""" Sample one-step dynamics based on the provided world model """
wm_output = mem_net(state, action)
num_mixtures = wm_output.logpi.shape[2]
mixture_idx = (
Categorical(torch.exp(wm_output.logpi.view(num_mixtures)))
.sample()
.long()
.item()
)
next_state = Normal(
wm_output.mus[0, 0, mixture_idx], wm_output.sigmas[0, 0, mixture_idx]
).sample()
reward = wm_output.reward[0, 0]
if self.terminal_effective:
not_terminal_prob = torch.sigmoid(wm_output.not_terminal[0, 0])
not_terminal = Bernoulli(not_terminal_prob).sample().long().item()
else:
not_terminal_prob = 1.0
not_terminal = 1
return reward, next_state, not_terminal, not_terminal_prob
示例5: synthesize
# 需要导入模块: from torch.distributions import normal [as 别名]
# 或者: from torch.distributions.normal import Normal [as 别名]
def synthesize(model):
global global_step
for batch_idx, (x, c) in enumerate(synth_loader):
if batch_idx < args.num_samples:
x, c = x.to(device), c.to(device)
q_0 = Normal(x.new_zeros(x.size()), x.new_ones(x.size()))
z = q_0.sample() * args.temp
torch.cuda.synchronize()
start_time = time.time()
with torch.no_grad():
y_gen = model.reverse(z, c).squeeze()
torch.cuda.synchronize()
print('{} seconds'.format(time.time() - start_time))
wav = y_gen.to(torch.device("cpu")).data.numpy()
wav_name = '{}/{}/generate_{}_{}_{}.wav'.format(args.sample_path, args.model_name,
global_step, batch_idx, args.temp)
librosa.output.write_wav(wav_name, wav, sr=22050)
print('{} Saved!'.format(wav_name))
示例6: synthesize
# 需要导入模块: from torch.distributions import normal [as 别名]
# 或者: from torch.distributions.normal import Normal [as 别名]
def synthesize(model):
global global_step
model.eval()
for batch_idx, (x, c) in enumerate(synth_loader):
if batch_idx == 0:
x, c = x.to(device), c.to(device)
q_0 = Normal(x.new_zeros(x.size()), x.new_ones(x.size()))
z = q_0.sample()
start_time = time.time()
with torch.no_grad():
y_gen = model.module.reverse(z, c).squeeze()
wav = y_gen.to(torch.device("cpu")).data.numpy()
wav_name = '{}/{}/generate_{}_{}.wav'.format(args.sample_path, args.model_name, global_step, batch_idx)
print('{} seconds'.format(time.time() - start_time))
librosa.output.write_wav(wav_name, wav, sr=22050)
print('{} Saved!'.format(wav_name))
del x, c, z, q_0, y_gen, wav
示例7: synthesize
# 需要导入模块: from torch.distributions import normal [as 别名]
# 或者: from torch.distributions.normal import Normal [as 别名]
def synthesize(model):
global global_step
model.eval()
for batch_idx, (x, c) in enumerate(synth_loader):
if batch_idx == 0:
x, c = x.to(device), c.to(device)
q_0 = Normal(x.new_zeros(x.size()), x.new_ones(x.size()))
z = q_0.sample()
start_time = time.time()
with torch.no_grad():
if args.num_gpu == 1:
y_gen = model.reverse(z, c).squeeze()
else:
y_gen = model.module.reverse(z, c).squeeze()
wav = y_gen.to(torch.device("cpu")).data.numpy()
wav_name = '{}/{}/generate_{}_{}.wav'.format(args.sample_path, args.model_name, global_step, batch_idx)
print('{} seconds'.format(time.time() - start_time))
librosa.output.write_wav(wav_name, wav, sr=22050)
print('{} Saved!'.format(wav_name))
del x, c, z, q_0, y_gen, wav
示例8: forward
# 需要导入模块: from torch.distributions import normal [as 别名]
# 或者: from torch.distributions.normal import Normal [as 别名]
def forward(self, x, a=None, old_log_std=None, old_mu=None):
mu = self.mu(x)
policy = Normal(mu, self.log_std.exp())
pi = policy.sample()
logp_pi = policy.log_prob(pi).sum(dim=1)
if a is not None:
logp = policy.log_prob(a).sum(dim=1)
else:
logp = None
if (old_mu is not None) or (old_log_std is not None):
old_policy = Normal(old_mu, old_log_std.exp())
d_kl = kl_divergence(old_policy, policy).mean()
else:
d_kl = None
info = {
"old_mu": np.squeeze(mu.detach().numpy()),
"old_log_std": self.log_std.detach().numpy(),
}
return pi, logp, logp_pi, info, d_kl
示例9: forward
# 需要导入模块: from torch.distributions import normal [as 别名]
# 或者: from torch.distributions.normal import Normal [as 别名]
def forward(self, x):
output = self.net(x)
mu = self.mu(output)
if self.output_activation:
mu = self.output_activation(mu)
log_std = self.log_std(output)
log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1)
policy = Normal(mu, torch.exp(log_std))
pi = policy.rsample() # Critical: must be rsample() and not sample()
logp_pi = torch.sum(policy.log_prob(pi), dim=1)
mu, pi, logp_pi = self._apply_squashing_func(mu, pi, logp_pi)
# make sure actions are in correct range
mu_scaled = mu * self.action_scale
pi_scaled = pi * self.action_scale
return pi_scaled, mu_scaled, logp_pi
示例10: _distribution
# 需要导入模块: from torch.distributions import normal [as 别名]
# 或者: from torch.distributions.normal import Normal [as 别名]
def _distribution(self, obs):
mu = self.mu_net(obs)
std = torch.exp(self.log_std)
return Normal(mu, std)
示例11: _log_prob_from_distribution
# 需要导入模块: from torch.distributions import normal [as 别名]
# 或者: from torch.distributions.normal import Normal [as 别名]
def _log_prob_from_distribution(self, pi, act):
return pi.log_prob(act).sum(axis=-1) # Last axis sum needed for Torch Normal distribution
示例12: forward
# 需要导入模块: from torch.distributions import normal [as 别名]
# 或者: from torch.distributions.normal import Normal [as 别名]
def forward(self, obs, deterministic=False, with_logprob=True):
net_out = self.net(obs)
mu = self.mu_layer(net_out)
log_std = self.log_std_layer(net_out)
log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
std = torch.exp(log_std)
# Pre-squash distribution and sample
pi_distribution = Normal(mu, std)
if deterministic:
# Only used for evaluating policy at test time.
pi_action = mu
else:
pi_action = pi_distribution.rsample()
if with_logprob:
# Compute logprob from Gaussian, and then apply correction for Tanh squashing.
# NOTE: The correction formula is a little bit magic. To get an understanding
# of where it comes from, check out the original SAC paper (arXiv 1801.01290)
# and look in appendix C. This is a more numerically-stable equivalent to Eq 21.
# Try deriving it yourself as a (very difficult) exercise. :)
logp_pi = pi_distribution.log_prob(pi_action).sum(axis=-1)
logp_pi -= (2*(np.log(2) - pi_action - F.softplus(-2*pi_action))).sum(axis=1)
else:
logp_pi = None
pi_action = torch.tanh(pi_action)
pi_action = self.act_limit * pi_action
return pi_action, logp_pi
示例13: sentence_loss
# 需要导入模块: from torch.distributions import normal [as 别名]
# 或者: from torch.distributions.normal import Normal [as 别名]
def sentence_loss(self, model_out, batch):
"""Compute Sentence score loss"""
sentence_pred = model_out[const.SENTENCE_SCORES]
sentence_scores = batch.sentence_scores
if not self.sentence_sigma:
return self.mse_loss(sentence_pred, sentence_scores)
else:
sigma = model_out[const.SENT_SIGMA]
mean = model_out['SENT_MU']
# Compute log-likelihood of x given mu, sigma
normal = Normal(mean, sigma)
# Renormalize on [0,1] for truncated Gaussian
partition_function = (normal.cdf(1) - normal.cdf(0)).detach()
nll = partition_function.log() - normal.log_prob(sentence_scores)
return nll.sum()
示例14: gmm_loss
# 需要导入模块: from torch.distributions import normal [as 别名]
# 或者: from torch.distributions.normal import Normal [as 别名]
def gmm_loss(batch, mus, sigmas, logpi, reduce=True): # pylint: disable=too-many-arguments
""" Computes the gmm loss.
Compute minus the log probability of batch under the GMM model described
by mus, sigmas, pi. Precisely, with bs1, bs2, ... the sizes of the batch
dimensions (several batch dimension are useful when you have both a batch
axis and a time step axis), gs the number of mixtures and fs the number of
features.
:args batch: (bs1, bs2, *, fs) torch tensor
:args mus: (bs1, bs2, *, gs, fs) torch tensor
:args sigmas: (bs1, bs2, *, gs, fs) torch tensor
:args logpi: (bs1, bs2, *, gs) torch tensor
:args reduce: if not reduce, the mean in the following formula is ommited
:returns:
loss(batch) = - mean_{i1=0..bs1, i2=0..bs2, ...} log(
sum_{k=1..gs} pi[i1, i2, ..., k] * N(
batch[i1, i2, ..., :] | mus[i1, i2, ..., k, :], sigmas[i1, i2, ..., k, :]))
NOTE: The loss is not reduced along the feature dimension (i.e. it should scale ~linearily
with fs).
"""
batch = batch.unsqueeze(-2)
normal_dist = Normal(mus, sigmas)
g_log_probs = normal_dist.log_prob(batch)
g_log_probs = logpi + torch.sum(g_log_probs, dim=-1)
max_log_probs = torch.max(g_log_probs, dim=-1, keepdim=True)[0]
g_log_probs = g_log_probs - max_log_probs
g_probs = torch.exp(g_log_probs)
probs = torch.sum(g_probs, dim=-1)
log_prob = max_log_probs.squeeze() + torch.log(probs)
if reduce:
return - torch.mean(log_prob)
return - log_prob
示例15: get_action
# 需要导入模块: from torch.distributions import normal [as 别名]
# 或者: from torch.distributions.normal import Normal [as 别名]
def get_action(self, x):
mean, log_std = self.forward(x)
std = log_std.exp()
normal = Normal(mean, std)
x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1))
y_t = torch.tanh(x_t)
action = y_t * self.action_scale + self.action_bias
log_prob = normal.log_prob(x_t)
# Enforcing Action Bound
log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + 1e-6)
log_prob = log_prob.sum(1, keepdim=True)
mean = torch.tanh(mean) * self.action_scale + self.action_bias
return action, log_prob, mean