当前位置: 首页>>代码示例>>Python>>正文


Python numpy.zeros方法代码示例

本文整理汇总了Python中jax.numpy.zeros方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.zeros方法的具体用法?Python numpy.zeros怎么用?Python numpy.zeros使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在jax.numpy的用法示例。


在下文中一共展示了numpy.zeros方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: GRUCell

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import zeros [as 别名]
def GRUCell(carry_size, param_init):
    @parametrized
    def gru_cell(carry, x):
        def param(name):
            return parameter((x.shape[1] + carry_size, carry_size), param_init, name)

        both = np.concatenate((x, carry), axis=1)
        update = sigmoid(np.dot(both, param('update_kernel')))
        reset = sigmoid(np.dot(both, param('reset_kernel')))
        both_reset_carry = np.concatenate((x, reset * carry), axis=1)
        compute = np.tanh(np.dot(both_reset_carry, param('compute_kernel')))
        out = update * compute + (1 - update) * carry
        return out, out

    def carry_init(batch_size):
        return np.zeros((batch_size, carry_size))

    return gru_cell, carry_init 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:20,代码来源:modules.py

示例2: BatchNorm

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import zeros [as 别名]
def BatchNorm(axis=(0, 1, 2), epsilon=1e-5, center=True, scale=True,
              beta_init=zeros, gamma_init=ones):
    """Layer construction function for a batch normalization layer."""

    axis = (axis,) if np.isscalar(axis) else axis

    @parametrized
    def batch_norm(x):
        ed = tuple(None if i in axis else slice(None) for i in range(np.ndim(x)))
        mean, var = np.mean(x, axis, keepdims=True), fastvar(x, axis, keepdims=True)
        z = (x - mean) / np.sqrt(var + epsilon)
        shape = tuple(d for i, d in enumerate(x.shape) if i not in axis)

        scaled = z * parameter(shape, gamma_init, 'gamma')[ed] if scale else z
        return scaled + parameter(shape, beta_init, 'beta')[ed] if center else scaled

    return batch_norm 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:19,代码来源:modules.py

示例3: Wavenet

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import zeros [as 别名]
def Wavenet(dilations, filter_width, initial_filter_width, out_width,
            residual_channels, dilation_channels, skip_channels, nr_mix):
    """
    :param dilations: dilations for each layer
    :param filter_width: for the resblock convs
    :param residual_channels: 1x1 conv output channels
    :param dilation_channels: gate and filter output channels
    :param skip_channels: channels before the final output
    :param initial_filter_width: for the pre processing conv
    """

    @parametrized
    def wavenet(inputs):
        hidden = Conv1D(residual_channels, (initial_filter_width,))(inputs)
        out = np.zeros((hidden.shape[0], out_width, residual_channels), 'float32')
        for dilation in dilations:
            res = ResLayer(dilation_channels, residual_channels,
                           filter_width, dilation, out_width)(hidden)
            hidden, out_partial = res
            out += out_partial
        return Sequential(relu, Conv1D(skip_channels, (1,)),
                          relu, Conv1D(3 * nr_mix, (1,)))(out)

    return wavenet 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:26,代码来源:wavenet.py

示例4: ConvOrConvTranspose

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import zeros [as 别名]
def ConvOrConvTranspose(out_chan, filter_shape=(3, 3), strides=None, padding='SAME', init_scale=1.,
                        transpose=False):
    strides = strides or (1,) * len(filter_shape)

    def apply(inputs, V, g, b):
        V = g * _l2_normalize(V, (0, 1, 2))
        return (lax.conv_transpose if transpose else _conv)(inputs, V, strides, padding) - b

    @parametrized
    def conv_or_conv_transpose(inputs):
        V = parameter(filter_shape + (inputs.shape[-1], out_chan), normal(.05), 'V')

        example_out = apply(inputs, V=V, g=jnp.ones(out_chan), b=jnp.zeros(out_chan))

        # TODO remove need for `.aval.val` when capturing variables in initializer function:
        g = Parameter(lambda key: init_scale /
                                  jnp.sqrt(jnp.var(example_out.aval.val, (0, 1, 2)) + 1e-10), 'g')()
        b = Parameter(lambda key: jnp.mean(example_out.aval.val, (0, 1, 2)) * g.aval.val, 'b')()

        return apply(inputs, V, b, g)

    return conv_or_conv_transpose 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:24,代码来源:pixelcnn.py

示例5: test_Parameter_dense

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import zeros [as 别名]
def test_Parameter_dense():
    def Dense(out_dim, kernel_init=glorot_normal(), bias_init=normal()):
        @parametrized
        def dense(inputs):
            kernel = parameter((inputs.shape[-1], out_dim), kernel_init)
            bias = parameter((out_dim,), bias_init)
            return jnp.dot(inputs, kernel) + bias

        return dense

    net = Dense(2)
    inputs = jnp.zeros((1, 3))
    params = net.init_parameters(inputs, key=PRNGKey(0))
    assert (3, 2) == params.parameter0.shape
    assert (2,) == params.parameter1.shape

    out = net.apply(params, inputs, jit=True)
    assert (1, 2) == out.shape 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:20,代码来源:test_examples.py

示例6: test_mnist_classifier

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import zeros [as 别名]
def test_mnist_classifier():
    from examples.mnist_classifier import predict, loss, accuracy

    next_batch = lambda: (jnp.zeros((3, 784)), jnp.zeros((3, 10)))
    opt = optimizers.Momentum(0.001, mass=0.9)
    state = opt.init(loss.init_parameters(*next_batch(), key=PRNGKey(0)))

    t = time.time()
    for _ in range(10):
        state = opt.update(loss.apply, state, *next_batch(), jit=True)

    elapsed = time.time() - t
    assert 5 > elapsed

    params = opt.get_parameters(state)
    train_acc = accuracy.apply_from({loss: params}, *next_batch(), jit=True)
    assert () == train_acc.shape

    predict_params = predict.parameters_from({loss.shaped(*next_batch()): params}, next_batch()[0])
    predictions = predict.apply(predict_params, next_batch()[0], jit=True)
    assert (3, 10) == predictions.shape 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:23,代码来源:test_examples.py

示例7: test_mnist_vae

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import zeros [as 别名]
def test_mnist_vae():
    @parametrized
    def encode(input):
        input = Sequential(Dense(5), relu, Dense(5), relu)(input)
        mean = Dense(10)(input)
        variance = Sequential(Dense(10), softplus)(input)
        return mean, variance

    decode = Sequential(Dense(5), relu, Dense(5), relu, Dense(5 * 5))

    @parametrized
    def elbo(key, images):
        mu_z, sigmasq_z = encode(images)
        logits_x = decode(gaussian_sample(key, mu_z, sigmasq_z))
        return bernoulli_logpdf(logits_x, images) - gaussian_kl(mu_z, sigmasq_z)

    params = elbo.init_parameters(PRNGKey(0), jnp.zeros((32, 5 * 5)), key=PRNGKey(0))
    assert (5, 10) == params.encode.sequential1.dense.kernel.shape 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:20,代码来源:test_examples.py

示例8: test_submodule_order

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import zeros [as 别名]
def test_submodule_order():
    @parametrized
    def net():
        p = Parameter(lambda key: jnp.zeros((1,)))
        a = p()
        b = parameter((2,), zeros)
        c = parameter((3,), zeros)
        d = parameter((4,), zeros)
        e = parameter((5,), zeros)
        f = parameter((6,), zeros)

        # must not mess up order (decided by first submodule call):
        k = p()

        return jnp.concatenate([a, f]) + jnp.concatenate([b, e]) + jnp.concatenate([c, d]) + k

    params = net.init_parameters(key=PRNGKey(0))

    assert jnp.zeros((1,)) == params.parameter0
    out = net.apply(params)
    assert (7,) == out.shape 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:23,代码来源:test_core.py

示例9: test_external_submodule2

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import zeros [as 别名]
def test_external_submodule2():
    layer = Dense(2, zeros, zeros)

    @parametrized
    def net(inputs):
        return layer(inputs)

    inputs = jnp.zeros((1, 2))

    params = net.init_parameters(inputs, key=PRNGKey(0))
    assert_parameters_equal(((jnp.zeros((2, 2)), jnp.zeros(2)),), params)

    out = net.apply(params, inputs)
    assert jnp.array_equal(jnp.zeros((1, 2)), out)

    out_ = net.apply(params, inputs, jit=True)
    assert jnp.array_equal(out, out_) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:19,代码来源:test_core.py

示例10: test_external_sequential_submodule

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import zeros [as 别名]
def test_external_sequential_submodule():
    layer = Sequential(Conv(4, (2, 2)), flatten, relu, Dense(3), relu, Dense(2),
                       Sequential(Dense(2), relu))
    inputs = jnp.zeros((1, 5, 5, 2))

    params = layer.init_parameters(inputs, key=PRNGKey(0))
    assert (4,) == params.conv.bias.shape
    assert (3,) == params.dense0.bias.shape
    assert (3, 2) == params.dense1.kernel.shape
    assert (2,) == params.dense1.bias.shape
    assert (2,) == params.sequential.dense.bias.shape

    out = layer.apply(params, inputs)
    assert (1, 2) == out.shape

    out_ = layer.apply(params, inputs, jit=True)
    assert jnp.allclose(out, out_) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:19,代码来源:test_core.py

示例11: test_submodule_reuse

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import zeros [as 别名]
def test_submodule_reuse():
    inputs = jnp.zeros((1, 2))

    layer = Dense(5)
    net1 = Sequential(layer, Dense(2))
    net2 = Sequential(layer, Dense(3))

    layer_params = layer.init_parameters(inputs, key=PRNGKey(0))
    net1_params = net1.init_parameters(inputs, key=PRNGKey(1), reuse={layer: layer_params})
    net2_params = net2.init_parameters(inputs, key=PRNGKey(2), reuse={layer: layer_params})

    out1 = net1.apply(net1_params, inputs)
    assert out1.shape == (1, 2)

    out2 = net2.apply(net2_params, inputs)
    assert out2.shape == (1, 3)

    assert_dense_parameters_equal(layer_params, net1_params[0])
    assert_dense_parameters_equal(layer_params, net2_params[0])

    new_layer_params = layer.init_parameters(inputs, key=PRNGKey(3))
    combined_params = net1.parameters_from({net1: net1_params, layer: new_layer_params}, inputs)
    assert_dense_parameters_equal(new_layer_params, combined_params.dense0)
    assert_dense_parameters_equal(net1_params.dense1, combined_params.dense1) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:26,代码来源:test_core.py

示例12: test_scan_parametrized_cell

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import zeros [as 别名]
def test_scan_parametrized_cell():
    @parametrized
    def cell(carry, x):
        scale = parameter((2,), zeros)
        return scale * jnp.array([2]) * carry * x, scale * jnp.array([2]) * carry * x

    @parametrized
    def rnn(inputs):
        _, outs = lax.scan(cell, jnp.zeros((2,)), inputs)
        return outs

    inputs = jnp.zeros((3,))

    rnn_params = rnn.init_parameters(inputs, key=PRNGKey(0))
    assert (2,) == rnn_params.cell.parameter.shape
    outs = rnn.apply(rnn_params, inputs)

    assert (3, 2) == outs.shape 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:20,代码来源:test_core.py

示例13: test_scan_parametrized_nonflat_cell

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import zeros [as 别名]
def test_scan_parametrized_nonflat_cell():
    @parametrized
    def cell(carry, x):
        scale = parameter((2,), zeros)
        return {'a': scale * jnp.array([2]) * carry['a'] * x}, scale * jnp.array([2]) * carry[
            'a'] * x

    @parametrized
    def rnn(inputs):
        _, outs = lax.scan(cell, {'a': jnp.zeros((2,))}, inputs)
        return outs

    inputs = jnp.zeros((3,))

    rnn_params = rnn.init_parameters(inputs, key=PRNGKey(0))
    assert (2,) == rnn_params.cell.parameter.shape
    outs = rnn.apply(rnn_params, inputs)

    assert (3, 2) == outs.shape 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:21,代码来源:test_core.py

示例14: test_param_and_submodule_mixed

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import zeros [as 别名]
def test_param_and_submodule_mixed():
    @parametrized
    def linear_map(inputs):
        kernel = parameter((inputs.shape[-1], 2), zeros, 'kernel')
        return jnp.dot(inputs, kernel)

    @parametrized
    def dense(inputs):
        return linear_map(inputs) + parameter((2,), zeros, 'bias')

    inputs = jnp.zeros((1, 3))

    params = dense.init_parameters(inputs, key=PRNGKey(0))
    assert (2,) == params.bias.shape
    assert (3, 2) == params.linear_map.kernel.shape

    out = dense.apply(params, inputs)
    assert jnp.array_equal(jnp.zeros((1, 2)), out)

    out_ = dense.apply(params, inputs, jit=True)
    assert jnp.array_equal(out, out_) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:23,代码来源:test_core.py

示例15: test_mixed_up_execution_order

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import zeros [as 别名]
def test_mixed_up_execution_order():
    @parametrized
    def dense(inputs):
        bias = parameter((2,), zeros, 'bias')
        kernel = parameter((inputs.shape[-1], 2), zeros, 'kernel')
        return jnp.dot(inputs, kernel) + bias

    inputs = jnp.zeros((1, 3))

    params = dense.init_parameters(inputs, key=PRNGKey(0))
    assert (2,) == params.bias.shape
    assert (3, 2) == params.kernel.shape

    out = dense.apply(params, inputs)
    assert jnp.array_equal(jnp.zeros((1, 2)), out)

    out_ = dense.apply(params, inputs, jit=True)
    assert jnp.array_equal(out, out_) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:20,代码来源:test_core.py


注:本文中的jax.numpy.zeros方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。