当前位置: 首页>>代码示例>>Python>>正文


Python pyro.sample函数代码示例

本文整理汇总了Python中pyro.sample函数的典型用法代码示例。如果您正苦于以下问题:Python sample函数的具体用法?Python sample怎么用?Python sample使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了sample函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: model

 def model():
     p2 = torch.tensor(torch.ones(2) / 2)
     p3 = torch.tensor(torch.ones(3) / 3)
     x2 = pyro.sample("x2", dist.OneHotCategorical(p2))
     x3 = pyro.sample("x3", dist.OneHotCategorical(p3))
     assert x2.shape == torch.Size([2]) + iarange_shape + p2.shape
     assert x3.shape == torch.Size([3, 1]) + iarange_shape + p3.shape
开发者ID:lewisKit,项目名称:pyro,代码行数:7,代码来源:test_valid_models.py

示例2: bernoulli_normal_model

def bernoulli_normal_model():
    bern_0 = pyro.sample('bern_0', dist.Bernoulli(torch.zeros(1) * 1e-2))
    loc = torch.ones(1) if bern_0.item() else -torch.ones(1)
    normal_0 = torch.ones(1)
    pyro.sample('normal_0', dist.Normal(loc, torch.ones(1) * 1e-2),
                obs=normal_0)
    return [bern_0, normal_0]
开发者ID:lewisKit,项目名称:pyro,代码行数:7,代码来源:test_properties.py

示例3: guide

 def guide():
     pyro.module("mymodule", pt_guide)
     mu_q, tau_q = torch.exp(pt_guide.mu_q_log), torch.exp(pt_guide.tau_q_log)
     sigma = torch.pow(tau_q, -0.5)
     pyro.sample("mu_latent",
                 dist.Normal(mu_q, sigma, reparameterized=reparameterized),
                 baseline=dict(use_decaying_avg_baseline=True))
开发者ID:Magica-Chen,项目名称:pyro,代码行数:7,代码来源:test_tracegraph_elbo.py

示例4: _register_param

    def _register_param(self, param, mode="model"):
        """
        Registers a parameter to Pyro. It can be seen as a wrapper for
        :func:`pyro.param` and :func:`pyro.sample` primitives.

        :param str param: Name of the parameter.
        :param str mode: Either "model" or "guide".
        """
        if param in self._fixed_params:
            self._registered_params[param] = self._fixed_params[param]
            return
        prior = self._priors.get(param)
        if self.name is None:
            param_name = param
        else:
            param_name = param_with_module_name(self.name, param)

        if prior is None:
            constraint = self._constraints.get(param)
            default_value = getattr(self, param)
            if constraint is None:
                p = pyro.param(param_name, default_value)
            else:
                p = pyro.param(param_name, default_value, constraint=constraint)
        elif mode == "model":
            p = pyro.sample(param_name, prior)
        else:  # prior != None and mode = "guide"
            MAP_param_name = param_name + "_MAP"
            # TODO: consider to init parameter from a prior call instead of mean
            MAP_param = pyro.param(MAP_param_name, prior.mean.detach())
            p = pyro.sample(param_name, dist.Delta(MAP_param))

        self._registered_params[param] = p
开发者ID:lewisKit,项目名称:pyro,代码行数:33,代码来源:util.py

示例5: iarange_model

def iarange_model(subsample_size):
    loc = torch.zeros(20)
    scale = torch.ones(20)
    with pyro.iarange('iarange', 20, subsample_size) as batch:
        pyro.sample("x", dist.Normal(loc[batch], scale[batch]))
        result = list(batch.data)
    return result
开发者ID:lewisKit,项目名称:pyro,代码行数:7,代码来源:test_mapdata.py

示例6: guide

 def guide(self, x):
     # register PyTorch module `encoder` with Pyro
     pyro.module("encoder", self.encoder)
     # use the encoder to get the parameters used to define q(z|x)
     z_mu, z_sigma = self.encoder.forward(x)
     # sample the latent code z
     pyro.sample("latent", dist.normal, z_mu, z_sigma)
开发者ID:Magica-Chen,项目名称:pyro,代码行数:7,代码来源:vae.py

示例7: guide

    def guide(self, xs, ys=None):
        """
        The guide corresponds to the following:
        q(y|x) = categorical(alpha(x))              # infer digit from an image
        q(z|x,y) = normal(mu(x,y),sigma(x,y))       # infer handwriting style from an image and the digit
        mu, sigma are given by a neural network `encoder_z`
        alpha is given by a neural network `encoder_y`
        :param xs: a batch of scaled vectors of pixels from an image
        :param ys: (optional) a batch of the class labels i.e.
                   the digit corresponding to the image(s)
        :return: None
        """
        # inform Pyro that the variables in the batch of xs, ys are conditionally independent
        with pyro.iarange("independent"):

            # if the class label (the digit) is not supervised, sample
            # (and score) the digit with the variational distribution
            # q(y|x) = categorical(alpha(x))
            if ys is None:
                alpha = self.encoder_y.forward(xs)
                ys = pyro.sample("y", dist.categorical, alpha)

            # sample (and score) the latent handwriting-style with the variational
            # distribution q(z|x,y) = normal(mu(x,y),sigma(x,y))
            mu, sigma = self.encoder_z.forward([xs, ys])
            zs = pyro.sample("z", dist.normal, mu, sigma)   # noqa: F841
开发者ID:Magica-Chen,项目名称:pyro,代码行数:26,代码来源:ss_vae_M2.py

示例8: guide

 def guide(subsample_size):
     mu = pyro.param("mu", lambda: Variable(torch.zeros(len(data)), requires_grad=True))
     sigma = pyro.param("sigma", lambda: Variable(torch.ones(1), requires_grad=True))
     with pyro.iarange("data", len(data), subsample_size) as ind:
         mu = mu[ind]
         sigma = sigma.expand(subsample_size)
         pyro.sample("z", dist.Normal(mu, sigma, reparameterized=reparameterized))
开发者ID:Magica-Chen,项目名称:pyro,代码行数:7,代码来源:test_gradient.py

示例9: sample_zs

 def sample_zs(name, width):
     alpha_z_q = pyro.param("log_alpha_z_q_%s" % name,
                            lambda: rand_tensor((x_size, width), self.alpha_init, self.sigma_init))
     mean_z_q = pyro.param("log_mean_z_q_%s" % name,
                           lambda: rand_tensor((x_size, width), self.mean_init, self.sigma_init))
     alpha_z_q, mean_z_q = self.softplus(alpha_z_q), self.softplus(mean_z_q)
     pyro.sample("z_%s" % name, Gamma(alpha_z_q, alpha_z_q / mean_z_q).independent(1))
开发者ID:lewisKit,项目名称:pyro,代码行数:7,代码来源:sparse_gamma_def.py

示例10: model

 def model(num_particles):
     with pyro.iarange("particles", num_particles):
         q3 = pyro.param("q3", torch.tensor(pi3, requires_grad=True))
         q4 = pyro.param("q4", torch.tensor(0.5 * (pi1 + pi2), requires_grad=True))
         z = pyro.sample("z", dist.Normal(q3, 1.0).expand_by([num_particles]))
         zz = torch.exp(z) / (1.0 + torch.exp(z))
         pyro.sample("y", dist.Bernoulli(q4 * zz))
开发者ID:lewisKit,项目名称:pyro,代码行数:7,代码来源:test_enum.py

示例11: guide

 def guide(num_particles):
     q1 = pyro.param("q1", torch.tensor(pi1, requires_grad=True))
     q2 = pyro.param("q2", torch.tensor(pi2, requires_grad=True))
     with pyro.iarange("particles", num_particles):
         z = pyro.sample("z", dist.Normal(q2, 1.0).expand_by([num_particles]))
         zz = torch.exp(z) / (1.0 + torch.exp(z))
         pyro.sample("y", dist.Bernoulli(q1 * zz))
开发者ID:lewisKit,项目名称:pyro,代码行数:7,代码来源:test_enum.py

示例12: guide

 def guide(subsample):
     loc = pyro.param("loc", lambda: torch.zeros(len(data), requires_grad=True))
     scale = pyro.param("scale", lambda: torch.tensor([1.0], requires_grad=True))
     with pyro.iarange("particles", num_particles):
         with pyro.iarange("data", len(data), subsample_size, subsample) as ind:
             loc_ind = loc[ind].unsqueeze(-1).expand(-1, num_particles)
             pyro.sample("z", Normal(loc_ind, scale))
开发者ID:lewisKit,项目名称:pyro,代码行数:7,代码来源:test_gradient.py

示例13: guide_step

    def guide_step(self, t, n, prev, inputs):

        rnn_input = torch.cat((inputs['embed'], prev.z_where, prev.z_what, prev.z_pres), 1)
        h, c = self.rnn(rnn_input, (prev.h, prev.c))
        z_pres_p, z_where_loc, z_where_scale = self.predict(h)

        # Compute baseline estimates for discrete choice z_pres.
        bl_value, bl_h, bl_c = self.baseline_step(prev, inputs)

        # Sample presence.
        z_pres = pyro.sample('z_pres_{}'.format(t),
                             dist.Bernoulli(z_pres_p * prev.z_pres).independent(1),
                             infer=dict(baseline=dict(baseline_value=bl_value.squeeze(-1))))

        sample_mask = z_pres if self.use_masking else torch.tensor(1.0)

        z_where = pyro.sample('z_where_{}'.format(t),
                              dist.Normal(z_where_loc + self.z_where_loc_prior,
                                          z_where_scale * self.z_where_scale_prior)
                                  .mask(sample_mask)
                                  .independent(1))

        # Figure 2 of [1] shows x_att depending on z_where and h,
        # rather than z_where and x as here, but I think this is
        # correct.
        x_att = image_to_window(z_where, self.window_size, self.x_size, inputs['raw'])

        # Encode attention windows.
        z_what_loc, z_what_scale = self.encode(x_att)

        z_what = pyro.sample('z_what_{}'.format(t),
                             dist.Normal(z_what_loc, z_what_scale)
                                 .mask(sample_mask)
                                 .independent(1))
        return GuideState(h=h, c=c, bl_h=bl_h, bl_c=bl_c, z_pres=z_pres, z_where=z_where, z_what=z_what)
开发者ID:lewisKit,项目名称:pyro,代码行数:35,代码来源:air.py

示例14: model

 def model(self, data):
     loc = self.loc_0
     lambda_prec = self.lambda_prec
     for i in range(1, self.chain_len + 1):
         loc = pyro.sample('loc_{}'.format(i),
                           dist.Normal(loc=loc, scale=lambda_prec))
     pyro.sample('obs', dist.Normal(loc, lambda_prec), obs=data)
开发者ID:lewisKit,项目名称:pyro,代码行数:7,代码来源:test_hmc.py

示例15: guide

 def guide():
     p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
     outer_irange = pyro.irange("irange_0", 3, subsample_size)
     inner_irange = pyro.irange("irange_1", 3, subsample_size)
     for j in inner_irange:
         for i in outer_irange:
             pyro.sample("x_{}_{}".format(i, j), dist.Bernoulli(p))
开发者ID:lewisKit,项目名称:pyro,代码行数:7,代码来源:test_valid_models.py


注:本文中的pyro.sample函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。