当前位置: 首页>>代码示例>>Python>>正文


Python pyro.infer方法代码示例

本文整理汇总了Python中pyro.infer方法的典型用法代码示例。如果您正苦于以下问题:Python pyro.infer方法的具体用法?Python pyro.infer怎么用?Python pyro.infer使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在pyro的用法示例。


在下文中一共展示了pyro.infer方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: pyro_noncentered_schools

# 需要导入模块: import pyro [as 别名]
# 或者: from pyro import infer [as 别名]
def pyro_noncentered_schools(data, draws, chains):
    """Non-centered eight schools implementation in Pyro."""
    import torch
    from pyro.infer import MCMC, NUTS

    y = torch.from_numpy(data["y"]).float()
    sigma = torch.from_numpy(data["sigma"]).float()

    nuts_kernel = NUTS(_pyro_noncentered_model, jit_compile=True, ignore_jit_warnings=True)
    posterior = MCMC(nuts_kernel, num_samples=draws, warmup_steps=draws, num_chains=chains)
    posterior.run(data["J"], sigma, y)

    # This block lets the posterior be pickled
    posterior.sampler = None
    posterior.kernel.potential_fn = None
    return posterior


# pylint:disable=no-member,no-value-for-parameter,invalid-name 
开发者ID:arviz-devs,项目名称:arviz,代码行数:21,代码来源:helpers.py

示例2: numpyro_schools_model

# 需要导入模块: import pyro [as 别名]
# 或者: from pyro import infer [as 别名]
def numpyro_schools_model(data, draws, chains):
    """Centered eight schools implementation in NumPyro."""
    from jax.random import PRNGKey
    from numpyro.infer import MCMC, NUTS

    mcmc = MCMC(
        NUTS(_numpyro_noncentered_model),
        num_warmup=draws,
        num_samples=draws,
        num_chains=chains,
        chain_method="sequential",
    )
    mcmc.run(PRNGKey(0), extra_fields=("num_steps", "energy"), **data)

    # This block lets the posterior be pickled
    mcmc.sampler._sample_fn = None  # pylint: disable=protected-access
    mcmc.sampler._init_fn = None  # pylint: disable=protected-access
    mcmc.sampler._postprocess_fn = None  # pylint: disable=protected-access
    mcmc.sampler._potential_fn = None  # pylint: disable=protected-access
    mcmc._cache = {}  # pylint: disable=protected-access
    return mcmc 
开发者ID:arviz-devs,项目名称:arviz,代码行数:23,代码来源:helpers.py

示例3: guide

# 需要导入模块: import pyro [as 别名]
# 或者: from pyro import infer [as 别名]
def guide(self, xs, ys=None):
        """
        The guide corresponds to the following:
        q(y|x) = categorical(alpha(x))              # infer digit from an image
        q(z|x,y) = normal(mu(x,y),sigma(x,y))       # infer handwriting style from an image and the digit
        mu, sigma are given by a neural network `encoder_z`
        alpha is given by a neural network `encoder_y`

        :param xs: a batch of scaled vectors of pixels from an image
        :param ys: (optional) a batch of the class labels i.e.
                   the digit corresponding to the image(s)
        :return: None
        """
        # inform Pyro that the variables in the batch of xs, ys are conditionally independent
        with pyro.iarange("independent"):
            # if the class label (the digit) is not supervised, sample
            # (and score) the digit with the variational distribution
            # q(y|x) = categorical(alpha(x))
            if ys is None:
                alpha = self.encoder_y.forward(xs)
                ys = pyro.sample("y", dist.OneHotCategorical(alpha))

            # sample (and score) the latent handwriting-style with the variational
            # distribution q(z|x,y) = normal(mu(x,y),sigma(x,y))
            mu, sigma = self.encoder_z.forward(xs, ys)
            zs = pyro.sample("z", dist.Normal(mu, sigma).reshape(extra_event_dims=1)) 
开发者ID:jinserk,项目名称:pytorch-asr,代码行数:28,代码来源:model.py

示例4: train_gp

# 需要导入模块: import pyro [as 别名]
# 或者: from pyro import infer [as 别名]
def train_gp(args, dataset, gp_class):
	u, y = dataset.get_train_data(0, gp_class.name)  if args.nclt else dataset.get_test_data(1, gp_class.name) # this is only to have a correct dimension

	if gp_class.name == 'GpOdoFog':
		fnet = FNET(args, u.shape[2], args.kernel_dim)
		def fnet_fn(x):
			return pyro.module("FNET", fnet)(x)

		lik = gp.likelihoods.Gaussian(name='lik_f', variance=0.1*torch.ones(6, 1))
		# lik = MultiVariateGaussian(name='lik_f', dim=6) # if lower_triangular_constraint is implemented
		kernel = gp.kernels.Matern52(input_dim=args.kernel_dim,
		                               lengthscale=torch.ones(args.kernel_dim)).warp(iwarping_fn=fnet_fn)
		Xu = u[torch.arange(0, u.shape[0], step=int(u.shape[0]/args.num_inducing_point)).long()]
		gp_model = gp.models.VariationalSparseGP(u, torch.zeros(6, u.shape[0]), kernel, Xu,
		                                     num_data=dataset.num_data, likelihood=lik, mean_function=None,
		                                     name=gp_class.name, whiten=True, jitter=1e-3)
	else:
		hnet = HNET(args, u.shape[2], args.kernel_dim)
		def hnet_fn(x):
			return pyro.module("HNET", hnet)(x)
		lik = gp.likelihoods.Gaussian(name='lik_h', variance=0.1*torch.ones(9, 1))
		# lik = MultiVariateGaussian(name='lik_h', dim=9) # if lower_triangular_constraint is implemented
		kernel = gp.kernels.Matern52(input_dim=args.kernel_dim,
		                               lengthscale=torch.ones(args.kernel_dim)).warp(iwarping_fn=hnet_fn)
		Xu = u[torch.arange(0, u.shape[0], step=int(u.shape[0]/args.num_inducing_point)).long()]
		gp_model = gp.models.VariationalSparseGP(u, torch.zeros(9, u.shape[0]), kernel, Xu,
		                                     num_data=dataset.num_data, likelihood=lik, mean_function=None,
		                                     name=gp_class.name, whiten=True, jitter=1e-4)

	gp_instante = gp_class(args, gp_model, dataset)
	args.mate = preprocessing(args, dataset, gp_instante)

	optimizer = optim.ClippedAdam({"lr": args.lr, "lrd": args.lr_decay})
	svi = infer.SVI(gp_instante.model, gp_instante.guide, optimizer, infer.Trace_ELBO())

	print("Start of training " + dataset.name + ", " + gp_class.name)
	start_time = time.time()
	for epoch in range(1, args.epochs + 1):
		train_loop(dataset, gp_instante, svi, epoch)
		if epoch == 10:
			if gp_class.name == 'GpOdoFog':
				gp_instante.gp_f.jitter = 1e-4
			else:
				gp_instante.gp_h.jitter = 1e-4

	save_gp(args, gp_instante, fnet) if gp_class.name == 'GpOdoFog' else save_gp(args, gp_instante, hnet) 
开发者ID:CAOR-MINES-ParisTech,项目名称:lwoi,代码行数:48,代码来源:train.py


注:本文中的pyro.infer方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。