当前位置: 首页>>代码示例>>Python>>正文


Python poutine.trace函数代码示例

本文整理汇总了Python中pyro.poutine.trace函数的典型用法代码示例。如果您正苦于以下问题:Python trace函数的具体用法?Python trace怎么用?Python trace使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了trace函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_compute_downstream_costs_irange_in_iarange

def test_compute_downstream_costs_irange_in_iarange(dim1, dim2):
    guide_trace = poutine.trace(nested_model_guide2,
                                graph_type="dense").get_trace(include_obs=False, dim1=dim1, dim2=dim2)
    model_trace = poutine.trace(poutine.replay(nested_model_guide2, trace=guide_trace),
                                graph_type="dense").get_trace(include_obs=True, dim1=dim1, dim2=dim2)

    guide_trace = prune_subsample_sites(guide_trace)
    model_trace = prune_subsample_sites(model_trace)
    model_trace.compute_log_prob()
    guide_trace.compute_log_prob()

    non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)
    dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes)
    dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes)

    assert dc_nodes == dc_nodes_brute

    for k in dc:
        assert(guide_trace.nodes[k]['log_prob'].size() == dc[k].size())
        assert_equal(dc[k], dc_brute[k])

    expected_b1 = model_trace.nodes['b1']['log_prob'] - guide_trace.nodes['b1']['log_prob']
    expected_b1 += model_trace.nodes['obs1']['log_prob']
    assert_equal(expected_b1, dc['b1'])

    expected_c = model_trace.nodes['c']['log_prob'] - guide_trace.nodes['c']['log_prob']
    for i in range(dim2):
        expected_c += model_trace.nodes['b{}'.format(i)]['log_prob'] - \
            guide_trace.nodes['b{}'.format(i)]['log_prob']
        expected_c += model_trace.nodes['obs{}'.format(i)]['log_prob']
    assert_equal(expected_c, dc['c'])

    expected_a1 = model_trace.nodes['a1']['log_prob'] - guide_trace.nodes['a1']['log_prob']
    expected_a1 += expected_c.sum()
    assert_equal(expected_a1, dc['a1'])
开发者ID:lewisKit,项目名称:pyro,代码行数:35,代码来源:test_compute_downstream_costs.py

示例2: test_block_full

 def test_block_full(self):
     model_trace = poutine.trace(poutine.block(self.model)).get_trace()
     guide_trace = poutine.trace(poutine.block(self.guide)).get_trace()
     for name in model_trace.nodes.keys():
         assert model_trace.nodes[name]["type"] in ("args", "return")
     for name in guide_trace.nodes.keys():
         assert guide_trace.nodes[name]["type"] in ("args", "return")
开发者ID:lewisKit,项目名称:pyro,代码行数:7,代码来源:test_poutines.py

示例3: test_replay_enumerate_poutine

def test_replay_enumerate_poutine(depth, first_available_dim):
    num_particles = 2
    y_dist = Categorical(torch.tensor([0.5, 0.25, 0.25]))

    def guide():
        pyro.sample("y", y_dist, infer={"enumerate": "parallel"})

    guide = poutine.enum(guide, first_available_dim=depth + first_available_dim)
    guide = poutine.trace(guide)
    guide_trace = guide.get_trace()

    def model():
        pyro.sample("x", Bernoulli(0.5))
        for i in range(depth):
            pyro.sample("a_{}".format(i), Bernoulli(0.5), infer={"enumerate": "parallel"})
        pyro.sample("y", y_dist, infer={"enumerate": "parallel"})
        for i in range(depth):
            pyro.sample("b_{}".format(i), Bernoulli(0.5), infer={"enumerate": "parallel"})

    model = poutine.enum(model, first_available_dim=first_available_dim)
    model = poutine.replay(model, trace=guide_trace)
    model = poutine.trace(model)

    for i in range(num_particles):
        tr = model.get_trace()
        assert tr.nodes["y"]["value"] is guide_trace.nodes["y"]["value"]
        tr.compute_log_prob()
        log_prob = sum(site["log_prob"] for name, site in tr.iter_stochastic_nodes())
        actual_shape = log_prob.shape
        expected_shape = (2,) * depth + (3,) + (2,) * depth + (1,) * first_available_dim
        assert actual_shape == expected_shape, 'error on iteration {}'.format(i)
开发者ID:lewisKit,项目名称:pyro,代码行数:31,代码来源:test_poutines.py

示例4: _get_traces

    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a tracegraph generator
        """

        for i in range(self.num_particles):
            guide_trace = poutine.trace(guide,
                                        graph_type="dense").get_trace(*args, **kwargs)
            model_trace = poutine.trace(poutine.replay(model, trace=guide_trace),
                                        graph_type="dense").get_trace(*args, **kwargs)
            if is_validation_enabled():
                check_model_guide_match(model_trace, guide_trace)
                enumerated_sites = [name for name, site in guide_trace.nodes.items()
                                    if site["type"] == "sample" and site["infer"].get("enumerate")]
                if enumerated_sites:
                    warnings.warn('\n'.join([
                        'TraceGraph_ELBO found sample sites configured for enumeration:'
                        ', '.join(enumerated_sites),
                        'If you want to enumerate sites, you need to use TraceEnum_ELBO instead.']))

            guide_trace = prune_subsample_sites(guide_trace)
            model_trace = prune_subsample_sites(model_trace)

            weight = 1.0 / self.num_particles
            yield weight, model_trace, guide_trace
开发者ID:lewisKit,项目名称:pyro,代码行数:26,代码来源:tracegraph_elbo.py

示例5: test_compute_downstream_costs_iarange_reuse

def test_compute_downstream_costs_iarange_reuse(dim1, dim2):
    guide_trace = poutine.trace(iarange_reuse_model_guide,
                                graph_type="dense").get_trace(include_obs=False, dim1=dim1, dim2=dim2)
    model_trace = poutine.trace(poutine.replay(iarange_reuse_model_guide, trace=guide_trace),
                                graph_type="dense").get_trace(include_obs=True, dim1=dim1, dim2=dim2)

    guide_trace = prune_subsample_sites(guide_trace)
    model_trace = prune_subsample_sites(model_trace)
    model_trace.compute_log_prob()
    guide_trace.compute_log_prob()

    non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)
    dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace,
                                             non_reparam_nodes)
    dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes)
    assert dc_nodes == dc_nodes_brute

    for k in dc:
        assert(guide_trace.nodes[k]['log_prob'].size() == dc[k].size())
        assert_equal(dc[k], dc_brute[k])

    expected_c1 = model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']
    expected_c1 += (model_trace.nodes['b1']['log_prob'] - guide_trace.nodes['b1']['log_prob']).sum()
    expected_c1 += model_trace.nodes['c2']['log_prob'] - guide_trace.nodes['c2']['log_prob']
    expected_c1 += model_trace.nodes['obs']['log_prob']
    assert_equal(expected_c1, dc['c1'])
开发者ID:lewisKit,项目名称:pyro,代码行数:26,代码来源:test_compute_downstream_costs.py

示例6: test_splice

 def test_splice(self):
     tr = poutine.trace(self.guide).get_trace()
     lifted_tr = poutine.trace(poutine.lift(self.guide, prior=self.prior)).get_trace()
     for name in tr.nodes.keys():
         if name in ('loc1', 'loc2', 'scale1', 'scale2'):
             assert name not in lifted_tr
         else:
             assert name in lifted_tr
开发者ID:lewisKit,项目名称:pyro,代码行数:8,代码来源:test_poutines.py

示例7: test_splice

 def test_splice(self):
     tr = poutine.trace(self.guide).get_trace()
     lifted_tr = poutine.trace(poutine.lift(self.guide, prior=self.prior)).get_trace()
     for name in tr.nodes.keys():
         if name in ('mu1', 'mu2', 'sigma1', 'sigma2'):
             self.assertFalse(name in lifted_tr)
         else:
             self.assertTrue(name in lifted_tr)
开发者ID:Magica-Chen,项目名称:pyro,代码行数:8,代码来源:test_poutines.py

示例8: test_trace_data

 def test_trace_data(self):
     tr1 = poutine.trace(
         poutine.block(self.model, expose_types=["sample"])).get_trace()
     tr2 = poutine.trace(
         poutine.condition(self.model, data=tr1)).get_trace()
     assert tr2.nodes["latent2"]["type"] == "sample" and \
         tr2.nodes["latent2"]["is_observed"]
     assert tr2.nodes["latent2"]["value"] is tr1.nodes["latent2"]["value"]
开发者ID:lewisKit,项目名称:pyro,代码行数:8,代码来源:test_poutines.py

示例9: test_iarange_error_on_enter

def test_iarange_error_on_enter():
    def model():
        with pyro.iarange('foo', 0):
            pass

    assert len(_DIM_ALLOCATOR._stack) == 0
    with pytest.raises(ZeroDivisionError):
        poutine.trace(model)()
    assert len(_DIM_ALLOCATOR._stack) == 0, 'stack was not cleaned on error'
开发者ID:lewisKit,项目名称:pyro,代码行数:9,代码来源:test_poutines.py

示例10: test_block_full_expose

 def test_block_full_expose(self):
     model_trace = poutine.trace(poutine.block(self.model,
                                               expose=self.model_sites)).get_trace()
     guide_trace = poutine.trace(poutine.block(self.guide,
                                               expose=self.guide_sites)).get_trace()
     for name in self.model_sites:
         assert name in model_trace
     for name in self.guide_sites:
         assert name in guide_trace
开发者ID:lewisKit,项目名称:pyro,代码行数:9,代码来源:test_poutines.py

示例11: test_block_tutorial_case

    def test_block_tutorial_case(self):
        model_trace = poutine.trace(self.model).get_trace()
        guide_trace = poutine.trace(
            poutine.block(self.guide, hide_types=["observe"])).get_trace()

        assert "latent1" in model_trace
        assert "latent1" in guide_trace
        assert "obs" in model_trace
        assert "obs" not in guide_trace
开发者ID:lewisKit,项目名称:pyro,代码行数:9,代码来源:test_poutines.py

示例12: test_prior_dict

 def test_prior_dict(self):
     tr = poutine.trace(self.guide).get_trace()
     lifted_tr = poutine.trace(poutine.lift(self.guide, prior=self.prior_dict)).get_trace()
     for name in tr.nodes.keys():
         assert name in lifted_tr
         if name in {'scale1', 'loc1', 'scale2', 'loc2'}:
             assert name + "_prior" == lifted_tr.nodes[name]['fn'].__name__
         if tr.nodes[name]["type"] == "param":
             assert lifted_tr.nodes[name]["type"] == "sample"
             assert not lifted_tr.nodes[name]["is_observed"]
开发者ID:lewisKit,项目名称:pyro,代码行数:10,代码来源:test_poutines.py

示例13: test_prior_dict

 def test_prior_dict(self):
     tr = poutine.trace(self.guide).get_trace()
     lifted_tr = poutine.trace(poutine.lift(self.guide, prior=self.prior_dict)).get_trace()
     for name in tr.nodes.keys():
         self.assertTrue(name in lifted_tr)
         if name in {'sigma1', 'mu1', 'sigma2', 'mu2'}:
             self.assertTrue(name + "_prior" == lifted_tr.nodes[name]['fn'].__name__)
         if tr.nodes[name]["type"] == "param":
             self.assertTrue(lifted_tr.nodes[name]["type"] == "sample" and
                             not lifted_tr.nodes[name]["is_observed"])
开发者ID:Magica-Chen,项目名称:pyro,代码行数:10,代码来源:test_poutines.py

示例14: _traces

 def _traces(self, *args, **kwargs):
     """
     Generator of weighted samples from the proposal distribution.
     """
     for i in range(self.num_samples):
         guide_trace = poutine.trace(self.guide).get_trace(*args, **kwargs)
         model_trace = poutine.trace(
             poutine.replay(self.model, trace=guide_trace)).get_trace(*args, **kwargs)
         log_weight = model_trace.log_prob_sum() - guide_trace.log_prob_sum()
         yield (model_trace, log_weight)
开发者ID:lewisKit,项目名称:pyro,代码行数:10,代码来源:importance.py

示例15: test_replay_full_repeat

 def test_replay_full_repeat(self):
     model_trace = poutine.trace(self.model).get_trace()
     ftr = poutine.trace(poutine.replay(self.model, trace=model_trace))
     tr11 = ftr.get_trace()
     tr12 = ftr.get_trace()
     tr2 = poutine.trace(poutine.replay(self.model, trace=model_trace)).get_trace()
     for name in self.full_sample_sites.keys():
         assert_equal(tr11.nodes[name]["value"], tr12.nodes[name]["value"])
         assert_equal(tr11.nodes[name]["value"], tr2.nodes[name]["value"])
         assert_equal(model_trace.nodes[name]["value"], tr11.nodes[name]["value"])
         assert_equal(model_trace.nodes[name]["value"], tr2.nodes[name]["value"])
开发者ID:lewisKit,项目名称:pyro,代码行数:11,代码来源:test_poutines.py


注:本文中的pyro.poutine.trace函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。