本文整理匯總了Python中torch.distributions.categorical.Categorical.log_prob方法的典型用法代碼示例。如果您正苦於以下問題:Python Categorical.log_prob方法的具體用法?Python Categorical.log_prob怎麽用?Python Categorical.log_prob使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類torch.distributions.categorical.Categorical
的用法示例。
在下文中一共展示了Categorical.log_prob方法的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: relax_grad2
# 需要導入模塊: from torch.distributions.categorical import Categorical [as 別名]
# 或者: from torch.distributions.categorical.Categorical import log_prob [as 別名]
def relax_grad2(x, logits, b, surrogate, mixtureweights):
B = logits.shape[0]
C = logits.shape[1]
cat = Categorical(logits=logits)
# u = torch.rand(B,C).clamp(1e-10, 1.-1e-10).cuda()
u = myclamp(torch.rand(B,C).cuda())
gumbels = -torch.log(-torch.log(u))
z = logits + gumbels
# b = torch.argmax(z, dim=1) #.view(B,1)
logq = cat.log_prob(b).view(B,1)
surr_input = torch.cat([z, x, logits.detach()], dim=1)
cz = surrogate.net(surr_input)
z_tilde = sample_relax_given_b(logits, b)
surr_input = torch.cat([z_tilde, x, logits.detach()], dim=1)
cz_tilde = surrogate.net(surr_input)
logpx_given_z = logprob_undercomponent(x, component=b)
logpz = torch.log(mixtureweights[b]).view(B,1)
logpxz = logpx_given_z + logpz #[B,1]
f = logpxz - logq
net_loss = - torch.mean( (f.detach() - cz_tilde.detach()) * logq - logq + cz - cz_tilde )
grad = torch.autograd.grad([net_loss], [logits], create_graph=True, retain_graph=True)[0] #[B,C]
pb = torch.exp(logq)
return grad, pb
示例2: sample_relax
# 需要導入模塊: from torch.distributions.categorical import Categorical [as 別名]
# 或者: from torch.distributions.categorical.Categorical import log_prob [as 別名]
def sample_relax(logits): #, k=1):
# u = torch.rand(B,C).clamp(1e-8, 1.-1e-8) #.cuda()
u = torch.rand(B,C).clamp(1e-12, 1.-1e-12) #.cuda()
gumbels = -torch.log(-torch.log(u))
z = logits + gumbels
b = torch.argmax(z, dim=1)
cat = Categorical(logits=logits)
logprob = cat.log_prob(b).view(B,1)
v_k = torch.rand(B,1).clamp(1e-12, 1.-1e-12)
z_tilde_b = -torch.log(-torch.log(v_k))
#this way seems biased even tho it shoudlnt be
# v_k = torch.gather(input=u, dim=1, index=b.view(B,1))
# z_tilde_b = torch.gather(input=z, dim=1, index=b.view(B,1))
v = torch.rand(B,C).clamp(1e-12, 1.-1e-12) #.cuda()
probs = torch.softmax(logits,dim=1).repeat(B,1)
# print (probs.shape, torch.log(v_k).shape, torch.log(v).shape)
# fasdfa
# print (v.shape)
# print (v.shape)
z_tilde = -torch.log((- torch.log(v) / probs) - torch.log(v_k))
# print (z_tilde)
# print (z_tilde_b)
z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)
# print (z_tilde)
# fasdfs
return z, b, logprob, z_tilde
示例3: sample_relax_given_class
# 需要導入模塊: from torch.distributions.categorical import Categorical [as 別名]
# 或者: from torch.distributions.categorical.Categorical import log_prob [as 別名]
def sample_relax_given_class(logits, samp):
cat = Categorical(logits=logits)
u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
gumbels = -torch.log(-torch.log(u))
z = logits + gumbels
b = samp #torch.argmax(z, dim=1)
logprob = cat.log_prob(b).view(B,1)
u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
z_tilde_b = -torch.log(-torch.log(u_b))
z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)
z = z_tilde
u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
z_tilde_b = -torch.log(-torch.log(u_b))
u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)
return z, z_tilde, logprob
示例4: sample_relax
# 需要導入模塊: from torch.distributions.categorical import Categorical [as 別名]
# 或者: from torch.distributions.categorical.Categorical import log_prob [as 別名]
def sample_relax(logits, surrogate):
cat = Categorical(logits=logits)
u = torch.rand(B,C).clamp(1e-10, 1.-1e-10).cuda()
gumbels = -torch.log(-torch.log(u))
z = logits + gumbels
b = torch.argmax(z, dim=1) #.view(B,1)
logprob = cat.log_prob(b).view(B,1)
# czs = []
# for j in range(1):
# z = sample_relax_z(logits)
# surr_input = torch.cat([z, x, logits.detach()], dim=1)
# cz = surrogate.net(surr_input)
# czs.append(cz)
# czs = torch.stack(czs)
# cz = torch.mean(czs, dim=0)#.view(1,1)
surr_input = torch.cat([z, x, logits.detach()], dim=1)
cz = surrogate.net(surr_input)
cz_tildes = []
for j in range(1):
z_tilde = sample_relax_given_b(logits, b)
surr_input = torch.cat([z_tilde, x, logits.detach()], dim=1)
cz_tilde = surrogate.net(surr_input)
cz_tildes.append(cz_tilde)
cz_tildes = torch.stack(cz_tildes)
cz_tilde = torch.mean(cz_tildes, dim=0) #.view(B,1)
return b, logprob, cz, cz_tilde
示例5: OneHotCategorical
# 需要導入模塊: from torch.distributions.categorical import Categorical [as 別名]
# 或者: from torch.distributions.categorical.Categorical import log_prob [as 別名]
class OneHotCategorical(Distribution):
r"""
Creates a one-hot categorical distribution parameterized by `probs`.
Samples are one-hot coded vectors of size probs.size(-1).
See also: :func:`torch.distributions.Categorical`
Example::
>>> m = OneHotCategorical(torch.Tensor([ 0.25, 0.25, 0.25, 0.25 ]))
>>> m.sample() # equal probability of 0, 1, 2, 3
0
0
1
0
[torch.FloatTensor of size 4]
Args:
probs (Tensor or Variable): event probabilities
"""
params = {'probs': constraints.simplex}
support = constraints.simplex
has_enumerate_support = True
def __init__(self, probs=None, logits=None):
self._categorical = Categorical(probs, logits)
batch_shape = self._categorical.probs.size()[:-1]
event_shape = self._categorical.probs.size()[-1:]
super(OneHotCategorical, self).__init__(batch_shape, event_shape)
def sample(self, sample_shape=torch.Size()):
sample_shape = torch.Size(sample_shape)
probs = self._categorical.probs
one_hot = probs.new(self._extended_shape(sample_shape)).zero_()
indices = self._categorical.sample(sample_shape)
if indices.dim() < one_hot.dim():
indices = indices.unsqueeze(-1)
return one_hot.scatter_(-1, indices, 1)
def log_prob(self, value):
indices = value.max(-1)[1]
return self._categorical.log_prob(indices)
def entropy(self):
return self._categorical.entropy()
def enumerate_support(self):
probs = self._categorical.probs
n = self.event_shape[0]
if isinstance(probs, Variable):
values = Variable(torch.eye(n, out=probs.data.new(n, n)))
else:
values = torch.eye(n, out=probs.new(n, n))
values = values.view((n,) + (1,) * len(self.batch_shape) + (n,))
return values.expand((n,) + self.batch_shape + (n,))
示例6: reinforce_baseline
# 需要導入模塊: from torch.distributions.categorical import Categorical [as 別名]
# 或者: from torch.distributions.categorical.Categorical import log_prob [as 別名]
def reinforce_baseline(surrogate, x, logits, mixtureweights, k=1, get_grad=False):
B = logits.shape[0]
probs = torch.softmax(logits, dim=1)
outputs = {}
cat = Categorical(probs=probs)
grads =[]
# net_loss = 0
for jj in range(k):
cluster_H = cat.sample()
outputs['logq'] = logq = cat.log_prob(cluster_H).view(B,1)
outputs['logpx_given_z'] = logpx_given_z = logprob_undercomponent(x, component=cluster_H)
outputs['logpz'] = logpz = torch.log(mixtureweights[cluster_H]).view(B,1)
logpxz = logpx_given_z + logpz #[B,1]
surr_pred = surrogate.net(x)
outputs['f'] = f = logpxz - logq - 1.
# outputs['net_loss'] = net_loss = net_loss - torch.mean((f.detach() ) * logq)
outputs['net_loss'] = net_loss = - torch.mean((f.detach() - surr_pred.detach()) * logq)
# net_loss += - torch.mean( -logq.detach()*logq)
# surr_loss = torch.mean(torch.abs(f.detach() - surr_pred))
grad_logq = torch.autograd.grad([torch.mean(logq)], [logits], create_graph=True, retain_graph=True)[0]
surr_loss = torch.mean(((f.detach() - surr_pred) * grad_logq )**2)
if get_grad:
grad = torch.autograd.grad([net_loss], [logits], create_graph=True, retain_graph=True)[0]
grads.append(grad)
# net_loss = net_loss/ k
if get_grad:
grads = torch.stack(grads)
# print (grads.shape)
outputs['grad_avg'] = torch.mean(torch.mean(grads, dim=0),dim=0)
outputs['grad_std'] = torch.std(grads, dim=0)[0]
outputs['surr_loss'] = surr_loss
# return net_loss, f, logpx_given_z, logpz, logq
return outputs
示例7: sample_relax
# 需要導入模塊: from torch.distributions.categorical import Categorical [as 別名]
# 或者: from torch.distributions.categorical.Categorical import log_prob [as 別名]
def sample_relax(probs):
cat = Categorical(probs=probs)
#Sample z
u = torch.rand(B,C).cuda()
u = u.clamp(1e-8, 1.-1e-8)
gumbels = -torch.log(-torch.log(u))
z = torch.log(probs) + gumbels
b = torch.argmax(z, dim=1)
logprob = cat.log_prob(b).view(B,1)
#Sample z_tilde
u_b = torch.rand(B,1).cuda()
u_b = u_b.clamp(1e-8, 1.-1e-8)
z_tilde_b = -torch.log(-torch.log(u_b))
u = torch.rand(B,C).cuda()
u = u.clamp(1e-8, 1.-1e-8)
z_tilde = -torch.log((- torch.log(u) / probs) - torch.log(u_b))
z_tilde[:,b] = z_tilde_b
return z, b, logprob, z_tilde, gumbels
示例8: sample_relax_given_class_k
# 需要導入模塊: from torch.distributions.categorical import Categorical [as 別名]
# 或者: from torch.distributions.categorical.Categorical import log_prob [as 別名]
def sample_relax_given_class_k(logits, samp, k):
cat = Categorical(logits=logits)
b = samp #torch.argmax(z, dim=1)
logprob = cat.log_prob(b).view(B,1)
zs = []
z_tildes = []
for i in range(k):
u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
gumbels = -torch.log(-torch.log(u))
z = logits + gumbels
u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
z_tilde_b = -torch.log(-torch.log(u_b))
z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)
z = z_tilde
u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
z_tilde_b = -torch.log(-torch.log(u_b))
u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)
zs.append(z)
z_tildes.append(z_tilde)
zs= torch.stack(zs)
z_tildes= torch.stack(z_tildes)
z = torch.mean(zs, dim=0)
z_tilde = torch.mean(z_tildes, dim=0)
return z, z_tilde, logprob
示例9: reinforce
# 需要導入模塊: from torch.distributions.categorical import Categorical [as 別名]
# 或者: from torch.distributions.categorical.Categorical import log_prob [as 別名]
def reinforce(x, logits, mixtureweights, k=1):
B = logits.shape[0]
probs = torch.softmax(logits, dim=1)
cat = Categorical(probs=probs)
net_loss = 0
for jj in range(k):
cluster_H = cat.sample()
logq = cat.log_prob(cluster_H).view(B,1)
logpx_given_z = logprob_undercomponent(x, component=cluster_H)
logpz = torch.log(mixtureweights[cluster_H]).view(B,1)
logpxz = logpx_given_z + logpz #[B,1]
f = logpxz - logq
net_loss += - torch.mean((f.detach() - 1.) * logq)
# net_loss += - torch.mean( -logq.detach()*logq)
net_loss = net_loss/ k
return net_loss, f, logpx_given_z, logpz, logq
示例10: relax_grad
# 需要導入模塊: from torch.distributions.categorical import Categorical [as 別名]
# 或者: from torch.distributions.categorical.Categorical import log_prob [as 別名]
def relax_grad(x, logits, b, surrogate, mixtureweights):
B = logits.shape[0]
C = logits.shape[1]
cat = Categorical(logits=logits)
# u = torch.rand(B,C).clamp(1e-10, 1.-1e-10).cuda()
u = myclamp(torch.rand(B,C).cuda())
gumbels = -torch.log(-torch.log(u))
z = logits + gumbels
# b = torch.argmax(z, dim=1) #.view(B,1)
logq = cat.log_prob(b).view(B,1)
surr_input = torch.cat([z, x, logits.detach()], dim=1)
cz = surrogate.net(surr_input)
z_tilde = sample_relax_given_b(logits, b)
surr_input = torch.cat([z_tilde, x, logits.detach()], dim=1)
cz_tilde = surrogate.net(surr_input)
logpx_given_z = logprob_undercomponent(x, component=b)
logpz = torch.log(mixtureweights[b]).view(B,1)
logpxz = logpx_given_z + logpz #[B,1]
f = logpxz - logq
grad_logq = torch.autograd.grad([torch.mean(logq)], [logits], create_graph=True, retain_graph=True)[0]
grad_surr_z = torch.autograd.grad([torch.mean(cz)], [logits], create_graph=True, retain_graph=True)[0]
grad_surr_z_tilde = torch.autograd.grad([torch.mean(cz_tilde)], [logits], create_graph=True, retain_graph=True)[0]
# surr_loss = torch.mean(((f.detach() - cz_tilde) * grad_logq - grad_logq + grad_surr_z - grad_surr_z_tilde)**2, dim=1, keepdim=True)
surr_loss = ((f.detach() - cz_tilde) * grad_logq - grad_logq + grad_surr_z - grad_surr_z_tilde)**2
# print (surr_loss.shape)
# print (logq.shape)
# fasda
# print (surr_loss, torch.exp(logq))
return surr_loss, torch.exp(logq)
示例11: sample_reinforce_given_class
# 需要導入模塊: from torch.distributions.categorical import Categorical [as 別名]
# 或者: from torch.distributions.categorical.Categorical import log_prob [as 別名]
def sample_reinforce_given_class(logits, samp):
dist = Categorical(logits=logits)
logprob = dist.log_prob(samp)
return logprob
示例12: HLAX
# 需要導入模塊: from torch.distributions.categorical import Categorical [as 別名]
# 或者: from torch.distributions.categorical.Categorical import log_prob [as 別名]
def HLAX(surrogate, surrogate2, x, logits, mixtureweights, k=1):
B = logits.shape[0]
probs = torch.softmax(logits, dim=1)
cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda())
cat_bernoulli = Categorical(probs=probs)
net_loss = 0
surr_loss = 0
for jj in range(k):
cluster_S = cat.rsample()
cluster_H = H(cluster_S)
logq_z = cat.log_prob(cluster_S.detach()).view(B,1)
logq_b = cat_bernoulli.log_prob(cluster_H.detach()).view(B,1)
logpx_given_z = logprob_undercomponent(x, component=cluster_H)
logpz = torch.log(mixtureweights[cluster_H]).view(B,1)
logpxz = logpx_given_z + logpz #[B,1]
f_z = logpxz - logq_z - 1.
f_b = logpxz - logq_b - 1.
surr_input = torch.cat([cluster_S, x], dim=1) #[B,21]
# surr_pred, alpha = surrogate.net(surr_input)
surr_pred = surrogate.net(surr_input)
alpha = torch.sigmoid(surrogate2.net(x))
net_loss += - torch.mean( alpha.detach()*(f_z.detach() - surr_pred.detach()) * logq_z
+ alpha.detach()*surr_pred
+ (1-alpha.detach())*(f_b.detach() ) * logq_b)
# surr_loss += torch.mean(torch.abs(f_z.detach() - surr_pred))
grad_logq_z = torch.mean( torch.autograd.grad([torch.mean(logq_z)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True)
grad_logq_b = torch.mean( torch.autograd.grad([torch.mean(logq_b)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True)
grad_surr = torch.mean( torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True)
# print (alpha.shape, f_z.shape, surr_pred.shape, grad_logq_z.shape, grad_surr.shape)
# fsdfa
# grad_surr = torch.autograd.grad([surr_pred[0]], [logits], create_graph=True, retain_graph=True)[0]
# print (grad_surr)
# fsdfasd
surr_loss += torch.mean(
(alpha*(f_z.detach() - surr_pred) * grad_logq_z
+ alpha*grad_surr
+ (1-alpha)*(f_b.detach()) * grad_logq_b )**2
)
surr_dif = torch.mean(torch.abs(f_z.detach() - surr_pred))
# gradd = torch.autograd.grad([surr_loss], [alpha], create_graph=True, retain_graph=True)[0]
# print (gradd)
# fdsf
grad_path = torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0]
grad_score = torch.autograd.grad([torch.mean((f_z.detach() - surr_pred.detach()) * logq_z)], [logits], create_graph=True, retain_graph=True)[0]
grad_path = torch.mean(torch.abs(grad_path))
grad_score = torch.mean(torch.abs(grad_score))
net_loss = net_loss/ k
surr_loss = surr_loss/ k
return net_loss, f_b, logpx_given_z, logpz, logq_b, surr_loss, surr_dif, grad_path, grad_score, torch.mean(alpha)
示例13: LogitRelaxedBernoulli
# 需要導入模塊: from torch.distributions.categorical import Categorical [as 別名]
# 或者: from torch.distributions.categorical.Categorical import log_prob [as 別名]
# dist = LogitRelaxedBernoulli(torch.Tensor([1.]), bern_param)
# dist_bernoulli = Bernoulli(bern_param)
C= 2
n_components = C
B=1
probs = torch.ones(B,C)
bern_param = bern_param.view(B,1)
aa = 1 - bern_param
probs = torch.cat([aa, bern_param], dim=1)
cat = Categorical(probs= probs)
grads = []
for i in range(n):
b = cat.sample()
logprob = cat.log_prob(b.detach())
# b_ = torch.argmax(z, dim=1)
logprobgrad = torch.autograd.grad(outputs=logprob, inputs=(bern_param), retain_graph=True)[0]
grad = f(b) * logprobgrad
grads.append(grad[0][0].data.numpy())
print ('Grad Estimator: Reinfoce categorical')
print ('Grad mean', np.mean(grads))
print ('Grad std', np.std(grads))
print ()
reinforce_cat_grad_means.append(np.mean(grads))
reinforce_cat_grad_stds.append(np.std(grads))
示例14: simplax
# 需要導入模塊: from torch.distributions.categorical import Categorical [as 別名]
# 或者: from torch.distributions.categorical.Categorical import log_prob [as 別名]
def simplax():
def show_surr_preds():
batch_size = 1
rows = 3
cols = 1
fig = plt.figure(figsize=(10+cols,4+rows), facecolor='white') #, dpi=150)
for i in range(rows):
x = sample_true(1).cuda() #.view(1,1)
logits = encoder.net(x)
probs = torch.softmax(logits, dim=1)
cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda())
cluster_S = cat.rsample()
cluster_H = H(cluster_S)
logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size,1)
check_nan(logprob_cluster)
z = cluster_S
n_evals = 40
x1 = np.linspace(-9,205, n_evals)
x = torch.from_numpy(x1).view(n_evals,1).float().cuda()
z = z.repeat(n_evals,1)
cluster_H = cluster_H.repeat(n_evals,1)
xz = torch.cat([z,x], dim=1)
logpxz = logprob_undercomponent(x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True)
f = logpxz #- logprob_cluster
surr_pred = surrogate.net(xz)
surr_pred = surr_pred.data.cpu().numpy()
f = f.data.cpu().numpy()
col =0
row = i
# print (row)
ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1)
ax.plot(x1,surr_pred, label='Surr')
ax.plot(x1,f, label='f')
ax.set_title(str(cluster_H[0]))
ax.legend()
# save_dir = home+'/Documents/Grad_Estimators/GMM/'
plt_path = exp_dir+'gmm_surr.png'
plt.savefig(plt_path)
print ('saved training plot', plt_path)
plt.close()
def plot_dist():
mixture_weights = torch.softmax(needsoftmax_mixtureweight, dim=0)
rows = 1
cols = 1
fig = plt.figure(figsize=(10+cols,4+rows), facecolor='white') #, dpi=150)
col =0
row = 0
ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1)
xs = np.linspace(-9,205, 300)
sum_ = np.zeros(len(xs))
# C = 20
for c in range(n_components):
m = Normal(torch.tensor([c*10.]).float(), torch.tensor([5.0]).float())
ys = []
for x in xs:
# component_i = (torch.exp(m.log_prob(x) )* ((c+5.) / 290.)).numpy()
component_i = (torch.exp(m.log_prob(x) )* mixture_weights[c]).detach().cpu().numpy()
ys.append(component_i)
ys = np.reshape(np.array(ys), [-1])
sum_ += ys
ax.plot(xs, ys, label='')
ax.plot(xs, sum_, label='')
# save_dir = home+'/Documents/Grad_Estimators/GMM/'
plt_path = exp_dir+'gmm_plot_dist.png'
plt.savefig(plt_path)
print ('saved training plot', plt_path)
plt.close()
#.........這裏部分代碼省略.........
示例15: range
# 需要導入模塊: from torch.distributions.categorical import Categorical [as 別名]
# 或者: from torch.distributions.categorical.Categorical import log_prob [as 別名]
steps_list = []
for step in range(n_steps):
optim.zero_grad()
loss = 0
net_loss = 0
for i in range(batch_size):
x = sample_true()
logits = encoder.net(x)
# print (logits.shape)
# print (torch.softmax(logits, dim=0))
# fsfd
cat = Categorical(probs= torch.softmax(logits, dim=0))
cluster = cat.sample()
logprob_cluster = cat.log_prob(cluster.detach())
# print (logprob_cluster)
pxz = logprob_undercomponent(x, component=cluster, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=False)
f = pxz - logprob_cluster
# print (f)
# logprob = logprob_givenmixtureeweights(x, needsoftmax_mixtureweight)
net_loss += -f.detach() * logprob_cluster
loss += -f
loss = loss / batch_size
net_loss = net_loss / batch_size
# print (loss, net_loss)
loss.backward(retain_graph=True)
optim.step()