当前位置: 首页>>代码示例>>Python>>正文


Python pyro.sample方法代码示例

本文整理汇总了Python中pyro.sample方法的典型用法代码示例。如果您正苦于以下问题:Python pyro.sample方法的具体用法?Python pyro.sample怎么用?Python pyro.sample使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在pyro的用法示例。


在下文中一共展示了pyro.sample方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: sample_latent

# 需要导入模块: import pyro [as 别名]
# 或者: from pyro import sample [as 别名]
def sample_latent(self, input, input_latent_mu, input_latent_sigma, pred_latent_mu,
                    pred_latent_sigma, initial_pose_mu, initial_pose_sigma, sample=True):
    '''
    Return latent variables: dictionary containing pose and content.
    Then, crop objects from the images and encode into z.
    '''
    latent = defaultdict(lambda: None)

    beta = self.get_transitions(input_latent_mu, input_latent_sigma,
                                pred_latent_mu, pred_latent_sigma, sample)
    pose = self.accumulate_pose(beta)
    # Sample initial pose
    initial_pose = self.pyro_sample('initial_pose', dist.Normal, initial_pose_mu,
                                    initial_pose_sigma, sample)
    pose += initial_pose.view(-1, 1, self.n_components, self.pose_latent_size)
    pose = self.constrain_pose(pose)

    # Get input objects
    input_pose = pose[:, :self.n_frames_input, :, :]
    input_obj = self.get_objects(input, input_pose)
    # Encode the sampled objects
    z = self.object_encoder(input_obj)
    z = self.sample_content(z, sample)
    latent.update({'pose': pose, 'content': z})
    return latent 
开发者ID:jthsieh,项目名称:DDPAE-video-prediction,代码行数:27,代码来源:DDPAE.py

示例2: get_transitions

# 需要导入模块: import pyro [as 别名]
# 或者: from pyro import sample [as 别名]
def get_transitions(self, input_latent_mu, input_latent_sigma, pred_latent_mu,
                      pred_latent_sigma, sample=True):
    '''
    Sample the transition variables beta.
    '''
    # input_beta: (batch_size * n_frames_input * n_components) x pose_latent_size
    input_beta = self.pyro_sample('input_beta', dist.Normal, input_latent_mu,
                                  input_latent_sigma, sample)
    beta = input_beta.view(-1, self.n_frames_input, self.n_components, self.pose_latent_size)

    # pred_beta: (batch_size * n_frames_output) x n_components x pose_latent_size
    pred_beta = self.pyro_sample('pred_beta', dist.Normal, pred_latent_mu,
                                 pred_latent_sigma, sample)
    pred_beta = pred_beta.view(-1, self.n_frames_output, self.n_components,
                               self.pose_latent_size)
    # Concatenate the input and prediction beta
    beta = torch.cat([beta, pred_beta], dim=1)
    return beta 
开发者ID:jthsieh,项目名称:DDPAE-video-prediction,代码行数:20,代码来源:DDPAE.py

示例3: sample_content

# 需要导入模块: import pyro [as 别名]
# 或者: from pyro import sample [as 别名]
def sample_content(self, content, sample):
    '''
    Pass into content_lstm to get a final content.
    '''
    content = content.view(-1, self.n_frames_input, self.total_components, self.content_latent_size)
    contents = []
    for i in range(self.total_components):
      z = content[:, :, i, :]
      z = self.content_lstm(z).unsqueeze(1) # batch_size x 1 x (content_latent_size * 2)
      contents.append(z)
    content = torch.cat(contents, dim=1).view(-1, self.content_latent_size * 2)

    # Get mu and sigma, and sample.
    content_mu = content[:, :self.content_latent_size]
    content_sigma = F.softplus(content[:, self.content_latent_size:])
    content = self.pyro_sample('content', dist.Normal, content_mu, content_sigma, sample)
    return content 
开发者ID:jthsieh,项目名称:DDPAE-video-prediction,代码行数:19,代码来源:DDPAE.py

示例4: encode

# 需要导入模块: import pyro [as 别名]
# 或者: from pyro import sample [as 别名]
def encode(self, input, sample=True):
    '''
    Encode video with pose_model, and sample the latent variables for reconstruction
    and prediction.
    Note: pyro.sample is called in self.sample_latent().
    param input: video of size (batch_size, n_frames_input, C, H, W)
    param sample: True if this is called by guide(), and sample with pyro.sample.
    Return latent: a dictionary {'pose': pose, 'content': content, ...}
    '''
    input_latent_mu, input_latent_sigma, pred_latent_mu, pred_latent_sigma,\
        initial_pose_mu, initial_pose_sigma = self.pose_model(input)

    # Sample latent variables
    latent = self.sample_latent(input, input_latent_mu, input_latent_sigma, pred_latent_mu,
                                pred_latent_sigma, initial_pose_mu, initial_pose_sigma, sample)
    return latent 
开发者ID:jthsieh,项目名称:DDPAE-video-prediction,代码行数:18,代码来源:DDPAE.py

示例5: test

# 需要导入模块: import pyro [as 别名]
# 或者: from pyro import sample [as 别名]
def test(self, input, output):
    '''
    Return decoded output.
    '''
    input = Variable(input.cuda())
    batch_size, _, _, H, W = input.size()
    output = Variable(output.cuda())
    gt = torch.cat([input, output], dim=1)

    latent = self.encode(input, sample=False)
    decoded_output, components = self.decode(latent, input.size(0))
    decoded_output = decoded_output.view(*gt.size())
    components = components.view(batch_size, self.n_frames_total, self.total_components,
                                 self.n_channels, H, W)
    latent['components'] = components
    decoded_output = decoded_output.clamp(0, 1)

    self.save_visuals(gt, decoded_output, components, latent)
    return decoded_output.cpu(), latent 
开发者ID:jthsieh,项目名称:DDPAE-video-prediction,代码行数:21,代码来源:DDPAE.py

示例6: model_classify

# 需要导入模块: import pyro [as 别名]
# 或者: from pyro import sample [as 别名]
def model_classify(self, xs, ys=None):
        """
        this model is used to add an auxiliary (supervised) loss as described in the
        NIPS 2014 paper by Kingma et al titled
        "Semi-Supervised Learning with Deep Generative Models"
        """
        # register all pytorch (sub)modules with pyro
        pyro.module("ss_vae", self)

        # inform Pyro that the variables in the batch of xs, ys are conditionally independent
        with pyro.iarange("independent"):
            # this here is the extra Term to yield an auxiliary loss that we do gradient descend on
            # similar to the NIPS 14 paper (Kingma et al).
            if ys is not None:
                alpha = self.encoder_y.forward(xs)
                with pyro.poutine.scale(None, self.aux_loss_multiplier):
                    pyro.sample("y_aux", dist.OneHotCategorical(alpha), obs=ys) 
开发者ID:jinserk,项目名称:pytorch-asr,代码行数:19,代码来源:model.py

示例7: model

# 需要导入模块: import pyro [as 别名]
# 或者: from pyro import sample [as 别名]
def model(self, x, y):
            pyro.module(self.name_prefix + ".gp", self)

            # Draw sample from q(f)
            function_dist = self.pyro_model(x, name_prefix=self.name_prefix)

            # Draw samples of cluster assignments
            cluster_assignment_samples = pyro.sample(
                self.name_prefix + ".cluster_logits",
                pyro.distributions.OneHotCategorical(logits=torch.zeros(self.num_tasks, self.num_functions)).to_event(
                    1
                ),
            )

            # Sample from observation distribution
            with pyro.plate(self.name_prefix + ".output_values_plate", function_dist.batch_shape[-1], dim=-1):
                function_samples = pyro.sample(self.name_prefix + ".f", function_dist)
                obs_dist = pyro.distributions.Normal(
                    loc=(function_samples.unsqueeze(-2) * cluster_assignment_samples).sum(-1), scale=self.noise.sqrt()
                ).to_event(1)
                with pyro.poutine.scale(scale=(self.num_data / y.size(-2))):
                    return pyro.sample(self.name_prefix + ".y", obs_dist, obs=y) 
开发者ID:cornellius-gp,项目名称:gpytorch,代码行数:24,代码来源:test_pyro_integration.py

示例8: pyro_model

# 需要导入模块: import pyro [as 别名]
# 或者: from pyro import sample [as 别名]
def pyro_model(self, input, beta=1.0, name_prefix=""):
        # Inducing values p(u)
        with pyro.poutine.scale(scale=beta):
            u_samples = pyro.sample(self.name_prefix + ".u", self.variational_strategy.prior_distribution)

        # Include term for GPyTorch priors
        log_prior = torch.tensor(0.0, dtype=u_samples.dtype, device=u_samples.device)
        for _, prior, closure, _ in self.named_priors():
            log_prior.add_(prior.log_prob(closure()).sum().div(self.num_data))
        pyro.factor(name_prefix + ".log_prior", log_prior)

        # Include factor for added loss terms
        added_loss = torch.tensor(0.0, dtype=u_samples.dtype, device=u_samples.device)
        for added_loss_term in self.added_loss_terms():
            added_loss.add_(added_loss_term.loss())
        pyro.factor(name_prefix + ".added_loss", added_loss)

        # Draw samples from p(f)
        function_dist = self(input, prior=True)
        function_dist = pyro.distributions.Normal(loc=function_dist.mean, scale=function_dist.stddev).to_event(
            len(function_dist.event_shape) - 1
        )
        return function_dist.mask(False) 
开发者ID:cornellius-gp,项目名称:gpytorch,代码行数:25,代码来源:_pyro_mixin.py

示例9: _draw_likelihood_samples

# 需要导入模块: import pyro [as 别名]
# 或者: from pyro import sample [as 别名]
def _draw_likelihood_samples(self, function_dist, *args, sample_shape=None, **kwargs):
            if self.training:
                num_event_dims = len(function_dist.event_shape)
                function_dist = base_distributions.Normal(function_dist.mean, function_dist.variance.sqrt())
                function_dist = base_distributions.Independent(function_dist, num_event_dims - 1)

            plate_name = self.name_prefix + ".num_particles_vectorized"
            num_samples = settings.num_likelihood_samples.value()
            max_plate_nesting = max(self.max_plate_nesting, len(function_dist.batch_shape))
            with pyro.plate(plate_name, size=num_samples, dim=(-max_plate_nesting - 1)):
                if sample_shape is None:
                    function_samples = pyro.sample(self.name_prefix, function_dist.mask(False))
                    # Deal with the fact that we're not assuming conditional indendence over data points here
                    function_samples = function_samples.squeeze(-len(function_dist.event_shape) - 1)
                else:
                    sample_shape = sample_shape[: -len(function_dist.batch_shape)]
                    function_samples = function_dist(sample_shape)

                if not self.training:
                    function_samples = function_samples.squeeze(-len(function_dist.event_shape) - 1)
                return self.forward(function_samples, *args, **kwargs) 
开发者ID:cornellius-gp,项目名称:gpytorch,代码行数:23,代码来源:likelihood.py

示例10: forward

# 需要导入模块: import pyro [as 别名]
# 或者: from pyro import sample [as 别名]
def forward(self, function_samples, *args, data={}, **kwargs):
            r"""
            Computes the conditional distribution :math:`p(\mathbf y \mid
            \mathbf f, \ldots)` that defines the likelihood.

            :param torch.Tensor function_samples: Samples from the function (:math:`\mathbf f`)
            :param dict data: (Optional, Pyro integration only) Additional
                variables (:math:`\ldots`) that the likelihood needs to condition
                on. The keys of the dictionary will correspond to Pyro sample sites
                in the likelihood's model/guide.
            :param args: Additional args
            :param kwargs: Additional kwargs
            :return: Distribution object (with same shape as :attr:`function_samples`)
            :rtype: :obj:`Distribution`
            """
            raise NotImplementedError 
开发者ID:cornellius-gp,项目名称:gpytorch,代码行数:18,代码来源:likelihood.py

示例11: pyro_model

# 需要导入模块: import pyro [as 别名]
# 或者: from pyro import sample [as 别名]
def pyro_model(self, function_dist, target, *args, **kwargs):
            r"""
            (For Pyro integration only).

            Part of the model function for the likelihood.
            It should return the
            This should be re-defined if the likelihood contains any latent variables that need to be infered.

            :param ~gpytorch.distributions.MultivariateNormal function_dist: Distribution of latent function
                :math:`p(\mathbf f)`.
            :param torch.Tensor target: Observed :math:`\mathbf y`.
            :param args: Additional args (for :meth:`~forward`).
            :param kwargs: Additional kwargs (for :meth:`~forward`).
            """
            with pyro.plate(self.name_prefix + ".data_plate", dim=-1):
                function_samples = pyro.sample(self.name_prefix + ".f", function_dist)
                output_dist = self(function_samples, *args, **kwargs)
                return self.sample_target(output_dist, target) 
开发者ID:cornellius-gp,项目名称:gpytorch,代码行数:20,代码来源:likelihood.py

示例12: _pyro_sample_from_prior

# 需要导入模块: import pyro [as 别名]
# 或者: from pyro import sample [as 别名]
def _pyro_sample_from_prior(module, memo=None, prefix=""):
    try:
        import pyro
    except ImportError:
        raise RuntimeError("Cannot call pyro_sample_from_prior without pyro installed!")
    if memo is None:
        memo = set()
    if hasattr(module, "_priors"):
        for prior_name, (prior, closure, setting_closure) in module._priors.items():
            if prior is not None and prior not in memo:
                if setting_closure is None:
                    raise RuntimeError(
                        "Cannot use Pyro for sampling without a setting_closure for each prior,"
                        f" but the following prior had none: {prior_name}, {prior}."
                    )
                memo.add(prior)
                prior = prior.expand(closure().shape)
                value = pyro.sample(prefix + ("." if prefix else "") + prior_name, prior)
                setting_closure(value)

    for mname, module_ in module.named_children():
        submodule_prefix = prefix + ("." if prefix else "") + mname
        _pyro_sample_from_prior(module=module_, memo=memo, prefix=submodule_prefix) 
开发者ID:cornellius-gp,项目名称:gpytorch,代码行数:25,代码来源:module.py

示例13: model

# 需要导入模块: import pyro [as 别名]
# 或者: from pyro import sample [as 别名]
def model(x, y):
    fc1_weight_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc1.weight), scale=torch.ones_like(det_net.fc1.weight))
    fc1_bias_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc1.bias), scale=torch.ones_like(det_net.fc1.bias))

    fc2_weight_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc2.weight), scale=torch.ones_like(det_net.fc2.weight))
    fc2_bias_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc2.bias), scale=torch.ones_like(det_net.fc2.bias))

    fc3_weight_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc3.weight), scale=torch.ones_like(det_net.fc3.weight))
    fc3_bias_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc3.bias), scale=torch.ones_like(det_net.fc3.bias))

    priors = {"fc1.weight": fc1_weight_prior, "fc1.bias": fc1_bias_prior,
              "fc2.weight": fc2_weight_prior, "fc2.bias": fc2_bias_prior,
              "fc3.weight": fc3_weight_prior, "fc3.bias": fc3_bias_prior}

    lifted_module = pyro.random_module("module", det_net, priors)

    sampled_reg_model = lifted_module()

    logits = sampled_reg_model(x)

    return pyro.sample("obs", pyro.distributions.Categorical(logits=logits), obs=y) 
开发者ID:fregu856,项目名称:evaluating_bdl,代码行数:23,代码来源:model_pyro.py

示例14: pyro_guide

# 需要导入模块: import pyro [as 别名]
# 或者: from pyro import sample [as 别名]
def pyro_guide(self, function_dist, target):
            pyro.sample(self.name_prefix + ".cluster_logits", self._cluster_dist(self.variational_cluster_logits))
            return super().pyro_guide(function_dist, target) 
开发者ID:cornellius-gp,项目名称:gpytorch,代码行数:5,代码来源:test_pyro_integration.py

示例15: guide

# 需要导入模块: import pyro [as 别名]
# 或者: from pyro import sample [as 别名]
def guide(self, input, output):
    '''
    Posterior model: encode input
    param input: video of size (batch_size, n_frames_input, C, H, W).
    parma output: not used.
    '''
    # Register networks
    for name, net in self.guide_modules.items():
      pyro.module(name, net)

    self.encode(input, sample=True) 
开发者ID:jthsieh,项目名称:DDPAE-video-prediction,代码行数:13,代码来源:DDPAE.py


注:本文中的pyro.sample方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。