本文整理汇总了Python中jax.numpy.zeros方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.zeros方法的具体用法?Python numpy.zeros怎么用?Python numpy.zeros使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类jax.numpy
的用法示例。
在下文中一共展示了numpy.zeros方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: GRUCell
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import zeros [as 别名]
def GRUCell(carry_size, param_init):
@parametrized
def gru_cell(carry, x):
def param(name):
return parameter((x.shape[1] + carry_size, carry_size), param_init, name)
both = np.concatenate((x, carry), axis=1)
update = sigmoid(np.dot(both, param('update_kernel')))
reset = sigmoid(np.dot(both, param('reset_kernel')))
both_reset_carry = np.concatenate((x, reset * carry), axis=1)
compute = np.tanh(np.dot(both_reset_carry, param('compute_kernel')))
out = update * compute + (1 - update) * carry
return out, out
def carry_init(batch_size):
return np.zeros((batch_size, carry_size))
return gru_cell, carry_init
示例2: BatchNorm
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import zeros [as 别名]
def BatchNorm(axis=(0, 1, 2), epsilon=1e-5, center=True, scale=True,
beta_init=zeros, gamma_init=ones):
"""Layer construction function for a batch normalization layer."""
axis = (axis,) if np.isscalar(axis) else axis
@parametrized
def batch_norm(x):
ed = tuple(None if i in axis else slice(None) for i in range(np.ndim(x)))
mean, var = np.mean(x, axis, keepdims=True), fastvar(x, axis, keepdims=True)
z = (x - mean) / np.sqrt(var + epsilon)
shape = tuple(d for i, d in enumerate(x.shape) if i not in axis)
scaled = z * parameter(shape, gamma_init, 'gamma')[ed] if scale else z
return scaled + parameter(shape, beta_init, 'beta')[ed] if center else scaled
return batch_norm
示例3: Wavenet
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import zeros [as 别名]
def Wavenet(dilations, filter_width, initial_filter_width, out_width,
residual_channels, dilation_channels, skip_channels, nr_mix):
"""
:param dilations: dilations for each layer
:param filter_width: for the resblock convs
:param residual_channels: 1x1 conv output channels
:param dilation_channels: gate and filter output channels
:param skip_channels: channels before the final output
:param initial_filter_width: for the pre processing conv
"""
@parametrized
def wavenet(inputs):
hidden = Conv1D(residual_channels, (initial_filter_width,))(inputs)
out = np.zeros((hidden.shape[0], out_width, residual_channels), 'float32')
for dilation in dilations:
res = ResLayer(dilation_channels, residual_channels,
filter_width, dilation, out_width)(hidden)
hidden, out_partial = res
out += out_partial
return Sequential(relu, Conv1D(skip_channels, (1,)),
relu, Conv1D(3 * nr_mix, (1,)))(out)
return wavenet
示例4: ConvOrConvTranspose
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import zeros [as 别名]
def ConvOrConvTranspose(out_chan, filter_shape=(3, 3), strides=None, padding='SAME', init_scale=1.,
transpose=False):
strides = strides or (1,) * len(filter_shape)
def apply(inputs, V, g, b):
V = g * _l2_normalize(V, (0, 1, 2))
return (lax.conv_transpose if transpose else _conv)(inputs, V, strides, padding) - b
@parametrized
def conv_or_conv_transpose(inputs):
V = parameter(filter_shape + (inputs.shape[-1], out_chan), normal(.05), 'V')
example_out = apply(inputs, V=V, g=jnp.ones(out_chan), b=jnp.zeros(out_chan))
# TODO remove need for `.aval.val` when capturing variables in initializer function:
g = Parameter(lambda key: init_scale /
jnp.sqrt(jnp.var(example_out.aval.val, (0, 1, 2)) + 1e-10), 'g')()
b = Parameter(lambda key: jnp.mean(example_out.aval.val, (0, 1, 2)) * g.aval.val, 'b')()
return apply(inputs, V, b, g)
return conv_or_conv_transpose
示例5: test_Parameter_dense
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import zeros [as 别名]
def test_Parameter_dense():
def Dense(out_dim, kernel_init=glorot_normal(), bias_init=normal()):
@parametrized
def dense(inputs):
kernel = parameter((inputs.shape[-1], out_dim), kernel_init)
bias = parameter((out_dim,), bias_init)
return jnp.dot(inputs, kernel) + bias
return dense
net = Dense(2)
inputs = jnp.zeros((1, 3))
params = net.init_parameters(inputs, key=PRNGKey(0))
assert (3, 2) == params.parameter0.shape
assert (2,) == params.parameter1.shape
out = net.apply(params, inputs, jit=True)
assert (1, 2) == out.shape
示例6: test_mnist_classifier
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import zeros [as 别名]
def test_mnist_classifier():
from examples.mnist_classifier import predict, loss, accuracy
next_batch = lambda: (jnp.zeros((3, 784)), jnp.zeros((3, 10)))
opt = optimizers.Momentum(0.001, mass=0.9)
state = opt.init(loss.init_parameters(*next_batch(), key=PRNGKey(0)))
t = time.time()
for _ in range(10):
state = opt.update(loss.apply, state, *next_batch(), jit=True)
elapsed = time.time() - t
assert 5 > elapsed
params = opt.get_parameters(state)
train_acc = accuracy.apply_from({loss: params}, *next_batch(), jit=True)
assert () == train_acc.shape
predict_params = predict.parameters_from({loss.shaped(*next_batch()): params}, next_batch()[0])
predictions = predict.apply(predict_params, next_batch()[0], jit=True)
assert (3, 10) == predictions.shape
示例7: test_mnist_vae
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import zeros [as 别名]
def test_mnist_vae():
@parametrized
def encode(input):
input = Sequential(Dense(5), relu, Dense(5), relu)(input)
mean = Dense(10)(input)
variance = Sequential(Dense(10), softplus)(input)
return mean, variance
decode = Sequential(Dense(5), relu, Dense(5), relu, Dense(5 * 5))
@parametrized
def elbo(key, images):
mu_z, sigmasq_z = encode(images)
logits_x = decode(gaussian_sample(key, mu_z, sigmasq_z))
return bernoulli_logpdf(logits_x, images) - gaussian_kl(mu_z, sigmasq_z)
params = elbo.init_parameters(PRNGKey(0), jnp.zeros((32, 5 * 5)), key=PRNGKey(0))
assert (5, 10) == params.encode.sequential1.dense.kernel.shape
示例8: test_submodule_order
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import zeros [as 别名]
def test_submodule_order():
@parametrized
def net():
p = Parameter(lambda key: jnp.zeros((1,)))
a = p()
b = parameter((2,), zeros)
c = parameter((3,), zeros)
d = parameter((4,), zeros)
e = parameter((5,), zeros)
f = parameter((6,), zeros)
# must not mess up order (decided by first submodule call):
k = p()
return jnp.concatenate([a, f]) + jnp.concatenate([b, e]) + jnp.concatenate([c, d]) + k
params = net.init_parameters(key=PRNGKey(0))
assert jnp.zeros((1,)) == params.parameter0
out = net.apply(params)
assert (7,) == out.shape
示例9: test_external_submodule2
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import zeros [as 别名]
def test_external_submodule2():
layer = Dense(2, zeros, zeros)
@parametrized
def net(inputs):
return layer(inputs)
inputs = jnp.zeros((1, 2))
params = net.init_parameters(inputs, key=PRNGKey(0))
assert_parameters_equal(((jnp.zeros((2, 2)), jnp.zeros(2)),), params)
out = net.apply(params, inputs)
assert jnp.array_equal(jnp.zeros((1, 2)), out)
out_ = net.apply(params, inputs, jit=True)
assert jnp.array_equal(out, out_)
示例10: test_external_sequential_submodule
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import zeros [as 别名]
def test_external_sequential_submodule():
layer = Sequential(Conv(4, (2, 2)), flatten, relu, Dense(3), relu, Dense(2),
Sequential(Dense(2), relu))
inputs = jnp.zeros((1, 5, 5, 2))
params = layer.init_parameters(inputs, key=PRNGKey(0))
assert (4,) == params.conv.bias.shape
assert (3,) == params.dense0.bias.shape
assert (3, 2) == params.dense1.kernel.shape
assert (2,) == params.dense1.bias.shape
assert (2,) == params.sequential.dense.bias.shape
out = layer.apply(params, inputs)
assert (1, 2) == out.shape
out_ = layer.apply(params, inputs, jit=True)
assert jnp.allclose(out, out_)
示例11: test_submodule_reuse
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import zeros [as 别名]
def test_submodule_reuse():
inputs = jnp.zeros((1, 2))
layer = Dense(5)
net1 = Sequential(layer, Dense(2))
net2 = Sequential(layer, Dense(3))
layer_params = layer.init_parameters(inputs, key=PRNGKey(0))
net1_params = net1.init_parameters(inputs, key=PRNGKey(1), reuse={layer: layer_params})
net2_params = net2.init_parameters(inputs, key=PRNGKey(2), reuse={layer: layer_params})
out1 = net1.apply(net1_params, inputs)
assert out1.shape == (1, 2)
out2 = net2.apply(net2_params, inputs)
assert out2.shape == (1, 3)
assert_dense_parameters_equal(layer_params, net1_params[0])
assert_dense_parameters_equal(layer_params, net2_params[0])
new_layer_params = layer.init_parameters(inputs, key=PRNGKey(3))
combined_params = net1.parameters_from({net1: net1_params, layer: new_layer_params}, inputs)
assert_dense_parameters_equal(new_layer_params, combined_params.dense0)
assert_dense_parameters_equal(net1_params.dense1, combined_params.dense1)
示例12: test_scan_parametrized_cell
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import zeros [as 别名]
def test_scan_parametrized_cell():
@parametrized
def cell(carry, x):
scale = parameter((2,), zeros)
return scale * jnp.array([2]) * carry * x, scale * jnp.array([2]) * carry * x
@parametrized
def rnn(inputs):
_, outs = lax.scan(cell, jnp.zeros((2,)), inputs)
return outs
inputs = jnp.zeros((3,))
rnn_params = rnn.init_parameters(inputs, key=PRNGKey(0))
assert (2,) == rnn_params.cell.parameter.shape
outs = rnn.apply(rnn_params, inputs)
assert (3, 2) == outs.shape
示例13: test_scan_parametrized_nonflat_cell
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import zeros [as 别名]
def test_scan_parametrized_nonflat_cell():
@parametrized
def cell(carry, x):
scale = parameter((2,), zeros)
return {'a': scale * jnp.array([2]) * carry['a'] * x}, scale * jnp.array([2]) * carry[
'a'] * x
@parametrized
def rnn(inputs):
_, outs = lax.scan(cell, {'a': jnp.zeros((2,))}, inputs)
return outs
inputs = jnp.zeros((3,))
rnn_params = rnn.init_parameters(inputs, key=PRNGKey(0))
assert (2,) == rnn_params.cell.parameter.shape
outs = rnn.apply(rnn_params, inputs)
assert (3, 2) == outs.shape
示例14: test_param_and_submodule_mixed
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import zeros [as 别名]
def test_param_and_submodule_mixed():
@parametrized
def linear_map(inputs):
kernel = parameter((inputs.shape[-1], 2), zeros, 'kernel')
return jnp.dot(inputs, kernel)
@parametrized
def dense(inputs):
return linear_map(inputs) + parameter((2,), zeros, 'bias')
inputs = jnp.zeros((1, 3))
params = dense.init_parameters(inputs, key=PRNGKey(0))
assert (2,) == params.bias.shape
assert (3, 2) == params.linear_map.kernel.shape
out = dense.apply(params, inputs)
assert jnp.array_equal(jnp.zeros((1, 2)), out)
out_ = dense.apply(params, inputs, jit=True)
assert jnp.array_equal(out, out_)
示例15: test_mixed_up_execution_order
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import zeros [as 别名]
def test_mixed_up_execution_order():
@parametrized
def dense(inputs):
bias = parameter((2,), zeros, 'bias')
kernel = parameter((inputs.shape[-1], 2), zeros, 'kernel')
return jnp.dot(inputs, kernel) + bias
inputs = jnp.zeros((1, 3))
params = dense.init_parameters(inputs, key=PRNGKey(0))
assert (2,) == params.bias.shape
assert (3, 2) == params.kernel.shape
out = dense.apply(params, inputs)
assert jnp.array_equal(jnp.zeros((1, 2)), out)
out_ = dense.apply(params, inputs, jit=True)
assert jnp.array_equal(out, out_)