当前位置: 首页>>代码示例>>Python>>正文


Python Categorical.log_prob方法代码示例

本文整理汇总了Python中torch.distributions.categorical.Categorical.log_prob方法的典型用法代码示例。如果您正苦于以下问题:Python Categorical.log_prob方法的具体用法?Python Categorical.log_prob怎么用?Python Categorical.log_prob使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.distributions.categorical.Categorical的用法示例。


在下文中一共展示了Categorical.log_prob方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: relax_grad2

# 需要导入模块: from torch.distributions.categorical import Categorical [as 别名]
# 或者: from torch.distributions.categorical.Categorical import log_prob [as 别名]
def relax_grad2(x, logits, b, surrogate, mixtureweights):
    B = logits.shape[0]
    C = logits.shape[1]

    cat = Categorical(logits=logits)
    # u = torch.rand(B,C).clamp(1e-10, 1.-1e-10).cuda()
    u = myclamp(torch.rand(B,C).cuda())
    gumbels = -torch.log(-torch.log(u))
    z = logits + gumbels
    # b = torch.argmax(z, dim=1) #.view(B,1)
    logq = cat.log_prob(b).view(B,1)

    surr_input = torch.cat([z, x, logits.detach()], dim=1)
    cz = surrogate.net(surr_input)

    z_tilde = sample_relax_given_b(logits, b)
    surr_input = torch.cat([z_tilde, x, logits.detach()], dim=1)
    cz_tilde = surrogate.net(surr_input)

    logpx_given_z = logprob_undercomponent(x, component=b)
    logpz = torch.log(mixtureweights[b]).view(B,1)
    logpxz = logpx_given_z + logpz #[B,1]

    f = logpxz - logq 
    net_loss = - torch.mean( (f.detach() - cz_tilde.detach()) * logq - logq +  cz - cz_tilde )

    grad = torch.autograd.grad([net_loss], [logits], create_graph=True, retain_graph=True)[0] #[B,C]
    pb = torch.exp(logq)

    return grad, pb
开发者ID:chriscremer,项目名称:Other_Code,代码行数:32,代码来源:gmm_cleaned_v6.py

示例2: sample_relax

# 需要导入模块: from torch.distributions.categorical import Categorical [as 别名]
# 或者: from torch.distributions.categorical.Categorical import log_prob [as 别名]
def sample_relax(logits): #, k=1):
    

    # u = torch.rand(B,C).clamp(1e-8, 1.-1e-8) #.cuda()
    u = torch.rand(B,C).clamp(1e-12, 1.-1e-12) #.cuda()
    gumbels = -torch.log(-torch.log(u))
    z = logits + gumbels
    b = torch.argmax(z, dim=1)

    cat = Categorical(logits=logits)
    logprob = cat.log_prob(b).view(B,1)

    v_k = torch.rand(B,1).clamp(1e-12, 1.-1e-12)
    z_tilde_b = -torch.log(-torch.log(v_k))
    #this way seems biased even tho it shoudlnt be
    # v_k = torch.gather(input=u, dim=1, index=b.view(B,1))
    # z_tilde_b = torch.gather(input=z, dim=1, index=b.view(B,1))

    v = torch.rand(B,C).clamp(1e-12, 1.-1e-12) #.cuda()
    probs = torch.softmax(logits,dim=1).repeat(B,1)
    # print (probs.shape, torch.log(v_k).shape, torch.log(v).shape)
    # fasdfa

    # print (v.shape)
    # print (v.shape)
    z_tilde = -torch.log((- torch.log(v) / probs) - torch.log(v_k))

    # print (z_tilde)
    # print (z_tilde_b)
    z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)
    # print (z_tilde)
    # fasdfs

    return z, b, logprob, z_tilde
开发者ID:chriscremer,项目名称:Other_Code,代码行数:36,代码来源:plotting_cat_grads_dist_4.py

示例3: sample_relax_given_class

# 需要导入模块: from torch.distributions.categorical import Categorical [as 别名]
# 或者: from torch.distributions.categorical.Categorical import log_prob [as 别名]
def sample_relax_given_class(logits, samp):

    cat = Categorical(logits=logits)

    u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
    gumbels = -torch.log(-torch.log(u))
    z = logits + gumbels

    b = samp #torch.argmax(z, dim=1)
    logprob = cat.log_prob(b).view(B,1)


    u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
    z_tilde_b = -torch.log(-torch.log(u_b))
    
    z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
    z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)


    z = z_tilde

    u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
    z_tilde_b = -torch.log(-torch.log(u_b))
    
    u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
    z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
    z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)

    return z, z_tilde, logprob
开发者ID:chriscremer,项目名称:Other_Code,代码行数:31,代码来源:plotting_cat_grads_dist.py

示例4: sample_relax

# 需要导入模块: from torch.distributions.categorical import Categorical [as 别名]
# 或者: from torch.distributions.categorical.Categorical import log_prob [as 别名]
    def sample_relax(logits, surrogate):
        cat = Categorical(logits=logits)
        u = torch.rand(B,C).clamp(1e-10, 1.-1e-10).cuda()
        gumbels = -torch.log(-torch.log(u))
        z = logits + gumbels
        b = torch.argmax(z, dim=1) #.view(B,1)
        logprob = cat.log_prob(b).view(B,1)


        # czs = []
        # for j in range(1):
        #     z = sample_relax_z(logits)
        #     surr_input = torch.cat([z, x, logits.detach()], dim=1)
        #     cz = surrogate.net(surr_input)
        #     czs.append(cz)
        # czs = torch.stack(czs)
        # cz = torch.mean(czs, dim=0)#.view(1,1)
        surr_input = torch.cat([z, x, logits.detach()], dim=1)
        cz = surrogate.net(surr_input)


        cz_tildes = []
        for j in range(1):
            z_tilde = sample_relax_given_b(logits, b)
            surr_input = torch.cat([z_tilde, x, logits.detach()], dim=1)
            cz_tilde = surrogate.net(surr_input)
            cz_tildes.append(cz_tilde)
        cz_tildes = torch.stack(cz_tildes)
        cz_tilde = torch.mean(cz_tildes, dim=0) #.view(B,1)

        return b, logprob, cz, cz_tilde
开发者ID:chriscremer,项目名称:Other_Code,代码行数:33,代码来源:gmm_cleaned_v5.py

示例5: OneHotCategorical

# 需要导入模块: from torch.distributions.categorical import Categorical [as 别名]
# 或者: from torch.distributions.categorical.Categorical import log_prob [as 别名]
class OneHotCategorical(Distribution):
    r"""
    Creates a one-hot categorical distribution parameterized by `probs`.

    Samples are one-hot coded vectors of size probs.size(-1).

    See also: :func:`torch.distributions.Categorical`

    Example::

        >>> m = OneHotCategorical(torch.Tensor([ 0.25, 0.25, 0.25, 0.25 ]))
        >>> m.sample()  # equal probability of 0, 1, 2, 3
         0
         0
         1
         0
        [torch.FloatTensor of size 4]

    Args:
        probs (Tensor or Variable): event probabilities
    """
    params = {'probs': constraints.simplex}
    support = constraints.simplex
    has_enumerate_support = True

    def __init__(self, probs=None, logits=None):
        self._categorical = Categorical(probs, logits)
        batch_shape = self._categorical.probs.size()[:-1]
        event_shape = self._categorical.probs.size()[-1:]
        super(OneHotCategorical, self).__init__(batch_shape, event_shape)

    def sample(self, sample_shape=torch.Size()):
        sample_shape = torch.Size(sample_shape)
        probs = self._categorical.probs
        one_hot = probs.new(self._extended_shape(sample_shape)).zero_()
        indices = self._categorical.sample(sample_shape)
        if indices.dim() < one_hot.dim():
            indices = indices.unsqueeze(-1)
        return one_hot.scatter_(-1, indices, 1)

    def log_prob(self, value):
        indices = value.max(-1)[1]
        return self._categorical.log_prob(indices)

    def entropy(self):
        return self._categorical.entropy()

    def enumerate_support(self):
        probs = self._categorical.probs
        n = self.event_shape[0]
        if isinstance(probs, Variable):
            values = Variable(torch.eye(n, out=probs.data.new(n, n)))
        else:
            values = torch.eye(n, out=probs.new(n, n))
        values = values.view((n,) + (1,) * len(self.batch_shape) + (n,))
        return values.expand((n,) + self.batch_shape + (n,))
开发者ID:lxlhh,项目名称:pytorch,代码行数:58,代码来源:one_hot_categorical.py

示例6: reinforce_baseline

# 需要导入模块: from torch.distributions.categorical import Categorical [as 别名]
# 或者: from torch.distributions.categorical.Categorical import log_prob [as 别名]
def reinforce_baseline(surrogate, x, logits, mixtureweights, k=1, get_grad=False):
    B = logits.shape[0]
    probs = torch.softmax(logits, dim=1)
    outputs = {}

    cat = Categorical(probs=probs)

    grads =[]
    # net_loss = 0
    for jj in range(k):

        cluster_H = cat.sample()
        outputs['logq'] = logq = cat.log_prob(cluster_H).view(B,1)
        outputs['logpx_given_z'] = logpx_given_z = logprob_undercomponent(x, component=cluster_H)
        outputs['logpz'] = logpz = torch.log(mixtureweights[cluster_H]).view(B,1)
        logpxz = logpx_given_z + logpz #[B,1]

        surr_pred = surrogate.net(x)

        outputs['f'] = f = logpxz - logq - 1. 
        # outputs['net_loss'] = net_loss = net_loss - torch.mean((f.detach() ) * logq)
        outputs['net_loss'] = net_loss = - torch.mean((f.detach() - surr_pred.detach()) * logq)
        # net_loss += - torch.mean( -logq.detach()*logq)

        # surr_loss = torch.mean(torch.abs(f.detach() - surr_pred))

        grad_logq =  torch.autograd.grad([torch.mean(logq)], [logits], create_graph=True, retain_graph=True)[0]
        surr_loss = torch.mean(((f.detach() - surr_pred) * grad_logq )**2)

        if get_grad:
            grad = torch.autograd.grad([net_loss], [logits], create_graph=True, retain_graph=True)[0]
            grads.append(grad)

    # net_loss = net_loss/ k

    if get_grad:
        grads = torch.stack(grads)
        # print (grads.shape)
        outputs['grad_avg'] = torch.mean(torch.mean(grads, dim=0),dim=0)
        outputs['grad_std'] = torch.std(grads, dim=0)[0]

    outputs['surr_loss'] = surr_loss
    # return net_loss, f, logpx_given_z, logpz, logq
    return outputs
开发者ID:chriscremer,项目名称:Other_Code,代码行数:46,代码来源:gmm_cleaned_v5.py

示例7: sample_relax

# 需要导入模块: from torch.distributions.categorical import Categorical [as 别名]
# 或者: from torch.distributions.categorical.Categorical import log_prob [as 别名]
    def sample_relax(probs):
        cat = Categorical(probs=probs)
        #Sample z
        u = torch.rand(B,C).cuda()
        u = u.clamp(1e-8, 1.-1e-8)
        gumbels = -torch.log(-torch.log(u))
        z = torch.log(probs) + gumbels

        b = torch.argmax(z, dim=1)
        logprob = cat.log_prob(b).view(B,1)

        #Sample z_tilde
        u_b = torch.rand(B,1).cuda()
        u_b = u_b.clamp(1e-8, 1.-1e-8)
        z_tilde_b = -torch.log(-torch.log(u_b))
        u = torch.rand(B,C).cuda()
        u = u.clamp(1e-8, 1.-1e-8)
        z_tilde = -torch.log((- torch.log(u) / probs) - torch.log(u_b))
        z_tilde[:,b] = z_tilde_b
        return z, b, logprob, z_tilde, gumbels
开发者ID:chriscremer,项目名称:Other_Code,代码行数:22,代码来源:gmm_cleaned_v3.py

示例8: sample_relax_given_class_k

# 需要导入模块: from torch.distributions.categorical import Categorical [as 别名]
# 或者: from torch.distributions.categorical.Categorical import log_prob [as 别名]
def sample_relax_given_class_k(logits, samp, k):

    cat = Categorical(logits=logits)
    b = samp #torch.argmax(z, dim=1)
    logprob = cat.log_prob(b).view(B,1)

    zs = []
    z_tildes = []
    for i in range(k):

        u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
        gumbels = -torch.log(-torch.log(u))
        z = logits + gumbels

        u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
        z_tilde_b = -torch.log(-torch.log(u_b))
        
        z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
        z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)

        z = z_tilde

        u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
        z_tilde_b = -torch.log(-torch.log(u_b))
        
        u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
        z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
        z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)

        zs.append(z)
        z_tildes.append(z_tilde)

    zs= torch.stack(zs)
    z_tildes= torch.stack(z_tildes)
    
    z = torch.mean(zs, dim=0)
    z_tilde = torch.mean(z_tildes, dim=0)

    return z, z_tilde, logprob
开发者ID:chriscremer,项目名称:Other_Code,代码行数:41,代码来源:plotting_cat_grads_dist.py

示例9: reinforce

# 需要导入模块: from torch.distributions.categorical import Categorical [as 别名]
# 或者: from torch.distributions.categorical.Categorical import log_prob [as 别名]
def reinforce(x, logits, mixtureweights, k=1):
    B = logits.shape[0]
    probs = torch.softmax(logits, dim=1)

    cat = Categorical(probs=probs)

    net_loss = 0
    for jj in range(k):

        cluster_H = cat.sample()
        logq = cat.log_prob(cluster_H).view(B,1)

        logpx_given_z = logprob_undercomponent(x, component=cluster_H)
        logpz = torch.log(mixtureweights[cluster_H]).view(B,1)
        logpxz = logpx_given_z + logpz #[B,1]
        f = logpxz - logq
        net_loss += - torch.mean((f.detach() - 1.) * logq)
        # net_loss += - torch.mean( -logq.detach()*logq)

    net_loss = net_loss/ k

    return net_loss, f, logpx_given_z, logpz, logq
开发者ID:chriscremer,项目名称:Other_Code,代码行数:24,代码来源:gmm_cleaned_v3.py

示例10: relax_grad

# 需要导入模块: from torch.distributions.categorical import Categorical [as 别名]
# 或者: from torch.distributions.categorical.Categorical import log_prob [as 别名]
def relax_grad(x, logits, b, surrogate, mixtureweights):
    B = logits.shape[0]
    C = logits.shape[1]

    cat = Categorical(logits=logits)
    # u = torch.rand(B,C).clamp(1e-10, 1.-1e-10).cuda()
    u = myclamp(torch.rand(B,C).cuda())
    gumbels = -torch.log(-torch.log(u))
    z = logits + gumbels
    # b = torch.argmax(z, dim=1) #.view(B,1)
    logq = cat.log_prob(b).view(B,1)

    surr_input = torch.cat([z, x, logits.detach()], dim=1)
    cz = surrogate.net(surr_input)

    z_tilde = sample_relax_given_b(logits, b)
    surr_input = torch.cat([z_tilde, x, logits.detach()], dim=1)
    cz_tilde = surrogate.net(surr_input)

    logpx_given_z = logprob_undercomponent(x, component=b)
    logpz = torch.log(mixtureweights[b]).view(B,1)
    logpxz = logpx_given_z + logpz #[B,1]

    f = logpxz - logq 

    grad_logq =  torch.autograd.grad([torch.mean(logq)], [logits], create_graph=True, retain_graph=True)[0]
    grad_surr_z =  torch.autograd.grad([torch.mean(cz)], [logits], create_graph=True, retain_graph=True)[0]
    grad_surr_z_tilde = torch.autograd.grad([torch.mean(cz_tilde)], [logits], create_graph=True, retain_graph=True)[0]
    # surr_loss = torch.mean(((f.detach() - cz_tilde) * grad_logq - grad_logq + grad_surr_z - grad_surr_z_tilde)**2, dim=1, keepdim=True)
    surr_loss = ((f.detach() - cz_tilde) * grad_logq - grad_logq + grad_surr_z - grad_surr_z_tilde)**2

    # print (surr_loss.shape)
    # print (logq.shape)
    # fasda

    # print (surr_loss,  torch.exp(logq))
    return surr_loss, torch.exp(logq)
开发者ID:chriscremer,项目名称:Other_Code,代码行数:39,代码来源:gmm_cleaned_v6.py

示例11: sample_reinforce_given_class

# 需要导入模块: from torch.distributions.categorical import Categorical [as 别名]
# 或者: from torch.distributions.categorical.Categorical import log_prob [as 别名]
def sample_reinforce_given_class(logits, samp):
    dist = Categorical(logits=logits)
    logprob = dist.log_prob(samp)
    return logprob
开发者ID:chriscremer,项目名称:Other_Code,代码行数:6,代码来源:plotting_cat_grads_dist.py

示例12: HLAX

# 需要导入模块: from torch.distributions.categorical import Categorical [as 别名]
# 或者: from torch.distributions.categorical.Categorical import log_prob [as 别名]
def HLAX(surrogate, surrogate2, x, logits, mixtureweights, k=1):
    B = logits.shape[0]
    probs = torch.softmax(logits, dim=1)

    cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda())
    cat_bernoulli = Categorical(probs=probs)

    net_loss = 0
    surr_loss = 0
    for jj in range(k):

        cluster_S = cat.rsample()
        cluster_H = H(cluster_S)

        logq_z = cat.log_prob(cluster_S.detach()).view(B,1)
        logq_b = cat_bernoulli.log_prob(cluster_H.detach()).view(B,1)


        logpx_given_z = logprob_undercomponent(x, component=cluster_H)
        logpz = torch.log(mixtureweights[cluster_H]).view(B,1)
        logpxz = logpx_given_z + logpz #[B,1]
        f_z = logpxz - logq_z - 1.
        f_b = logpxz - logq_b - 1.

        surr_input = torch.cat([cluster_S, x], dim=1) #[B,21]
        # surr_pred, alpha = surrogate.net(surr_input)
        surr_pred = surrogate.net(surr_input)
        alpha = torch.sigmoid(surrogate2.net(x))

        net_loss += - torch.mean(     alpha.detach()*(f_z.detach()  - surr_pred.detach()) * logq_z  
                                    + alpha.detach()*surr_pred 
                                    + (1-alpha.detach())*(f_b.detach()  ) * logq_b)

        # surr_loss += torch.mean(torch.abs(f_z.detach() - surr_pred))

        grad_logq_z = torch.mean( torch.autograd.grad([torch.mean(logq_z)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True)
        grad_logq_b =  torch.mean( torch.autograd.grad([torch.mean(logq_b)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True)
        grad_surr = torch.mean( torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True)
        # print (alpha.shape, f_z.shape, surr_pred.shape, grad_logq_z.shape, grad_surr.shape)
        # fsdfa
        # grad_surr = torch.autograd.grad([surr_pred[0]], [logits], create_graph=True, retain_graph=True)[0]
        # print (grad_surr)
        # fsdfasd
        surr_loss += torch.mean(
                                    (alpha*(f_z.detach() - surr_pred) * grad_logq_z 
                                    + alpha*grad_surr
                                    + (1-alpha)*(f_b.detach()) * grad_logq_b )**2
                                    )

        surr_dif = torch.mean(torch.abs(f_z.detach() - surr_pred))
        # gradd = torch.autograd.grad([surr_loss], [alpha], create_graph=True, retain_graph=True)[0]
        # print (gradd)
        # fdsf
        grad_path = torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0]
        grad_score = torch.autograd.grad([torch.mean((f_z.detach() - surr_pred.detach()) * logq_z)], [logits], create_graph=True, retain_graph=True)[0]
        grad_path = torch.mean(torch.abs(grad_path))
        grad_score = torch.mean(torch.abs(grad_score))


    net_loss = net_loss/ k
    surr_loss = surr_loss/ k

    return net_loss, f_b, logpx_given_z, logpz, logq_b, surr_loss, surr_dif, grad_path, grad_score, torch.mean(alpha)
开发者ID:chriscremer,项目名称:Other_Code,代码行数:65,代码来源:gmm_cleaned_v3.py

示例13: LogitRelaxedBernoulli

# 需要导入模块: from torch.distributions.categorical import Categorical [as 别名]
# 或者: from torch.distributions.categorical.Categorical import log_prob [as 别名]
    # dist = LogitRelaxedBernoulli(torch.Tensor([1.]), bern_param)
    # dist_bernoulli = Bernoulli(bern_param)
    C= 2
    n_components = C
    B=1
    probs = torch.ones(B,C)
    bern_param = bern_param.view(B,1)
    aa = 1 - bern_param
    probs = torch.cat([aa, bern_param], dim=1)

    cat = Categorical(probs= probs)

    grads = []
    for i in range(n):
        b = cat.sample()
        logprob = cat.log_prob(b.detach())
        # b_ = torch.argmax(z, dim=1)

        logprobgrad = torch.autograd.grad(outputs=logprob, inputs=(bern_param), retain_graph=True)[0]
        grad = f(b) * logprobgrad

        grads.append(grad[0][0].data.numpy())

    print ('Grad Estimator: Reinfoce categorical')
    print ('Grad mean', np.mean(grads))
    print ('Grad std', np.std(grads))
    print ()

    reinforce_cat_grad_means.append(np.mean(grads))
    reinforce_cat_grad_stds.append(np.std(grads))
开发者ID:chriscremer,项目名称:Other_Code,代码行数:32,代码来源:is_pz_grad_dependent_on_theta_2.py

示例14: simplax

# 需要导入模块: from torch.distributions.categorical import Categorical [as 别名]
# 或者: from torch.distributions.categorical.Categorical import log_prob [as 别名]
def simplax():



    def show_surr_preds():

        batch_size = 1

        rows = 3
        cols = 1
        fig = plt.figure(figsize=(10+cols,4+rows), facecolor='white') #, dpi=150)

        for i in range(rows):

            x = sample_true(1).cuda() #.view(1,1)
            logits = encoder.net(x)
            probs = torch.softmax(logits, dim=1)
            cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda())
            cluster_S = cat.rsample()
            cluster_H = H(cluster_S)
            logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size,1)
            check_nan(logprob_cluster)

            z = cluster_S

            n_evals = 40
            x1 = np.linspace(-9,205, n_evals)
            x = torch.from_numpy(x1).view(n_evals,1).float().cuda()
            z = z.repeat(n_evals,1)
            cluster_H = cluster_H.repeat(n_evals,1)
            xz = torch.cat([z,x], dim=1) 

            logpxz = logprob_undercomponent(x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True)
            f = logpxz #- logprob_cluster

            surr_pred = surrogate.net(xz)
            surr_pred = surr_pred.data.cpu().numpy()
            f = f.data.cpu().numpy()

            col =0
            row = i
            # print (row)
            ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1)

            ax.plot(x1,surr_pred, label='Surr')
            ax.plot(x1,f, label='f')
            ax.set_title(str(cluster_H[0]))
            ax.legend()


        # save_dir = home+'/Documents/Grad_Estimators/GMM/'
        plt_path = exp_dir+'gmm_surr.png'
        plt.savefig(plt_path)
        print ('saved training plot', plt_path)
        plt.close()




    def plot_dist():


        mixture_weights = torch.softmax(needsoftmax_mixtureweight, dim=0)

        rows = 1
        cols = 1
        fig = plt.figure(figsize=(10+cols,4+rows), facecolor='white') #, dpi=150)

        col =0
        row = 0
        ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1)


        xs = np.linspace(-9,205, 300)
        sum_ = np.zeros(len(xs))

        # C = 20
        for c in range(n_components):
            m = Normal(torch.tensor([c*10.]).float(), torch.tensor([5.0]).float())
            ys = []
            for x in xs:
                # component_i = (torch.exp(m.log_prob(x) )* ((c+5.) / 290.)).numpy()
                component_i = (torch.exp(m.log_prob(x) )* mixture_weights[c]).detach().cpu().numpy()


                ys.append(component_i)

            ys = np.reshape(np.array(ys), [-1])
            sum_ += ys
            ax.plot(xs, ys, label='')

        ax.plot(xs, sum_, label='')

        # save_dir = home+'/Documents/Grad_Estimators/GMM/'
        plt_path = exp_dir+'gmm_plot_dist.png'
        plt.savefig(plt_path)
        print ('saved training plot', plt_path)
        plt.close()
        

#.........这里部分代码省略.........
开发者ID:chriscremer,项目名称:Other_Code,代码行数:103,代码来源:gmm_batch_fewerclasses.py

示例15: range

# 需要导入模块: from torch.distributions.categorical import Categorical [as 别名]
# 或者: from torch.distributions.categorical.Categorical import log_prob [as 别名]
    steps_list = []
    for step in range(n_steps):

        optim.zero_grad()

        loss = 0
        net_loss = 0
        for i in range(batch_size):
            x = sample_true()
            logits = encoder.net(x)
            # print (logits.shape)
            # print (torch.softmax(logits, dim=0))
            # fsfd
            cat = Categorical(probs= torch.softmax(logits, dim=0))
            cluster = cat.sample()
            logprob_cluster = cat.log_prob(cluster.detach())
            # print (logprob_cluster)
            pxz = logprob_undercomponent(x, component=cluster, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=False)
            f = pxz - logprob_cluster
            # print (f)
            # logprob = logprob_givenmixtureeweights(x, needsoftmax_mixtureweight)
            net_loss += -f.detach() * logprob_cluster
            loss += -f
        loss = loss / batch_size
        net_loss = net_loss / batch_size

        # print (loss, net_loss)

        loss.backward(retain_graph=True)  
        optim.step()
开发者ID:chriscremer,项目名称:Other_Code,代码行数:32,代码来源:gmm.py


注:本文中的torch.distributions.categorical.Categorical.log_prob方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。