本文整理匯總了Python中torch.distributions.categorical.Categorical類的典型用法代碼示例。如果您正苦於以下問題:Python Categorical類的具體用法?Python Categorical怎麽用?Python Categorical使用的例子?那麽, 這裏精選的類代碼示例或許可以為您提供幫助。
在下文中一共展示了Categorical類的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: sample_relax
def sample_relax(logits, surrogate):
cat = Categorical(logits=logits)
u = torch.rand(B,C).clamp(1e-10, 1.-1e-10).cuda()
gumbels = -torch.log(-torch.log(u))
z = logits + gumbels
b = torch.argmax(z, dim=1) #.view(B,1)
logprob = cat.log_prob(b).view(B,1)
# czs = []
# for j in range(1):
# z = sample_relax_z(logits)
# surr_input = torch.cat([z, x, logits.detach()], dim=1)
# cz = surrogate.net(surr_input)
# czs.append(cz)
# czs = torch.stack(czs)
# cz = torch.mean(czs, dim=0)#.view(1,1)
surr_input = torch.cat([z, x, logits.detach()], dim=1)
cz = surrogate.net(surr_input)
cz_tildes = []
for j in range(1):
z_tilde = sample_relax_given_b(logits, b)
surr_input = torch.cat([z_tilde, x, logits.detach()], dim=1)
cz_tilde = surrogate.net(surr_input)
cz_tildes.append(cz_tilde)
cz_tildes = torch.stack(cz_tildes)
cz_tilde = torch.mean(cz_tildes, dim=0) #.view(B,1)
return b, logprob, cz, cz_tilde
示例2: sample_relax
def sample_relax(logits): #, k=1):
# u = torch.rand(B,C).clamp(1e-8, 1.-1e-8) #.cuda()
u = torch.rand(B,C).clamp(1e-12, 1.-1e-12) #.cuda()
gumbels = -torch.log(-torch.log(u))
z = logits + gumbels
b = torch.argmax(z, dim=1)
cat = Categorical(logits=logits)
logprob = cat.log_prob(b).view(B,1)
v_k = torch.rand(B,1).clamp(1e-12, 1.-1e-12)
z_tilde_b = -torch.log(-torch.log(v_k))
#this way seems biased even tho it shoudlnt be
# v_k = torch.gather(input=u, dim=1, index=b.view(B,1))
# z_tilde_b = torch.gather(input=z, dim=1, index=b.view(B,1))
v = torch.rand(B,C).clamp(1e-12, 1.-1e-12) #.cuda()
probs = torch.softmax(logits,dim=1).repeat(B,1)
# print (probs.shape, torch.log(v_k).shape, torch.log(v).shape)
# fasdfa
# print (v.shape)
# print (v.shape)
z_tilde = -torch.log((- torch.log(v) / probs) - torch.log(v_k))
# print (z_tilde)
# print (z_tilde_b)
z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)
# print (z_tilde)
# fasdfs
return z, b, logprob, z_tilde
示例3: relax_grad2
def relax_grad2(x, logits, b, surrogate, mixtureweights):
B = logits.shape[0]
C = logits.shape[1]
cat = Categorical(logits=logits)
# u = torch.rand(B,C).clamp(1e-10, 1.-1e-10).cuda()
u = myclamp(torch.rand(B,C).cuda())
gumbels = -torch.log(-torch.log(u))
z = logits + gumbels
# b = torch.argmax(z, dim=1) #.view(B,1)
logq = cat.log_prob(b).view(B,1)
surr_input = torch.cat([z, x, logits.detach()], dim=1)
cz = surrogate.net(surr_input)
z_tilde = sample_relax_given_b(logits, b)
surr_input = torch.cat([z_tilde, x, logits.detach()], dim=1)
cz_tilde = surrogate.net(surr_input)
logpx_given_z = logprob_undercomponent(x, component=b)
logpz = torch.log(mixtureweights[b]).view(B,1)
logpxz = logpx_given_z + logpz #[B,1]
f = logpxz - logq
net_loss = - torch.mean( (f.detach() - cz_tilde.detach()) * logq - logq + cz - cz_tilde )
grad = torch.autograd.grad([net_loss], [logits], create_graph=True, retain_graph=True)[0] #[B,C]
pb = torch.exp(logq)
return grad, pb
示例4: sample_relax_given_class
def sample_relax_given_class(logits, samp):
cat = Categorical(logits=logits)
u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
gumbels = -torch.log(-torch.log(u))
z = logits + gumbels
b = samp #torch.argmax(z, dim=1)
logprob = cat.log_prob(b).view(B,1)
u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
z_tilde_b = -torch.log(-torch.log(u_b))
z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)
z = z_tilde
u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
z_tilde_b = -torch.log(-torch.log(u_b))
u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)
return z, z_tilde, logprob
示例5: sample_true2
def sample_true2():
cat = Categorical(probs= torch.tensor(true_mixture_weights))
cluster = cat.sample()
# print (cluster)
# fsd
norm = Normal(torch.tensor([cluster*10.]).float(), torch.tensor([5.0]).float())
samp = norm.sample()
# print (samp)
return samp,cluster
示例6: sample_gmm
def sample_gmm(batch_size, mixture_weights):
cat = Categorical(probs=mixture_weights)
cluster = cat.sample([batch_size]) # [B]
mean = (cluster*10.).float().cuda()
std = torch.ones([batch_size]).cuda() *5.
norm = Normal(mean, std)
samp = norm.sample()
samp = samp.view(batch_size, 1)
return samp
示例7: OneHotCategorical
class OneHotCategorical(Distribution):
r"""
Creates a one-hot categorical distribution parameterized by `probs`.
Samples are one-hot coded vectors of size probs.size(-1).
See also: :func:`torch.distributions.Categorical`
Example::
>>> m = OneHotCategorical(torch.Tensor([ 0.25, 0.25, 0.25, 0.25 ]))
>>> m.sample() # equal probability of 0, 1, 2, 3
0
0
1
0
[torch.FloatTensor of size 4]
Args:
probs (Tensor or Variable): event probabilities
"""
params = {'probs': constraints.simplex}
support = constraints.simplex
has_enumerate_support = True
def __init__(self, probs=None, logits=None):
self._categorical = Categorical(probs, logits)
batch_shape = self._categorical.probs.size()[:-1]
event_shape = self._categorical.probs.size()[-1:]
super(OneHotCategorical, self).__init__(batch_shape, event_shape)
def sample(self, sample_shape=torch.Size()):
sample_shape = torch.Size(sample_shape)
probs = self._categorical.probs
one_hot = probs.new(self._extended_shape(sample_shape)).zero_()
indices = self._categorical.sample(sample_shape)
if indices.dim() < one_hot.dim():
indices = indices.unsqueeze(-1)
return one_hot.scatter_(-1, indices, 1)
def log_prob(self, value):
indices = value.max(-1)[1]
return self._categorical.log_prob(indices)
def entropy(self):
return self._categorical.entropy()
def enumerate_support(self):
probs = self._categorical.probs
n = self.event_shape[0]
if isinstance(probs, Variable):
values = Variable(torch.eye(n, out=probs.data.new(n, n)))
else:
values = torch.eye(n, out=probs.new(n, n))
values = values.view((n,) + (1,) * len(self.batch_shape) + (n,))
return values.expand((n,) + self.batch_shape + (n,))
示例8: test_gmm_loss
def test_gmm_loss(self):
""" Test case 1 """
n_samples = 10000
means = torch.Tensor([[0., 0.],
[1., 1.],
[-1., 1.]])
stds = torch.Tensor([[.03, .05],
[.02, .1],
[.1, .03]])
pi = torch.Tensor([.2, .3, .5])
cat_dist = Categorical(pi)
indices = cat_dist.sample((n_samples,)).long()
rands = torch.randn(n_samples, 2)
samples = means[indices] + rands * stds[indices]
class _model(nn.Module):
def __init__(self, gaussians):
super().__init__()
self.means = nn.Parameter(torch.Tensor(1, gaussians, 2).normal_())
self.pre_stds = nn.Parameter(torch.Tensor(1, gaussians, 2).normal_())
self.pi = nn.Parameter(torch.Tensor(1, gaussians).normal_())
def forward(self, *inputs):
return self.means, torch.exp(self.pre_stds), f.softmax(self.pi, dim=1)
model = _model(3)
optimizer = torch.optim.Adam(model.parameters())
iterations = 100000
log_step = iterations // 10
pbar = tqdm(total=iterations)
cum_loss = 0
for i in range(iterations):
batch = samples[torch.LongTensor(128).random_(0, n_samples)]
m, s, p = model.forward()
loss = gmm_loss(batch, m, s, p)
cum_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.set_postfix_str("avg_loss={:10.6f}".format(
cum_loss / (i + 1)))
pbar.update(1)
if i % log_step == log_step - 1:
print(m)
print(s)
print(p)
示例9: sample_true
def sample_true(batch_size):
# print (true_mixture_weights.shape)
cat = Categorical(probs=torch.tensor(true_mixture_weights))
cluster = cat.sample([batch_size]) # [B]
mean = (cluster*10.).float()
std = torch.ones([batch_size]) *5.
# print (cluster.shape)
# fsd
# norm = Normal(torch.tensor([cluster*10.]).float(), torch.tensor([5.0]).float())
norm = Normal(mean, std)
samp = norm.sample()
# print (samp.shape)
# fadsf
samp = samp.view(batch_size, 1)
return samp
示例10: reinforce_baseline
def reinforce_baseline(surrogate, x, logits, mixtureweights, k=1, get_grad=False):
B = logits.shape[0]
probs = torch.softmax(logits, dim=1)
outputs = {}
cat = Categorical(probs=probs)
grads =[]
# net_loss = 0
for jj in range(k):
cluster_H = cat.sample()
outputs['logq'] = logq = cat.log_prob(cluster_H).view(B,1)
outputs['logpx_given_z'] = logpx_given_z = logprob_undercomponent(x, component=cluster_H)
outputs['logpz'] = logpz = torch.log(mixtureweights[cluster_H]).view(B,1)
logpxz = logpx_given_z + logpz #[B,1]
surr_pred = surrogate.net(x)
outputs['f'] = f = logpxz - logq - 1.
# outputs['net_loss'] = net_loss = net_loss - torch.mean((f.detach() ) * logq)
outputs['net_loss'] = net_loss = - torch.mean((f.detach() - surr_pred.detach()) * logq)
# net_loss += - torch.mean( -logq.detach()*logq)
# surr_loss = torch.mean(torch.abs(f.detach() - surr_pred))
grad_logq = torch.autograd.grad([torch.mean(logq)], [logits], create_graph=True, retain_graph=True)[0]
surr_loss = torch.mean(((f.detach() - surr_pred) * grad_logq )**2)
if get_grad:
grad = torch.autograd.grad([net_loss], [logits], create_graph=True, retain_graph=True)[0]
grads.append(grad)
# net_loss = net_loss/ k
if get_grad:
grads = torch.stack(grads)
# print (grads.shape)
outputs['grad_avg'] = torch.mean(torch.mean(grads, dim=0),dim=0)
outputs['grad_std'] = torch.std(grads, dim=0)[0]
outputs['surr_loss'] = surr_loss
# return net_loss, f, logpx_given_z, logpz, logq
return outputs
示例11: sample_relax
def sample_relax(probs):
cat = Categorical(probs=probs)
#Sample z
u = torch.rand(B,C).cuda()
u = u.clamp(1e-8, 1.-1e-8)
gumbels = -torch.log(-torch.log(u))
z = torch.log(probs) + gumbels
b = torch.argmax(z, dim=1)
logprob = cat.log_prob(b).view(B,1)
#Sample z_tilde
u_b = torch.rand(B,1).cuda()
u_b = u_b.clamp(1e-8, 1.-1e-8)
z_tilde_b = -torch.log(-torch.log(u_b))
u = torch.rand(B,C).cuda()
u = u.clamp(1e-8, 1.-1e-8)
z_tilde = -torch.log((- torch.log(u) / probs) - torch.log(u_b))
z_tilde[:,b] = z_tilde_b
return z, b, logprob, z_tilde, gumbels
示例12: reinforce
def reinforce(x, logits, mixtureweights, k=1):
B = logits.shape[0]
probs = torch.softmax(logits, dim=1)
cat = Categorical(probs=probs)
net_loss = 0
for jj in range(k):
cluster_H = cat.sample()
logq = cat.log_prob(cluster_H).view(B,1)
logpx_given_z = logprob_undercomponent(x, component=cluster_H)
logpz = torch.log(mixtureweights[cluster_H]).view(B,1)
logpxz = logpx_given_z + logpz #[B,1]
f = logpxz - logq
net_loss += - torch.mean((f.detach() - 1.) * logq)
# net_loss += - torch.mean( -logq.detach()*logq)
net_loss = net_loss/ k
return net_loss, f, logpx_given_z, logpz, logq
示例13: sample_relax_given_class_k
def sample_relax_given_class_k(logits, samp, k):
cat = Categorical(logits=logits)
b = samp #torch.argmax(z, dim=1)
logprob = cat.log_prob(b).view(B,1)
zs = []
z_tildes = []
for i in range(k):
u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
gumbels = -torch.log(-torch.log(u))
z = logits + gumbels
u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
z_tilde_b = -torch.log(-torch.log(u_b))
z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)
z = z_tilde
u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
z_tilde_b = -torch.log(-torch.log(u_b))
u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)
zs.append(z)
z_tildes.append(z_tilde)
zs= torch.stack(zs)
z_tildes= torch.stack(z_tildes)
z = torch.mean(zs, dim=0)
z_tilde = torch.mean(z_tildes, dim=0)
return z, z_tilde, logprob
示例14: relax_grad
def relax_grad(x, logits, b, surrogate, mixtureweights):
B = logits.shape[0]
C = logits.shape[1]
cat = Categorical(logits=logits)
# u = torch.rand(B,C).clamp(1e-10, 1.-1e-10).cuda()
u = myclamp(torch.rand(B,C).cuda())
gumbels = -torch.log(-torch.log(u))
z = logits + gumbels
# b = torch.argmax(z, dim=1) #.view(B,1)
logq = cat.log_prob(b).view(B,1)
surr_input = torch.cat([z, x, logits.detach()], dim=1)
cz = surrogate.net(surr_input)
z_tilde = sample_relax_given_b(logits, b)
surr_input = torch.cat([z_tilde, x, logits.detach()], dim=1)
cz_tilde = surrogate.net(surr_input)
logpx_given_z = logprob_undercomponent(x, component=b)
logpz = torch.log(mixtureweights[b]).view(B,1)
logpxz = logpx_given_z + logpz #[B,1]
f = logpxz - logq
grad_logq = torch.autograd.grad([torch.mean(logq)], [logits], create_graph=True, retain_graph=True)[0]
grad_surr_z = torch.autograd.grad([torch.mean(cz)], [logits], create_graph=True, retain_graph=True)[0]
grad_surr_z_tilde = torch.autograd.grad([torch.mean(cz_tilde)], [logits], create_graph=True, retain_graph=True)[0]
# surr_loss = torch.mean(((f.detach() - cz_tilde) * grad_logq - grad_logq + grad_surr_z - grad_surr_z_tilde)**2, dim=1, keepdim=True)
surr_loss = ((f.detach() - cz_tilde) * grad_logq - grad_logq + grad_surr_z - grad_surr_z_tilde)**2
# print (surr_loss.shape)
# print (logq.shape)
# fasda
# print (surr_loss, torch.exp(logq))
return surr_loss, torch.exp(logq)
示例15: OneHotCategorical
class OneHotCategorical(Distribution):
r"""
Creates a one-hot categorical distribution parameterized by :attr:`probs` or
:attr:`logits`.
Samples are one-hot coded vectors of size ``probs.size(-1)``.
.. note:: :attr:`probs` will be normalized to be summing to 1.
See also: :func:`torch.distributions.Categorical` for specifications of
:attr:`probs` and :attr:`logits`.
Example::
>>> m = OneHotCategorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ]))
>>> m.sample() # equal probability of 0, 1, 2, 3
tensor([ 0., 0., 0., 1.])
Args:
probs (Tensor): event probabilities
logits (Tensor): event log probabilities
"""
arg_constraints = {'probs': constraints.simplex}
support = constraints.simplex
has_enumerate_support = True
def __init__(self, probs=None, logits=None, validate_args=None):
self._categorical = Categorical(probs, logits)
batch_shape = self._categorical.batch_shape
event_shape = self._categorical.param_shape[-1:]
super(OneHotCategorical, self).__init__(batch_shape, event_shape, validate_args=validate_args)
def _new(self, *args, **kwargs):
return self._categorical._new(*args, **kwargs)
@property
def probs(self):
return self._categorical.probs
@property
def logits(self):
return self._categorical.logits
@property
def mean(self):
return self._categorical.probs
@property
def variance(self):
return self._categorical.probs * (1 - self._categorical.probs)
@property
def param_shape(self):
return self._categorical.param_shape
def sample(self, sample_shape=torch.Size()):
sample_shape = torch.Size(sample_shape)
probs = self._categorical.probs
one_hot = probs.new(self._extended_shape(sample_shape)).zero_()
indices = self._categorical.sample(sample_shape)
if indices.dim() < one_hot.dim():
indices = indices.unsqueeze(-1)
return one_hot.scatter_(-1, indices, 1)
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
indices = value.max(-1)[1]
return self._categorical.log_prob(indices)
def entropy(self):
return self._categorical.entropy()
def enumerate_support(self):
n = self.event_shape[0]
values = self._new((n, n))
torch.eye(n, out=values)
values = values.view((n,) + (1,) * len(self.batch_shape) + (n,))
return values.expand((n,) + self.batch_shape + (n,))