本文整理汇总了Python中pyro.poutine.trace函数的典型用法代码示例。如果您正苦于以下问题:Python trace函数的具体用法?Python trace怎么用?Python trace使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了trace函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test_compute_downstream_costs_irange_in_iarange
def test_compute_downstream_costs_irange_in_iarange(dim1, dim2):
guide_trace = poutine.trace(nested_model_guide2,
graph_type="dense").get_trace(include_obs=False, dim1=dim1, dim2=dim2)
model_trace = poutine.trace(poutine.replay(nested_model_guide2, trace=guide_trace),
graph_type="dense").get_trace(include_obs=True, dim1=dim1, dim2=dim2)
guide_trace = prune_subsample_sites(guide_trace)
model_trace = prune_subsample_sites(model_trace)
model_trace.compute_log_prob()
guide_trace.compute_log_prob()
non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)
dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes)
dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes)
assert dc_nodes == dc_nodes_brute
for k in dc:
assert(guide_trace.nodes[k]['log_prob'].size() == dc[k].size())
assert_equal(dc[k], dc_brute[k])
expected_b1 = model_trace.nodes['b1']['log_prob'] - guide_trace.nodes['b1']['log_prob']
expected_b1 += model_trace.nodes['obs1']['log_prob']
assert_equal(expected_b1, dc['b1'])
expected_c = model_trace.nodes['c']['log_prob'] - guide_trace.nodes['c']['log_prob']
for i in range(dim2):
expected_c += model_trace.nodes['b{}'.format(i)]['log_prob'] - \
guide_trace.nodes['b{}'.format(i)]['log_prob']
expected_c += model_trace.nodes['obs{}'.format(i)]['log_prob']
assert_equal(expected_c, dc['c'])
expected_a1 = model_trace.nodes['a1']['log_prob'] - guide_trace.nodes['a1']['log_prob']
expected_a1 += expected_c.sum()
assert_equal(expected_a1, dc['a1'])
示例2: test_block_full
def test_block_full(self):
model_trace = poutine.trace(poutine.block(self.model)).get_trace()
guide_trace = poutine.trace(poutine.block(self.guide)).get_trace()
for name in model_trace.nodes.keys():
assert model_trace.nodes[name]["type"] in ("args", "return")
for name in guide_trace.nodes.keys():
assert guide_trace.nodes[name]["type"] in ("args", "return")
示例3: test_replay_enumerate_poutine
def test_replay_enumerate_poutine(depth, first_available_dim):
num_particles = 2
y_dist = Categorical(torch.tensor([0.5, 0.25, 0.25]))
def guide():
pyro.sample("y", y_dist, infer={"enumerate": "parallel"})
guide = poutine.enum(guide, first_available_dim=depth + first_available_dim)
guide = poutine.trace(guide)
guide_trace = guide.get_trace()
def model():
pyro.sample("x", Bernoulli(0.5))
for i in range(depth):
pyro.sample("a_{}".format(i), Bernoulli(0.5), infer={"enumerate": "parallel"})
pyro.sample("y", y_dist, infer={"enumerate": "parallel"})
for i in range(depth):
pyro.sample("b_{}".format(i), Bernoulli(0.5), infer={"enumerate": "parallel"})
model = poutine.enum(model, first_available_dim=first_available_dim)
model = poutine.replay(model, trace=guide_trace)
model = poutine.trace(model)
for i in range(num_particles):
tr = model.get_trace()
assert tr.nodes["y"]["value"] is guide_trace.nodes["y"]["value"]
tr.compute_log_prob()
log_prob = sum(site["log_prob"] for name, site in tr.iter_stochastic_nodes())
actual_shape = log_prob.shape
expected_shape = (2,) * depth + (3,) + (2,) * depth + (1,) * first_available_dim
assert actual_shape == expected_shape, 'error on iteration {}'.format(i)
示例4: _get_traces
def _get_traces(self, model, guide, *args, **kwargs):
"""
runs the guide and runs the model against the guide with
the result packaged as a tracegraph generator
"""
for i in range(self.num_particles):
guide_trace = poutine.trace(guide,
graph_type="dense").get_trace(*args, **kwargs)
model_trace = poutine.trace(poutine.replay(model, trace=guide_trace),
graph_type="dense").get_trace(*args, **kwargs)
if is_validation_enabled():
check_model_guide_match(model_trace, guide_trace)
enumerated_sites = [name for name, site in guide_trace.nodes.items()
if site["type"] == "sample" and site["infer"].get("enumerate")]
if enumerated_sites:
warnings.warn('\n'.join([
'TraceGraph_ELBO found sample sites configured for enumeration:'
', '.join(enumerated_sites),
'If you want to enumerate sites, you need to use TraceEnum_ELBO instead.']))
guide_trace = prune_subsample_sites(guide_trace)
model_trace = prune_subsample_sites(model_trace)
weight = 1.0 / self.num_particles
yield weight, model_trace, guide_trace
示例5: test_compute_downstream_costs_iarange_reuse
def test_compute_downstream_costs_iarange_reuse(dim1, dim2):
guide_trace = poutine.trace(iarange_reuse_model_guide,
graph_type="dense").get_trace(include_obs=False, dim1=dim1, dim2=dim2)
model_trace = poutine.trace(poutine.replay(iarange_reuse_model_guide, trace=guide_trace),
graph_type="dense").get_trace(include_obs=True, dim1=dim1, dim2=dim2)
guide_trace = prune_subsample_sites(guide_trace)
model_trace = prune_subsample_sites(model_trace)
model_trace.compute_log_prob()
guide_trace.compute_log_prob()
non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)
dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace,
non_reparam_nodes)
dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes)
assert dc_nodes == dc_nodes_brute
for k in dc:
assert(guide_trace.nodes[k]['log_prob'].size() == dc[k].size())
assert_equal(dc[k], dc_brute[k])
expected_c1 = model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']
expected_c1 += (model_trace.nodes['b1']['log_prob'] - guide_trace.nodes['b1']['log_prob']).sum()
expected_c1 += model_trace.nodes['c2']['log_prob'] - guide_trace.nodes['c2']['log_prob']
expected_c1 += model_trace.nodes['obs']['log_prob']
assert_equal(expected_c1, dc['c1'])
示例6: test_splice
def test_splice(self):
tr = poutine.trace(self.guide).get_trace()
lifted_tr = poutine.trace(poutine.lift(self.guide, prior=self.prior)).get_trace()
for name in tr.nodes.keys():
if name in ('loc1', 'loc2', 'scale1', 'scale2'):
assert name not in lifted_tr
else:
assert name in lifted_tr
示例7: test_splice
def test_splice(self):
tr = poutine.trace(self.guide).get_trace()
lifted_tr = poutine.trace(poutine.lift(self.guide, prior=self.prior)).get_trace()
for name in tr.nodes.keys():
if name in ('mu1', 'mu2', 'sigma1', 'sigma2'):
self.assertFalse(name in lifted_tr)
else:
self.assertTrue(name in lifted_tr)
示例8: test_trace_data
def test_trace_data(self):
tr1 = poutine.trace(
poutine.block(self.model, expose_types=["sample"])).get_trace()
tr2 = poutine.trace(
poutine.condition(self.model, data=tr1)).get_trace()
assert tr2.nodes["latent2"]["type"] == "sample" and \
tr2.nodes["latent2"]["is_observed"]
assert tr2.nodes["latent2"]["value"] is tr1.nodes["latent2"]["value"]
示例9: test_iarange_error_on_enter
def test_iarange_error_on_enter():
def model():
with pyro.iarange('foo', 0):
pass
assert len(_DIM_ALLOCATOR._stack) == 0
with pytest.raises(ZeroDivisionError):
poutine.trace(model)()
assert len(_DIM_ALLOCATOR._stack) == 0, 'stack was not cleaned on error'
示例10: test_block_full_expose
def test_block_full_expose(self):
model_trace = poutine.trace(poutine.block(self.model,
expose=self.model_sites)).get_trace()
guide_trace = poutine.trace(poutine.block(self.guide,
expose=self.guide_sites)).get_trace()
for name in self.model_sites:
assert name in model_trace
for name in self.guide_sites:
assert name in guide_trace
示例11: test_block_tutorial_case
def test_block_tutorial_case(self):
model_trace = poutine.trace(self.model).get_trace()
guide_trace = poutine.trace(
poutine.block(self.guide, hide_types=["observe"])).get_trace()
assert "latent1" in model_trace
assert "latent1" in guide_trace
assert "obs" in model_trace
assert "obs" not in guide_trace
示例12: test_prior_dict
def test_prior_dict(self):
tr = poutine.trace(self.guide).get_trace()
lifted_tr = poutine.trace(poutine.lift(self.guide, prior=self.prior_dict)).get_trace()
for name in tr.nodes.keys():
assert name in lifted_tr
if name in {'scale1', 'loc1', 'scale2', 'loc2'}:
assert name + "_prior" == lifted_tr.nodes[name]['fn'].__name__
if tr.nodes[name]["type"] == "param":
assert lifted_tr.nodes[name]["type"] == "sample"
assert not lifted_tr.nodes[name]["is_observed"]
示例13: test_prior_dict
def test_prior_dict(self):
tr = poutine.trace(self.guide).get_trace()
lifted_tr = poutine.trace(poutine.lift(self.guide, prior=self.prior_dict)).get_trace()
for name in tr.nodes.keys():
self.assertTrue(name in lifted_tr)
if name in {'sigma1', 'mu1', 'sigma2', 'mu2'}:
self.assertTrue(name + "_prior" == lifted_tr.nodes[name]['fn'].__name__)
if tr.nodes[name]["type"] == "param":
self.assertTrue(lifted_tr.nodes[name]["type"] == "sample" and
not lifted_tr.nodes[name]["is_observed"])
示例14: _traces
def _traces(self, *args, **kwargs):
"""
Generator of weighted samples from the proposal distribution.
"""
for i in range(self.num_samples):
guide_trace = poutine.trace(self.guide).get_trace(*args, **kwargs)
model_trace = poutine.trace(
poutine.replay(self.model, trace=guide_trace)).get_trace(*args, **kwargs)
log_weight = model_trace.log_prob_sum() - guide_trace.log_prob_sum()
yield (model_trace, log_weight)
示例15: test_replay_full_repeat
def test_replay_full_repeat(self):
model_trace = poutine.trace(self.model).get_trace()
ftr = poutine.trace(poutine.replay(self.model, trace=model_trace))
tr11 = ftr.get_trace()
tr12 = ftr.get_trace()
tr2 = poutine.trace(poutine.replay(self.model, trace=model_trace)).get_trace()
for name in self.full_sample_sites.keys():
assert_equal(tr11.nodes[name]["value"], tr12.nodes[name]["value"])
assert_equal(tr11.nodes[name]["value"], tr2.nodes[name]["value"])
assert_equal(model_trace.nodes[name]["value"], tr11.nodes[name]["value"])
assert_equal(model_trace.nodes[name]["value"], tr2.nodes[name]["value"])