当前位置: 首页>>代码示例>>Python>>正文


Python Model.get_param_values方法代码示例

本文整理汇总了Python中blocks.model.Model.get_param_values方法的典型用法代码示例。如果您正苦于以下问题:Python Model.get_param_values方法的具体用法?Python Model.get_param_values怎么用?Python Model.get_param_values使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在blocks.model.Model的用法示例。


在下文中一共展示了Model.get_param_values方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_model

# 需要导入模块: from blocks.model import Model [as 别名]
# 或者: from blocks.model.Model import get_param_values [as 别名]
def test_model():
    x = tensor.matrix("x")
    mlp1 = MLP([Tanh(), Tanh()], [10, 20, 30], name="mlp1")
    mlp2 = MLP([Tanh()], [30, 40], name="mlp2")
    h1 = mlp1.apply(x)
    h2 = mlp2.apply(h1)

    model = Model(h2.sum())
    assert model.get_top_bricks() == [mlp1, mlp2]
    # The order of parameters returned is deterministic but
    # not sensible.
    assert list(model.get_params().items()) == [
        ("/mlp2/linear_0.b", mlp2.linear_transformations[0].b),
        ("/mlp1/linear_1.b", mlp1.linear_transformations[1].b),
        ("/mlp1/linear_0.b", mlp1.linear_transformations[0].b),
        ("/mlp1/linear_0.W", mlp1.linear_transformations[0].W),
        ("/mlp1/linear_1.W", mlp1.linear_transformations[1].W),
        ("/mlp2/linear_0.W", mlp2.linear_transformations[0].W),
    ]

    # Test getting and setting parameter values
    mlp3 = MLP([Tanh()], [10, 10])
    mlp3.allocate()
    model3 = Model(mlp3.apply(x))
    param_values = {
        "/mlp/linear_0.W": 2 * numpy.ones((10, 10), dtype=theano.config.floatX),
        "/mlp/linear_0.b": 3 * numpy.ones(10, dtype=theano.config.floatX),
    }
    model3.set_param_values(param_values)
    assert numpy.all(mlp3.linear_transformations[0].params[0].get_value() == 2)
    assert numpy.all(mlp3.linear_transformations[0].params[1].get_value() == 3)
    got_param_values = model3.get_param_values()
    assert len(got_param_values) == len(param_values)
    for name, value in param_values.items():
        assert_allclose(value, got_param_values[name])

    # Test name conflict handling
    mlp4 = MLP([Tanh()], [10, 10])

    def helper():
        Model(mlp4.apply(mlp3.apply(x)))

    assert_raises(ValueError, helper)
开发者ID:gilbertoIglesias,项目名称:blocks,代码行数:45,代码来源:test_model.py


注:本文中的blocks.model.Model.get_param_values方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。