本文整理汇总了Python中pyro.param函数的典型用法代码示例。如果您正苦于以下问题:Python param函数的具体用法?Python param怎么用?Python param使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了param函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test_dynamic_lr
def test_dynamic_lr(scheduler, num_steps):
pyro.clear_param_store()
def model():
sample = pyro.sample('latent', Normal(torch.tensor(0.), torch.tensor(0.3)))
return pyro.sample('obs', Normal(sample, torch.tensor(0.2)), obs=torch.tensor(0.1))
def guide():
loc = pyro.param('loc', torch.tensor(0.))
scale = pyro.param('scale', torch.tensor(0.5))
pyro.sample('latent', Normal(loc, scale))
svi = SVI(model, guide, scheduler, loss=TraceGraph_ELBO())
for epoch in range(2):
scheduler.set_epoch(epoch)
for _ in range(num_steps):
svi.step()
if epoch == 1:
loc = pyro.param('loc')
scale = pyro.param('scale')
opt = scheduler.optim_objs[loc].optimizer
assert opt.state_dict()['param_groups'][0]['lr'] == 0.02
assert opt.state_dict()['param_groups'][0]['initial_lr'] == 0.01
opt = scheduler.optim_objs[scale].optimizer
assert opt.state_dict()['param_groups'][0]['lr'] == 0.02
assert opt.state_dict()['param_groups'][0]['initial_lr'] == 0.01
示例2: guide
def guide(num_particles):
q1 = pyro.param("q1", torch.tensor(pi1, requires_grad=True))
q2 = pyro.param("q2", torch.tensor(pi2, requires_grad=True))
with pyro.iarange("particles", num_particles):
z = pyro.sample("z", dist.Normal(q2, 1.0).expand_by([num_particles]))
zz = torch.exp(z) / (1.0 + torch.exp(z))
pyro.sample("y", dist.Bernoulli(q1 * zz))
示例3: guide
def guide(subsample):
loc = pyro.param("loc", lambda: torch.zeros(len(data), requires_grad=True))
scale = pyro.param("scale", lambda: torch.tensor([1.0], requires_grad=True))
with pyro.iarange("particles", num_particles):
with pyro.iarange("data", len(data), subsample_size, subsample) as ind:
loc_ind = loc[ind].unsqueeze(-1).expand(-1, num_particles)
pyro.sample("z", Normal(loc_ind, scale))
示例4: sample_ws
def sample_ws(name, width):
alpha_w_q = pyro.param("log_alpha_w_q_%s" % name,
lambda: rand_tensor((width), self.alpha_init, self.sigma_init))
mean_w_q = pyro.param("log_mean_w_q_%s" % name,
lambda: rand_tensor((width), self.mean_init, self.sigma_init))
alpha_w_q, mean_w_q = self.softplus(alpha_w_q), self.softplus(mean_w_q)
pyro.sample("w_%s" % name, Gamma(alpha_w_q, alpha_w_q / mean_w_q))
示例5: guide
def guide(*args, **kwargs):
latents_dict = {}
n_nodes = len(self.q_topo_sort)
for i, node in enumerate(self.q_topo_sort):
deps = self.q_dag.predecessors(node)
node_suffix = node[10:]
log_sig_node = pyro.param("log_sig_" + node_suffix,
Variable(-0.5 * torch.log(self.target_lambdas[node_suffix]).data +
difficulty * (torch.Tensor([-0.3]) -
0.3 * (torch.randn(1) ** 2)),
requires_grad=True))
mean_function_node = pyro.param("constant_term_" + node,
Variable(self.mu0.data +
torch.Tensor([difficulty * i / n_nodes]),
requires_grad=True))
for dep in deps:
kappa_dep = pyro.param("kappa_" + node_suffix + '_' + dep[10:],
Variable(torch.Tensor([0.5 + difficulty * i / n_nodes]),
requires_grad=True))
mean_function_node = mean_function_node + kappa_dep * latents_dict[dep]
node_flagged = True if self.which_nodes_reparam[i] == 1.0 else False
repa = True if reparameterized else node_flagged
latent_dist_node = dist.Normal(mean_function_node, torch.exp(log_sig_node),
reparameterized=repa)
latent_node = pyro.sample(node, latent_dist_node,
baseline=dict(use_decaying_avg_baseline=True,
baseline_beta=0.96))
latents_dict[node] = latent_node
return latents_dict['mu_latent_1']
示例6: guide
def guide(self, reparameterized, model_permutation, difficulty=0.0):
latents_dict = {}
n_nodes = len(self.q_topo_sort)
for i, node in enumerate(self.q_topo_sort):
deps = self.q_dag.predecessors(node)
node_suffix = node[11:]
log_sig_node = pyro.param("log_sig_" + node_suffix,
torch.tensor(-0.5 * torch.log(self.target_lambdas[node_suffix]).data +
difficulty * (torch.Tensor([-0.3]) -
0.3 * (torch.randn(1) ** 2)),
requires_grad=True))
mean_function_node = pyro.param("constant_term_" + node,
torch.tensor(self.loc0.data +
torch.Tensor([difficulty * i / n_nodes]),
requires_grad=True))
for dep in deps:
kappa_dep = pyro.param("kappa_" + node_suffix + '_' + dep[11:],
torch.tensor([0.5 + difficulty * i / n_nodes],
requires_grad=True))
mean_function_node = mean_function_node + kappa_dep * latents_dict[dep]
node_flagged = True if self.which_nodes_reparam[i] == 1.0 else False
Normal = dist.Normal if reparameterized or node_flagged else fakes.NonreparameterizedNormal
latent_node = pyro.sample(node, Normal(mean_function_node, torch.exp(log_sig_node)),
infer=dict(baseline=dict(use_decaying_avg_baseline=True,
baseline_beta=0.96)))
latents_dict[node] = latent_node
return latents_dict['loc_latent_1']
示例7: guide
def guide():
mu_q = pyro.param("mu_q", Variable(self.analytic_mu_n.data + 0.334 * torch.ones(2),
requires_grad=True))
log_sig_q = pyro.param("log_sig_q", Variable(
self.analytic_log_sig_n.data - 0.29 * torch.ones(2),
requires_grad=True))
mu_q_prime = pyro.param("mu_q_prime", Variable(torch.Tensor([-0.34, 0.52]),
requires_grad=True))
kappa_q = pyro.param("kappa_q", Variable(torch.Tensor([0.74]),
requires_grad=True))
log_sig_q_prime = pyro.param("log_sig_q_prime",
Variable(-0.5 * torch.log(1.2 * self.lam0.data),
requires_grad=True))
sig_q, sig_q_prime = torch.exp(log_sig_q), torch.exp(log_sig_q_prime)
mu_latent_dist = dist.Normal(mu_q, sig_q, reparameterized=repa2)
mu_latent = pyro.sample("mu_latent", mu_latent_dist,
baseline=dict(use_decaying_avg_baseline=use_decaying_avg_baseline))
mu_latent_prime_dist = dist.Normal(kappa_q.expand_as(mu_latent) * mu_latent + mu_q_prime,
sig_q_prime,
reparameterized=repa1)
pyro.sample("mu_latent_prime",
mu_latent_prime_dist,
baseline=dict(nn_baseline=mu_prime_baseline,
nn_baseline_input=mu_latent,
use_decaying_avg_baseline=use_decaying_avg_baseline))
return mu_latent
示例8: model
def model(num_particles):
with pyro.iarange("particles", num_particles):
q3 = pyro.param("q3", torch.tensor(pi3, requires_grad=True))
q4 = pyro.param("q4", torch.tensor(0.5 * (pi1 + pi2), requires_grad=True))
z = pyro.sample("z", dist.Normal(q3, 1.0).expand_by([num_particles]))
zz = torch.exp(z) / (1.0 + torch.exp(z))
pyro.sample("y", dist.Bernoulli(q4 * zz))
示例9: _register_param
def _register_param(self, param, mode="model"):
"""
Registers a parameter to Pyro. It can be seen as a wrapper for
:func:`pyro.param` and :func:`pyro.sample` primitives.
:param str param: Name of the parameter.
:param str mode: Either "model" or "guide".
"""
if param in self._fixed_params:
self._registered_params[param] = self._fixed_params[param]
return
prior = self._priors.get(param)
if self.name is None:
param_name = param
else:
param_name = param_with_module_name(self.name, param)
if prior is None:
constraint = self._constraints.get(param)
default_value = getattr(self, param)
if constraint is None:
p = pyro.param(param_name, default_value)
else:
p = pyro.param(param_name, default_value, constraint=constraint)
elif mode == "model":
p = pyro.sample(param_name, prior)
else: # prior != None and mode = "guide"
MAP_param_name = param_name + "_MAP"
# TODO: consider to init parameter from a prior call instead of mean
MAP_param = pyro.param(MAP_param_name, prior.mean.detach())
p = pyro.sample(param_name, dist.Delta(MAP_param))
self._registered_params[param] = p
示例10: test_iter_discrete_traces_vector
def test_iter_discrete_traces_vector(graph_type):
pyro.clear_param_store()
def model():
p = pyro.param("p", Variable(torch.Tensor([[0.05], [0.15]])))
ps = pyro.param("ps", Variable(torch.Tensor([[0.1, 0.2, 0.3, 0.4],
[0.4, 0.3, 0.2, 0.1]])))
x = pyro.sample("x", dist.Bernoulli(p))
y = pyro.sample("y", dist.Categorical(ps, one_hot=False))
assert x.size() == (2, 1)
assert y.size() == (2, 1)
return dict(x=x, y=y)
traces = list(iter_discrete_traces(graph_type, model))
p = pyro.param("p").data
ps = pyro.param("ps").data
assert len(traces) == 2 * ps.size(-1)
for scale, trace in traces:
x = trace.nodes["x"]["value"].data.squeeze().long()[0]
y = trace.nodes["y"]["value"].data.squeeze().long()[0]
expected_scale = torch.exp(dist.Bernoulli(p).log_pdf(x) *
dist.Categorical(ps, one_hot=False).log_pdf(y))
expected_scale = expected_scale.data.view(-1)[0]
assert_equal(scale, expected_scale)
示例11: test_optimizers
def test_optimizers(factory):
optim = factory()
def model(loc, cov):
x = pyro.param("x", torch.randn(2))
y = pyro.param("y", torch.randn(3, 2))
z = pyro.param("z", torch.randn(4, 2).abs(), constraint=constraints.greater_than(-1))
pyro.sample("obs_x", dist.MultivariateNormal(loc, cov), obs=x)
with pyro.iarange("y_iarange", 3):
pyro.sample("obs_y", dist.MultivariateNormal(loc, cov), obs=y)
with pyro.iarange("z_iarange", 4):
pyro.sample("obs_z", dist.MultivariateNormal(loc, cov), obs=z)
loc = torch.tensor([-0.5, 0.5])
cov = torch.tensor([[1.0, 0.09], [0.09, 0.1]])
for step in range(100):
tr = poutine.trace(model).get_trace(loc, cov)
loss = -tr.log_prob_sum()
params = {name: pyro.param(name).unconstrained() for name in ["x", "y", "z"]}
optim.step(loss, params)
for name in ["x", "y", "z"]:
actual = pyro.param(name)
expected = loc.expand(actual.shape)
assert_equal(actual, expected, prec=1e-2,
msg='{} in correct: {} vs {}'.format(name, actual, expected))
示例12: sample_zs
def sample_zs(name, width):
alpha_z_q = pyro.param("log_alpha_z_q_%s" % name,
lambda: rand_tensor((x_size, width), self.alpha_init, self.sigma_init))
mean_z_q = pyro.param("log_mean_z_q_%s" % name,
lambda: rand_tensor((x_size, width), self.mean_init, self.sigma_init))
alpha_z_q, mean_z_q = self.softplus(alpha_z_q), self.softplus(mean_z_q)
pyro.sample("z_%s" % name, Gamma(alpha_z_q, alpha_z_q / mean_z_q).independent(1))
示例13: guide
def guide(subsample_size):
mu = pyro.param("mu", lambda: Variable(torch.zeros(len(data)), requires_grad=True))
sigma = pyro.param("sigma", lambda: Variable(torch.ones(1), requires_grad=True))
with pyro.iarange("data", len(data), subsample_size) as ind:
mu = mu[ind]
sigma = sigma.expand(subsample_size)
pyro.sample("z", dist.Normal(mu, sigma, reparameterized=reparameterized))
示例14: guide
def guide():
alpha_q_log = pyro.param("alpha_q_log",
Variable(self.log_alpha_n.data + 0.17, requires_grad=True))
beta_q_log = pyro.param("beta_q_log",
Variable(self.log_beta_n.data - 0.143, requires_grad=True))
alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log)
pyro.sample("p_latent", dist.beta, alpha_q, beta_q)
pyro.map_data("aaa", self.data, lambda i, x: None, batch_size=self.batch_size)
示例15: guide
def guide():
loc1 = pyro.param("loc1", torch.randn(2, requires_grad=True))
scale1 = pyro.param("scale1", torch.ones(2, requires_grad=True))
pyro.sample("latent1", Normal(loc1, scale1))
loc2 = pyro.param("loc2", torch.randn(2, requires_grad=True))
scale2 = pyro.param("scale2", torch.ones(2, requires_grad=True))
latent2 = pyro.sample("latent2", Normal(loc2, scale2))
return latent2