當前位置: 首頁>>代碼示例>>Python>>正文


Python categorical.Categorical類代碼示例

本文整理匯總了Python中torch.distributions.categorical.Categorical的典型用法代碼示例。如果您正苦於以下問題:Python Categorical類的具體用法?Python Categorical怎麽用?Python Categorical使用的例子?那麽, 這裏精選的類代碼示例或許可以為您提供幫助。


在下文中一共展示了Categorical類的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: sample_relax

    def sample_relax(logits, surrogate):
        cat = Categorical(logits=logits)
        u = torch.rand(B,C).clamp(1e-10, 1.-1e-10).cuda()
        gumbels = -torch.log(-torch.log(u))
        z = logits + gumbels
        b = torch.argmax(z, dim=1) #.view(B,1)
        logprob = cat.log_prob(b).view(B,1)


        # czs = []
        # for j in range(1):
        #     z = sample_relax_z(logits)
        #     surr_input = torch.cat([z, x, logits.detach()], dim=1)
        #     cz = surrogate.net(surr_input)
        #     czs.append(cz)
        # czs = torch.stack(czs)
        # cz = torch.mean(czs, dim=0)#.view(1,1)
        surr_input = torch.cat([z, x, logits.detach()], dim=1)
        cz = surrogate.net(surr_input)


        cz_tildes = []
        for j in range(1):
            z_tilde = sample_relax_given_b(logits, b)
            surr_input = torch.cat([z_tilde, x, logits.detach()], dim=1)
            cz_tilde = surrogate.net(surr_input)
            cz_tildes.append(cz_tilde)
        cz_tildes = torch.stack(cz_tildes)
        cz_tilde = torch.mean(cz_tildes, dim=0) #.view(B,1)

        return b, logprob, cz, cz_tilde
開發者ID:chriscremer,項目名稱:Other_Code,代碼行數:31,代碼來源:gmm_cleaned_v5.py

示例2: sample_relax

def sample_relax(logits): #, k=1):
    

    # u = torch.rand(B,C).clamp(1e-8, 1.-1e-8) #.cuda()
    u = torch.rand(B,C).clamp(1e-12, 1.-1e-12) #.cuda()
    gumbels = -torch.log(-torch.log(u))
    z = logits + gumbels
    b = torch.argmax(z, dim=1)

    cat = Categorical(logits=logits)
    logprob = cat.log_prob(b).view(B,1)

    v_k = torch.rand(B,1).clamp(1e-12, 1.-1e-12)
    z_tilde_b = -torch.log(-torch.log(v_k))
    #this way seems biased even tho it shoudlnt be
    # v_k = torch.gather(input=u, dim=1, index=b.view(B,1))
    # z_tilde_b = torch.gather(input=z, dim=1, index=b.view(B,1))

    v = torch.rand(B,C).clamp(1e-12, 1.-1e-12) #.cuda()
    probs = torch.softmax(logits,dim=1).repeat(B,1)
    # print (probs.shape, torch.log(v_k).shape, torch.log(v).shape)
    # fasdfa

    # print (v.shape)
    # print (v.shape)
    z_tilde = -torch.log((- torch.log(v) / probs) - torch.log(v_k))

    # print (z_tilde)
    # print (z_tilde_b)
    z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)
    # print (z_tilde)
    # fasdfs

    return z, b, logprob, z_tilde
開發者ID:chriscremer,項目名稱:Other_Code,代碼行數:34,代碼來源:plotting_cat_grads_dist_4.py

示例3: relax_grad2

def relax_grad2(x, logits, b, surrogate, mixtureweights):
    B = logits.shape[0]
    C = logits.shape[1]

    cat = Categorical(logits=logits)
    # u = torch.rand(B,C).clamp(1e-10, 1.-1e-10).cuda()
    u = myclamp(torch.rand(B,C).cuda())
    gumbels = -torch.log(-torch.log(u))
    z = logits + gumbels
    # b = torch.argmax(z, dim=1) #.view(B,1)
    logq = cat.log_prob(b).view(B,1)

    surr_input = torch.cat([z, x, logits.detach()], dim=1)
    cz = surrogate.net(surr_input)

    z_tilde = sample_relax_given_b(logits, b)
    surr_input = torch.cat([z_tilde, x, logits.detach()], dim=1)
    cz_tilde = surrogate.net(surr_input)

    logpx_given_z = logprob_undercomponent(x, component=b)
    logpz = torch.log(mixtureweights[b]).view(B,1)
    logpxz = logpx_given_z + logpz #[B,1]

    f = logpxz - logq 
    net_loss = - torch.mean( (f.detach() - cz_tilde.detach()) * logq - logq +  cz - cz_tilde )

    grad = torch.autograd.grad([net_loss], [logits], create_graph=True, retain_graph=True)[0] #[B,C]
    pb = torch.exp(logq)

    return grad, pb
開發者ID:chriscremer,項目名稱:Other_Code,代碼行數:30,代碼來源:gmm_cleaned_v6.py

示例4: sample_relax_given_class

def sample_relax_given_class(logits, samp):

    cat = Categorical(logits=logits)

    u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
    gumbels = -torch.log(-torch.log(u))
    z = logits + gumbels

    b = samp #torch.argmax(z, dim=1)
    logprob = cat.log_prob(b).view(B,1)


    u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
    z_tilde_b = -torch.log(-torch.log(u_b))
    
    z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
    z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)


    z = z_tilde

    u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
    z_tilde_b = -torch.log(-torch.log(u_b))
    
    u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
    z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
    z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)

    return z, z_tilde, logprob
開發者ID:chriscremer,項目名稱:Other_Code,代碼行數:29,代碼來源:plotting_cat_grads_dist.py

示例5: sample_true2

def sample_true2():
    cat = Categorical(probs= torch.tensor(true_mixture_weights))
    cluster = cat.sample()
    # print (cluster)
    # fsd
    norm = Normal(torch.tensor([cluster*10.]).float(), torch.tensor([5.0]).float())
    samp = norm.sample()
    # print (samp)
    return samp,cluster
開發者ID:chriscremer,項目名稱:Other_Code,代碼行數:9,代碼來源:gmm_batch_v2.py

示例6: sample_gmm

def sample_gmm(batch_size, mixture_weights):
    cat = Categorical(probs=mixture_weights)
    cluster = cat.sample([batch_size]) # [B]
    mean = (cluster*10.).float().cuda()
    std = torch.ones([batch_size]).cuda() *5.
    norm = Normal(mean, std)
    samp = norm.sample()
    samp = samp.view(batch_size, 1)
    return samp
開發者ID:chriscremer,項目名稱:Other_Code,代碼行數:9,代碼來源:gmm_cleaned_v5.py

示例7: OneHotCategorical

class OneHotCategorical(Distribution):
    r"""
    Creates a one-hot categorical distribution parameterized by `probs`.

    Samples are one-hot coded vectors of size probs.size(-1).

    See also: :func:`torch.distributions.Categorical`

    Example::

        >>> m = OneHotCategorical(torch.Tensor([ 0.25, 0.25, 0.25, 0.25 ]))
        >>> m.sample()  # equal probability of 0, 1, 2, 3
         0
         0
         1
         0
        [torch.FloatTensor of size 4]

    Args:
        probs (Tensor or Variable): event probabilities
    """
    params = {'probs': constraints.simplex}
    support = constraints.simplex
    has_enumerate_support = True

    def __init__(self, probs=None, logits=None):
        self._categorical = Categorical(probs, logits)
        batch_shape = self._categorical.probs.size()[:-1]
        event_shape = self._categorical.probs.size()[-1:]
        super(OneHotCategorical, self).__init__(batch_shape, event_shape)

    def sample(self, sample_shape=torch.Size()):
        sample_shape = torch.Size(sample_shape)
        probs = self._categorical.probs
        one_hot = probs.new(self._extended_shape(sample_shape)).zero_()
        indices = self._categorical.sample(sample_shape)
        if indices.dim() < one_hot.dim():
            indices = indices.unsqueeze(-1)
        return one_hot.scatter_(-1, indices, 1)

    def log_prob(self, value):
        indices = value.max(-1)[1]
        return self._categorical.log_prob(indices)

    def entropy(self):
        return self._categorical.entropy()

    def enumerate_support(self):
        probs = self._categorical.probs
        n = self.event_shape[0]
        if isinstance(probs, Variable):
            values = Variable(torch.eye(n, out=probs.data.new(n, n)))
        else:
            values = torch.eye(n, out=probs.new(n, n))
        values = values.view((n,) + (1,) * len(self.batch_shape) + (n,))
        return values.expand((n,) + self.batch_shape + (n,))
開發者ID:lxlhh,項目名稱:pytorch,代碼行數:56,代碼來源:one_hot_categorical.py

示例8: test_gmm_loss

    def test_gmm_loss(self):
        """ Test case 1 """
        n_samples = 10000

        means = torch.Tensor([[0., 0.],
                              [1., 1.],
                              [-1., 1.]])
        stds = torch.Tensor([[.03, .05],
                             [.02, .1],
                             [.1, .03]])
        pi = torch.Tensor([.2, .3, .5])

        cat_dist = Categorical(pi)
        indices = cat_dist.sample((n_samples,)).long()
        rands = torch.randn(n_samples, 2)

        samples = means[indices] + rands * stds[indices]

        class _model(nn.Module):
            def __init__(self, gaussians):
                super().__init__()
                self.means = nn.Parameter(torch.Tensor(1, gaussians, 2).normal_())
                self.pre_stds = nn.Parameter(torch.Tensor(1, gaussians, 2).normal_())
                self.pi = nn.Parameter(torch.Tensor(1, gaussians).normal_())

            def forward(self, *inputs):
                return self.means, torch.exp(self.pre_stds), f.softmax(self.pi, dim=1)

        model = _model(3)
        optimizer = torch.optim.Adam(model.parameters())

        iterations = 100000
        log_step = iterations // 10
        pbar = tqdm(total=iterations)
        cum_loss = 0
        for i in range(iterations):
            batch = samples[torch.LongTensor(128).random_(0, n_samples)]
            m, s, p = model.forward()
            loss = gmm_loss(batch, m, s, p)
            cum_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            pbar.set_postfix_str("avg_loss={:10.6f}".format(
                cum_loss / (i + 1)))
            pbar.update(1)
            if i % log_step == log_step - 1:
                print(m)
                print(s)
                print(p)
開發者ID:hbcbh1999,項目名稱:world-models,代碼行數:50,代碼來源:test_gmm.py

示例9: sample_true

def sample_true(batch_size):
    # print (true_mixture_weights.shape)
    cat = Categorical(probs=torch.tensor(true_mixture_weights))
    cluster = cat.sample([batch_size]) # [B]
    mean = (cluster*10.).float()
    std = torch.ones([batch_size]) *5.
    # print (cluster.shape)
    # fsd
    # norm = Normal(torch.tensor([cluster*10.]).float(), torch.tensor([5.0]).float())
    norm = Normal(mean, std)
    samp = norm.sample()
    # print (samp.shape)
    # fadsf
    samp = samp.view(batch_size, 1)
    return samp
開發者ID:chriscremer,項目名稱:Other_Code,代碼行數:15,代碼來源:gmm_batch_v2.py

示例10: reinforce_baseline

def reinforce_baseline(surrogate, x, logits, mixtureweights, k=1, get_grad=False):
    B = logits.shape[0]
    probs = torch.softmax(logits, dim=1)
    outputs = {}

    cat = Categorical(probs=probs)

    grads =[]
    # net_loss = 0
    for jj in range(k):

        cluster_H = cat.sample()
        outputs['logq'] = logq = cat.log_prob(cluster_H).view(B,1)
        outputs['logpx_given_z'] = logpx_given_z = logprob_undercomponent(x, component=cluster_H)
        outputs['logpz'] = logpz = torch.log(mixtureweights[cluster_H]).view(B,1)
        logpxz = logpx_given_z + logpz #[B,1]

        surr_pred = surrogate.net(x)

        outputs['f'] = f = logpxz - logq - 1. 
        # outputs['net_loss'] = net_loss = net_loss - torch.mean((f.detach() ) * logq)
        outputs['net_loss'] = net_loss = - torch.mean((f.detach() - surr_pred.detach()) * logq)
        # net_loss += - torch.mean( -logq.detach()*logq)

        # surr_loss = torch.mean(torch.abs(f.detach() - surr_pred))

        grad_logq =  torch.autograd.grad([torch.mean(logq)], [logits], create_graph=True, retain_graph=True)[0]
        surr_loss = torch.mean(((f.detach() - surr_pred) * grad_logq )**2)

        if get_grad:
            grad = torch.autograd.grad([net_loss], [logits], create_graph=True, retain_graph=True)[0]
            grads.append(grad)

    # net_loss = net_loss/ k

    if get_grad:
        grads = torch.stack(grads)
        # print (grads.shape)
        outputs['grad_avg'] = torch.mean(torch.mean(grads, dim=0),dim=0)
        outputs['grad_std'] = torch.std(grads, dim=0)[0]

    outputs['surr_loss'] = surr_loss
    # return net_loss, f, logpx_given_z, logpz, logq
    return outputs
開發者ID:chriscremer,項目名稱:Other_Code,代碼行數:44,代碼來源:gmm_cleaned_v5.py

示例11: sample_relax

    def sample_relax(probs):
        cat = Categorical(probs=probs)
        #Sample z
        u = torch.rand(B,C).cuda()
        u = u.clamp(1e-8, 1.-1e-8)
        gumbels = -torch.log(-torch.log(u))
        z = torch.log(probs) + gumbels

        b = torch.argmax(z, dim=1)
        logprob = cat.log_prob(b).view(B,1)

        #Sample z_tilde
        u_b = torch.rand(B,1).cuda()
        u_b = u_b.clamp(1e-8, 1.-1e-8)
        z_tilde_b = -torch.log(-torch.log(u_b))
        u = torch.rand(B,C).cuda()
        u = u.clamp(1e-8, 1.-1e-8)
        z_tilde = -torch.log((- torch.log(u) / probs) - torch.log(u_b))
        z_tilde[:,b] = z_tilde_b
        return z, b, logprob, z_tilde, gumbels
開發者ID:chriscremer,項目名稱:Other_Code,代碼行數:20,代碼來源:gmm_cleaned_v3.py

示例12: reinforce

def reinforce(x, logits, mixtureweights, k=1):
    B = logits.shape[0]
    probs = torch.softmax(logits, dim=1)

    cat = Categorical(probs=probs)

    net_loss = 0
    for jj in range(k):

        cluster_H = cat.sample()
        logq = cat.log_prob(cluster_H).view(B,1)

        logpx_given_z = logprob_undercomponent(x, component=cluster_H)
        logpz = torch.log(mixtureweights[cluster_H]).view(B,1)
        logpxz = logpx_given_z + logpz #[B,1]
        f = logpxz - logq
        net_loss += - torch.mean((f.detach() - 1.) * logq)
        # net_loss += - torch.mean( -logq.detach()*logq)

    net_loss = net_loss/ k

    return net_loss, f, logpx_given_z, logpz, logq
開發者ID:chriscremer,項目名稱:Other_Code,代碼行數:22,代碼來源:gmm_cleaned_v3.py

示例13: sample_relax_given_class_k

def sample_relax_given_class_k(logits, samp, k):

    cat = Categorical(logits=logits)
    b = samp #torch.argmax(z, dim=1)
    logprob = cat.log_prob(b).view(B,1)

    zs = []
    z_tildes = []
    for i in range(k):

        u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
        gumbels = -torch.log(-torch.log(u))
        z = logits + gumbels

        u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
        z_tilde_b = -torch.log(-torch.log(u_b))
        
        z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
        z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)

        z = z_tilde

        u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
        z_tilde_b = -torch.log(-torch.log(u_b))
        
        u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
        z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
        z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)

        zs.append(z)
        z_tildes.append(z_tilde)

    zs= torch.stack(zs)
    z_tildes= torch.stack(z_tildes)
    
    z = torch.mean(zs, dim=0)
    z_tilde = torch.mean(z_tildes, dim=0)

    return z, z_tilde, logprob
開發者ID:chriscremer,項目名稱:Other_Code,代碼行數:39,代碼來源:plotting_cat_grads_dist.py

示例14: relax_grad

def relax_grad(x, logits, b, surrogate, mixtureweights):
    B = logits.shape[0]
    C = logits.shape[1]

    cat = Categorical(logits=logits)
    # u = torch.rand(B,C).clamp(1e-10, 1.-1e-10).cuda()
    u = myclamp(torch.rand(B,C).cuda())
    gumbels = -torch.log(-torch.log(u))
    z = logits + gumbels
    # b = torch.argmax(z, dim=1) #.view(B,1)
    logq = cat.log_prob(b).view(B,1)

    surr_input = torch.cat([z, x, logits.detach()], dim=1)
    cz = surrogate.net(surr_input)

    z_tilde = sample_relax_given_b(logits, b)
    surr_input = torch.cat([z_tilde, x, logits.detach()], dim=1)
    cz_tilde = surrogate.net(surr_input)

    logpx_given_z = logprob_undercomponent(x, component=b)
    logpz = torch.log(mixtureweights[b]).view(B,1)
    logpxz = logpx_given_z + logpz #[B,1]

    f = logpxz - logq 

    grad_logq =  torch.autograd.grad([torch.mean(logq)], [logits], create_graph=True, retain_graph=True)[0]
    grad_surr_z =  torch.autograd.grad([torch.mean(cz)], [logits], create_graph=True, retain_graph=True)[0]
    grad_surr_z_tilde = torch.autograd.grad([torch.mean(cz_tilde)], [logits], create_graph=True, retain_graph=True)[0]
    # surr_loss = torch.mean(((f.detach() - cz_tilde) * grad_logq - grad_logq + grad_surr_z - grad_surr_z_tilde)**2, dim=1, keepdim=True)
    surr_loss = ((f.detach() - cz_tilde) * grad_logq - grad_logq + grad_surr_z - grad_surr_z_tilde)**2

    # print (surr_loss.shape)
    # print (logq.shape)
    # fasda

    # print (surr_loss,  torch.exp(logq))
    return surr_loss, torch.exp(logq)
開發者ID:chriscremer,項目名稱:Other_Code,代碼行數:37,代碼來源:gmm_cleaned_v6.py

示例15: OneHotCategorical

class OneHotCategorical(Distribution):
    r"""
    Creates a one-hot categorical distribution parameterized by :attr:`probs` or
    :attr:`logits`.

    Samples are one-hot coded vectors of size ``probs.size(-1)``.

    .. note:: :attr:`probs` will be normalized to be summing to 1.

    See also: :func:`torch.distributions.Categorical` for specifications of
    :attr:`probs` and :attr:`logits`.

    Example::

        >>> m = OneHotCategorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ]))
        >>> m.sample()  # equal probability of 0, 1, 2, 3
        tensor([ 0.,  0.,  0.,  1.])

    Args:
        probs (Tensor): event probabilities
        logits (Tensor): event log probabilities
    """
    arg_constraints = {'probs': constraints.simplex}
    support = constraints.simplex
    has_enumerate_support = True

    def __init__(self, probs=None, logits=None, validate_args=None):
        self._categorical = Categorical(probs, logits)
        batch_shape = self._categorical.batch_shape
        event_shape = self._categorical.param_shape[-1:]
        super(OneHotCategorical, self).__init__(batch_shape, event_shape, validate_args=validate_args)

    def _new(self, *args, **kwargs):
        return self._categorical._new(*args, **kwargs)

    @property
    def probs(self):
        return self._categorical.probs

    @property
    def logits(self):
        return self._categorical.logits

    @property
    def mean(self):
        return self._categorical.probs

    @property
    def variance(self):
        return self._categorical.probs * (1 - self._categorical.probs)

    @property
    def param_shape(self):
        return self._categorical.param_shape

    def sample(self, sample_shape=torch.Size()):
        sample_shape = torch.Size(sample_shape)
        probs = self._categorical.probs
        one_hot = probs.new(self._extended_shape(sample_shape)).zero_()
        indices = self._categorical.sample(sample_shape)
        if indices.dim() < one_hot.dim():
            indices = indices.unsqueeze(-1)
        return one_hot.scatter_(-1, indices, 1)

    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)
        indices = value.max(-1)[1]
        return self._categorical.log_prob(indices)

    def entropy(self):
        return self._categorical.entropy()

    def enumerate_support(self):
        n = self.event_shape[0]
        values = self._new((n, n))
        torch.eye(n, out=values)
        values = values.view((n,) + (1,) * len(self.batch_shape) + (n,))
        return values.expand((n,) + self.batch_shape + (n,))
開發者ID:inkawhich,項目名稱:pytorch,代碼行數:79,代碼來源:one_hot_categorical.py


注:本文中的torch.distributions.categorical.Categorical類示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。