本文整理匯總了Python中torch.distributions.kl.kl_divergence方法的典型用法代碼示例。如果您正苦於以下問題:Python kl.kl_divergence方法的具體用法?Python kl.kl_divergence怎麽用?Python kl.kl_divergence使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類torch.distributions.kl
的用法示例。
在下文中一共展示了kl.kl_divergence方法的11個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: kl_divergence
# 需要導入模塊: from torch.distributions import kl [as 別名]
# 或者: from torch.distributions.kl import kl_divergence [as 別名]
def kl_divergence(self, analytic=True, calculate_posterior=False, z_posterior=None):
"""
Calculate the KL divergence between the posterior and prior KL(Q||P)
analytic: calculate KL analytically or via sampling from the posterior
calculate_posterior: if we use samapling to approximate KL we can sample here or supply a sample
"""
if analytic:
#Neeed to add this to torch source code, see: https://github.com/pytorch/pytorch/issues/13545
kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space)
else:
if calculate_posterior:
z_posterior = self.posterior_latent_space.rsample()
log_posterior_prob = self.posterior_latent_space.log_prob(z_posterior)
log_prior_prob = self.prior_latent_space.log_prob(z_posterior)
kl_div = log_posterior_prob - log_prior_prob
return kl_div
示例2: elbo
# 需要導入模塊: from torch.distributions import kl [as 別名]
# 或者: from torch.distributions.kl import kl_divergence [as 別名]
def elbo(self, segm, analytic_kl=True, reconstruct_posterior_mean=False):
"""
Calculate the evidence lower bound of the log-likelihood of P(Y|X)
"""
criterion = nn.BCEWithLogitsLoss(size_average = False, reduce=False, reduction=None)
z_posterior = self.posterior_latent_space.rsample()
self.kl = torch.mean(self.kl_divergence(analytic=analytic_kl, calculate_posterior=False, z_posterior=z_posterior))
#Here we use the posterior sample sampled above
self.reconstruction = self.reconstruct(use_posterior_mean=reconstruct_posterior_mean, calculate_posterior=False, z_posterior=z_posterior)
reconstruction_loss = criterion(input=self.reconstruction, target=segm)
self.reconstruction_loss = torch.sum(reconstruction_loss)
self.mean_reconstruction_loss = torch.mean(reconstruction_loss)
return -(self.reconstruction_loss + self.beta * self.kl)
示例3: surrogate_loss
# 需要導入模塊: from torch.distributions import kl [as 別名]
# 或者: from torch.distributions.kl import kl_divergence [as 別名]
def surrogate_loss(self, train_futures, valid_futures, old_pi=None):
first_order = (old_pi is not None) or self.first_order
params = await self.adapt(train_futures,
first_order=first_order)
with torch.set_grad_enabled(old_pi is None):
valid_episodes = await valid_futures
pi = self.policy(valid_episodes.observations, params=params)
if old_pi is None:
old_pi = detach_distribution(pi)
log_ratio = (pi.log_prob(valid_episodes.actions)
- old_pi.log_prob(valid_episodes.actions))
ratio = torch.exp(log_ratio)
losses = -weighted_mean(ratio * valid_episodes.advantages,
lengths=valid_episodes.lengths)
kls = weighted_mean(kl_divergence(pi, old_pi),
lengths=valid_episodes.lengths)
return losses.mean(), kls.mean(), old_pi
示例4: forward
# 需要導入模塊: from torch.distributions import kl [as 別名]
# 或者: from torch.distributions.kl import kl_divergence [as 別名]
def forward(self, x, a=None, old_logits=None):
logits = self.logits(x)
policy = Categorical(logits=logits)
pi = policy.sample()
logp_pi = policy.log_prob(pi).squeeze()
if a is not None:
logp = policy.log_prob(a).squeeze()
else:
logp = None
if old_logits is not None:
old_policy = Categorical(logits=old_logits)
d_kl = kl_divergence(old_policy, policy).mean()
else:
d_kl = None
info = {"old_logits": logits.detach().numpy()}
return pi, logp, logp_pi, info, d_kl
示例5: kl_divergence
# 需要導入模塊: from torch.distributions import kl [as 別名]
# 或者: from torch.distributions.kl import kl_divergence [as 別名]
def kl_divergence(self, episodes, old_pis=None):
kls = []
if old_pis is None:
old_pis = [None] * len(episodes)
for (train_episodes, valid_episodes), old_pi in zip(episodes, old_pis):
# this is the inner-loop update
self.policy.reset_context()
params, _ = self.adapt(train_episodes)
pi = self.policy(valid_episodes.observations, params=params)
if old_pi is None:
old_pi = detach_distribution(pi)
mask = valid_episodes.mask
if valid_episodes.actions.dim() > 2:
mask = mask.unsqueeze(2)
kl = weighted_mean(kl_divergence(pi, old_pi), dim=0, weights=mask)
kls.append(kl)
return torch.mean(torch.stack(kls, dim=0))
示例6: kl_divergence
# 需要導入模塊: from torch.distributions import kl [as 別名]
# 或者: from torch.distributions.kl import kl_divergence [as 別名]
def kl_divergence(self):
variational_dist = self.variational_distribution
prior_dist = self.prior_distribution
mean_dist = Delta(variational_dist.mean)
covar_dist = MultivariateNormal(
torch.zeros_like(variational_dist.mean), variational_dist.lazy_covariance_matrix
)
return kl_divergence(mean_dist, prior_dist) + kl_divergence(covar_dist, prior_dist)
示例7: meta_surrogate_loss
# 需要導入模塊: from torch.distributions import kl [as 別名]
# 或者: from torch.distributions.kl import kl_divergence [as 別名]
def meta_surrogate_loss(iteration_replays, iteration_policies, policy, baseline, tau, gamma, adapt_lr):
mean_loss = 0.0
mean_kl = 0.0
for task_replays, old_policy in tqdm(zip(iteration_replays, iteration_policies),
total=len(iteration_replays),
desc='Surrogate Loss',
leave=False):
train_replays = task_replays[:-1]
valid_episodes = task_replays[-1]
new_policy = l2l.clone_module(policy)
# Fast Adapt
for train_episodes in train_replays:
new_policy = fast_adapt_a2c(new_policy, train_episodes, adapt_lr,
baseline, gamma, tau, first_order=False)
# Useful values
states = valid_episodes.state()
actions = valid_episodes.action()
next_states = valid_episodes.next_state()
rewards = valid_episodes.reward()
dones = valid_episodes.done()
# Compute KL
old_densities = old_policy.density(states)
new_densities = new_policy.density(states)
kl = kl_divergence(new_densities, old_densities).mean()
mean_kl += kl
# Compute Surrogate Loss
advantages = compute_advantages(baseline, tau, gamma, rewards, dones, states, next_states)
advantages = ch.normalize(advantages).detach()
old_log_probs = old_densities.log_prob(actions).mean(dim=1, keepdim=True).detach()
new_log_probs = new_densities.log_prob(actions).mean(dim=1, keepdim=True)
mean_loss += trpo.policy_loss(new_log_probs, old_log_probs, advantages)
mean_kl /= len(iteration_replays)
mean_loss /= len(iteration_replays)
return mean_loss, mean_kl
示例8: KL
# 需要導入模塊: from torch.distributions import kl [as 別名]
# 或者: from torch.distributions.kl import kl_divergence [as 別名]
def KL(normal_1, normal_2):
kl = kl_divergence(normal_1, normal_2)
kl = torch.mean(kl)
return kl
示例9: _kl_independent_independent
# 需要導入模塊: from torch.distributions import kl [as 別名]
# 或者: from torch.distributions.kl import kl_divergence [as 別名]
def _kl_independent_independent(p, q):
if p.reinterpreted_batch_ndims != q.reinterpreted_batch_ndims:
raise NotImplementedError
result = kl_divergence(p.base_dist, q.base_dist)
return _sum_rightmost(result, p.reinterpreted_batch_ndims)
示例10: hessian_vector_product
# 需要導入模塊: from torch.distributions import kl [as 別名]
# 或者: from torch.distributions.kl import kl_divergence [as 別名]
def hessian_vector_product(self, episodes, damping=1e-2):
"""Hessian-vector product, based on the Perlmutter method."""
def _product(vector):
kl = self.kl_divergence(episodes)
grads = torch.autograd.grad(kl, self.policy.parameters(), create_graph=True)
flat_grad_kl = parameters_to_vector(grads)
grad_kl_v = torch.dot(flat_grad_kl, vector)
grad2s = torch.autograd.grad(grad_kl_v, self.policy.parameters())
flat_grad2_kl = parameters_to_vector(grad2s)
return flat_grad2_kl + damping * vector
return _product
示例11: surrogate_loss
# 需要導入模塊: from torch.distributions import kl [as 別名]
# 或者: from torch.distributions.kl import kl_divergence [as 別名]
def surrogate_loss(self, episodes, old_pis=None):
losses, kls, pis = [], [], []
if old_pis is None:
old_pis = [None] * len(episodes)
for (train_episodes, valid_episodes), old_pi in zip(episodes, old_pis):
# do inner-loop update
self.policy.reset_context()
params, _ = self.adapt(train_episodes)
with torch.set_grad_enabled(old_pi is None):
# get action values after inner-loop update
pi = self.policy(valid_episodes.observations, params=params)
pis.append(detach_distribution(pi))
if old_pi is None:
old_pi = detach_distribution(pi)
values = self.baseline(valid_episodes)
advantages = valid_episodes.gae(values, tau=self.tau)
advantages = weighted_normalize(advantages, weights=valid_episodes.mask)
log_ratio = (pi.log_prob(valid_episodes.actions)
- old_pi.log_prob(valid_episodes.actions))
if log_ratio.dim() > 2:
log_ratio = torch.sum(log_ratio, dim=2)
ratio = torch.exp(log_ratio)
loss = -weighted_mean(ratio * advantages, dim=0, weights=valid_episodes.mask)
losses.append(loss)
mask = valid_episodes.mask
if valid_episodes.actions.dim() > 2:
mask = mask.unsqueeze(2)
kl = weighted_mean(kl_divergence(pi, old_pi), dim=0, weights=mask)
kls.append(kl)
return torch.mean(torch.stack(losses, dim=0)), torch.mean(torch.stack(kls, dim=0)), pis