本文整理汇总了Python中pyro.infer方法的典型用法代码示例。如果您正苦于以下问题:Python pyro.infer方法的具体用法?Python pyro.infer怎么用?Python pyro.infer使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类pyro
的用法示例。
在下文中一共展示了pyro.infer方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: pyro_noncentered_schools
# 需要导入模块: import pyro [as 别名]
# 或者: from pyro import infer [as 别名]
def pyro_noncentered_schools(data, draws, chains):
"""Non-centered eight schools implementation in Pyro."""
import torch
from pyro.infer import MCMC, NUTS
y = torch.from_numpy(data["y"]).float()
sigma = torch.from_numpy(data["sigma"]).float()
nuts_kernel = NUTS(_pyro_noncentered_model, jit_compile=True, ignore_jit_warnings=True)
posterior = MCMC(nuts_kernel, num_samples=draws, warmup_steps=draws, num_chains=chains)
posterior.run(data["J"], sigma, y)
# This block lets the posterior be pickled
posterior.sampler = None
posterior.kernel.potential_fn = None
return posterior
# pylint:disable=no-member,no-value-for-parameter,invalid-name
示例2: numpyro_schools_model
# 需要导入模块: import pyro [as 别名]
# 或者: from pyro import infer [as 别名]
def numpyro_schools_model(data, draws, chains):
"""Centered eight schools implementation in NumPyro."""
from jax.random import PRNGKey
from numpyro.infer import MCMC, NUTS
mcmc = MCMC(
NUTS(_numpyro_noncentered_model),
num_warmup=draws,
num_samples=draws,
num_chains=chains,
chain_method="sequential",
)
mcmc.run(PRNGKey(0), extra_fields=("num_steps", "energy"), **data)
# This block lets the posterior be pickled
mcmc.sampler._sample_fn = None # pylint: disable=protected-access
mcmc.sampler._init_fn = None # pylint: disable=protected-access
mcmc.sampler._postprocess_fn = None # pylint: disable=protected-access
mcmc.sampler._potential_fn = None # pylint: disable=protected-access
mcmc._cache = {} # pylint: disable=protected-access
return mcmc
示例3: guide
# 需要导入模块: import pyro [as 别名]
# 或者: from pyro import infer [as 别名]
def guide(self, xs, ys=None):
"""
The guide corresponds to the following:
q(y|x) = categorical(alpha(x)) # infer digit from an image
q(z|x,y) = normal(mu(x,y),sigma(x,y)) # infer handwriting style from an image and the digit
mu, sigma are given by a neural network `encoder_z`
alpha is given by a neural network `encoder_y`
:param xs: a batch of scaled vectors of pixels from an image
:param ys: (optional) a batch of the class labels i.e.
the digit corresponding to the image(s)
:return: None
"""
# inform Pyro that the variables in the batch of xs, ys are conditionally independent
with pyro.iarange("independent"):
# if the class label (the digit) is not supervised, sample
# (and score) the digit with the variational distribution
# q(y|x) = categorical(alpha(x))
if ys is None:
alpha = self.encoder_y.forward(xs)
ys = pyro.sample("y", dist.OneHotCategorical(alpha))
# sample (and score) the latent handwriting-style with the variational
# distribution q(z|x,y) = normal(mu(x,y),sigma(x,y))
mu, sigma = self.encoder_z.forward(xs, ys)
zs = pyro.sample("z", dist.Normal(mu, sigma).reshape(extra_event_dims=1))
示例4: train_gp
# 需要导入模块: import pyro [as 别名]
# 或者: from pyro import infer [as 别名]
def train_gp(args, dataset, gp_class):
u, y = dataset.get_train_data(0, gp_class.name) if args.nclt else dataset.get_test_data(1, gp_class.name) # this is only to have a correct dimension
if gp_class.name == 'GpOdoFog':
fnet = FNET(args, u.shape[2], args.kernel_dim)
def fnet_fn(x):
return pyro.module("FNET", fnet)(x)
lik = gp.likelihoods.Gaussian(name='lik_f', variance=0.1*torch.ones(6, 1))
# lik = MultiVariateGaussian(name='lik_f', dim=6) # if lower_triangular_constraint is implemented
kernel = gp.kernels.Matern52(input_dim=args.kernel_dim,
lengthscale=torch.ones(args.kernel_dim)).warp(iwarping_fn=fnet_fn)
Xu = u[torch.arange(0, u.shape[0], step=int(u.shape[0]/args.num_inducing_point)).long()]
gp_model = gp.models.VariationalSparseGP(u, torch.zeros(6, u.shape[0]), kernel, Xu,
num_data=dataset.num_data, likelihood=lik, mean_function=None,
name=gp_class.name, whiten=True, jitter=1e-3)
else:
hnet = HNET(args, u.shape[2], args.kernel_dim)
def hnet_fn(x):
return pyro.module("HNET", hnet)(x)
lik = gp.likelihoods.Gaussian(name='lik_h', variance=0.1*torch.ones(9, 1))
# lik = MultiVariateGaussian(name='lik_h', dim=9) # if lower_triangular_constraint is implemented
kernel = gp.kernels.Matern52(input_dim=args.kernel_dim,
lengthscale=torch.ones(args.kernel_dim)).warp(iwarping_fn=hnet_fn)
Xu = u[torch.arange(0, u.shape[0], step=int(u.shape[0]/args.num_inducing_point)).long()]
gp_model = gp.models.VariationalSparseGP(u, torch.zeros(9, u.shape[0]), kernel, Xu,
num_data=dataset.num_data, likelihood=lik, mean_function=None,
name=gp_class.name, whiten=True, jitter=1e-4)
gp_instante = gp_class(args, gp_model, dataset)
args.mate = preprocessing(args, dataset, gp_instante)
optimizer = optim.ClippedAdam({"lr": args.lr, "lrd": args.lr_decay})
svi = infer.SVI(gp_instante.model, gp_instante.guide, optimizer, infer.Trace_ELBO())
print("Start of training " + dataset.name + ", " + gp_class.name)
start_time = time.time()
for epoch in range(1, args.epochs + 1):
train_loop(dataset, gp_instante, svi, epoch)
if epoch == 10:
if gp_class.name == 'GpOdoFog':
gp_instante.gp_f.jitter = 1e-4
else:
gp_instante.gp_h.jitter = 1e-4
save_gp(args, gp_instante, fnet) if gp_class.name == 'GpOdoFog' else save_gp(args, gp_instante, hnet)