本文整理汇总了Python中jax.random.PRNGKey方法的典型用法代码示例。如果您正苦于以下问题:Python random.PRNGKey方法的具体用法?Python random.PRNGKey怎么用?Python random.PRNGKey使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类jax.random
的用法示例。
在下文中一共展示了random.PRNGKey方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: __init__
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import PRNGKey [as 别名]
def __init__(self, obs_dim, *, seed=None):
"""Internal setup for Jax-based reward models.
Initialises reward model using given seed & input size (`obs_dim`).
Args:
obs_dim (int): dimensionality of observation space.
seed (int or None): random seed for generating initial params. If
None, seed will be chosen arbitrarily, as in
LinearRewardModel.
"""
# TODO: apply jax.jit() to everything in sight
net_init, self._net_apply = self.make_stax_model()
if seed is None:
# oh well
seed = np.random.randint((1 << 63) - 1)
rng = jrandom.PRNGKey(seed)
out_shape, self._net_params = net_init(rng, (-1, obs_dim))
self._net_grads = jax.grad(self._net_apply)
# output shape should just be batch dim, nothing else
assert out_shape == (-1,), "got a weird output shape %s" % (out_shape,)
示例2: jax_randint
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import PRNGKey [as 别名]
def jax_randint(key, shape, minval, maxval, dtype=np.int32):
"""Sample uniform random values in [minval, maxval) with given shape/dtype.
Args:
key: a PRNGKey used as the random key.
shape: a tuple of nonnegative integers representing the shape.
minval: int or array of ints broadcast-compatible with ``shape``, a minimum
(inclusive) value for the range.
maxval: int or array of ints broadcast-compatible with ``shape``, a maximum
(exclusive) value for the range.
dtype: optional, an int dtype for the returned values (default int32).
Returns:
A random array with the specified shape and dtype.
"""
return jax_random.randint(key, shape, minval=minval, maxval=maxval,
dtype=dtype)
示例3: _flat_reuse_dicts
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import PRNGKey [as 别名]
def _flat_reuse_dicts(reuse, *example_inputs):
r = {}
for module, parameters in reuse.items():
inputs = example_inputs
if isinstance(module, ShapedParametrized):
module, inputs = module.parametrized, module.example_inputs
if not isinstance(module, parametrized):
raise ValueError('Keys for reuse must be parametrized or ShapedParametrized.')
example_dict, _ = module._init_and_apply_parameters_dict(*inputs, key=PRNGKey(0))
params_dict = parametrized._parameters_dict(parameters, example_dict)
r.update(module._flatten_dict(params_dict))
return r
示例4: main
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import PRNGKey [as 别名]
def main(batch_size=32, nr_filters=8, epochs=10, step_size=.001, decay_rate=.999995,
model_path=Path('./pixelcnn.params')):
loss, _ = PixelCNNPP(nr_filters=nr_filters)
get_train_batches, test_batches = dataset(batch_size)
key, init_key = random.split(PRNGKey(0))
opt = Adam(exponential_decay(step_size, 1, decay_rate))
state = opt.init(loss.init_parameters(next(test_batches), key=init_key))
for epoch in range(epochs):
for batch in get_train_batches():
key, update_key = random.split(key)
i = opt.get_step(state)
state, train_loss = opt.update_and_get_loss(loss.apply, state, batch, key=update_key,
jit=True)
if i % 100 == 0 or i < 10:
key, test_key = random.split(key)
test_loss = loss.apply(opt.get_parameters(state), next(test_batches), key=test_key,
jit=True)
print(f"Epoch {epoch}, iteration {i}, "
f"train loss {train_loss:.3f}, "
f"test loss {test_loss:.3f} ")
save(opt.get_parameters(state), model_path)
示例5: test_readme
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import PRNGKey [as 别名]
def test_readme():
net = Sequential(Dense(1024), relu, Dense(1024), relu, Dense(4), log_softmax)
@parametrized
def loss(inputs, targets):
return -jnp.mean(net(inputs) * targets)
def next_batch(): return jnp.zeros((3, 784)), jnp.zeros((3, 4))
params = loss.init_parameters(*next_batch(), key=PRNGKey(0))
print(params.sequential.dense2.bias) # [-0.01101029, -0.00749435, -0.00952365, 0.00493979]
assert jnp.allclose([-0.01101029, -0.00749435, -0.00952365, 0.00493979],
params.sequential.dense2.bias)
out = loss.apply(params, *next_batch())
assert () == out.shape
out_ = loss.apply(params, *next_batch(), jit=True)
assert out.shape == out_.shape
示例6: test_Parameter_dense
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import PRNGKey [as 别名]
def test_Parameter_dense():
def Dense(out_dim, kernel_init=glorot_normal(), bias_init=normal()):
@parametrized
def dense(inputs):
kernel = parameter((inputs.shape[-1], out_dim), kernel_init)
bias = parameter((out_dim,), bias_init)
return jnp.dot(inputs, kernel) + bias
return dense
net = Dense(2)
inputs = jnp.zeros((1, 3))
params = net.init_parameters(inputs, key=PRNGKey(0))
assert (3, 2) == params.parameter0.shape
assert (2,) == params.parameter1.shape
out = net.apply(params, inputs, jit=True)
assert (1, 2) == out.shape
示例7: test_mnist_classifier
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import PRNGKey [as 别名]
def test_mnist_classifier():
from examples.mnist_classifier import predict, loss, accuracy
next_batch = lambda: (jnp.zeros((3, 784)), jnp.zeros((3, 10)))
opt = optimizers.Momentum(0.001, mass=0.9)
state = opt.init(loss.init_parameters(*next_batch(), key=PRNGKey(0)))
t = time.time()
for _ in range(10):
state = opt.update(loss.apply, state, *next_batch(), jit=True)
elapsed = time.time() - t
assert 5 > elapsed
params = opt.get_parameters(state)
train_acc = accuracy.apply_from({loss: params}, *next_batch(), jit=True)
assert () == train_acc.shape
predict_params = predict.parameters_from({loss.shaped(*next_batch()): params}, next_batch()[0])
predictions = predict.apply(predict_params, next_batch()[0], jit=True)
assert (3, 10) == predictions.shape
示例8: test_mnist_vae
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import PRNGKey [as 别名]
def test_mnist_vae():
@parametrized
def encode(input):
input = Sequential(Dense(5), relu, Dense(5), relu)(input)
mean = Dense(10)(input)
variance = Sequential(Dense(10), softplus)(input)
return mean, variance
decode = Sequential(Dense(5), relu, Dense(5), relu, Dense(5 * 5))
@parametrized
def elbo(key, images):
mu_z, sigmasq_z = encode(images)
logits_x = decode(gaussian_sample(key, mu_z, sigmasq_z))
return bernoulli_logpdf(logits_x, images) - gaussian_kl(mu_z, sigmasq_z)
params = elbo.init_parameters(PRNGKey(0), jnp.zeros((32, 5 * 5)), key=PRNGKey(0))
assert (5, 10) == params.encode.sequential1.dense.kernel.shape
示例9: test_submodule_order
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import PRNGKey [as 别名]
def test_submodule_order():
@parametrized
def net():
p = Parameter(lambda key: jnp.zeros((1,)))
a = p()
b = parameter((2,), zeros)
c = parameter((3,), zeros)
d = parameter((4,), zeros)
e = parameter((5,), zeros)
f = parameter((6,), zeros)
# must not mess up order (decided by first submodule call):
k = p()
return jnp.concatenate([a, f]) + jnp.concatenate([b, e]) + jnp.concatenate([c, d]) + k
params = net.init_parameters(key=PRNGKey(0))
assert jnp.zeros((1,)) == params.parameter0
out = net.apply(params)
assert (7,) == out.shape
示例10: test_external_submodule
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import PRNGKey [as 别名]
def test_external_submodule():
layer = Dense(3)
@parametrized
def net(inputs):
return 2 * layer(inputs)
inputs = random_inputs((2,))
params = net.init_parameters(inputs, key=PRNGKey(0))
out = net.apply(params, inputs)
assert out.shape == (3,)
out_ = net.apply(params, inputs)
assert jnp.array_equal(out, out_)
out_ = net.apply(params, inputs, jit=True)
assert jnp.allclose(out, out_)
示例11: test_inline_submodule
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import PRNGKey [as 别名]
def test_inline_submodule():
@parametrized
def net(inputs):
layer = Dense(3)
return 2 * layer(inputs)
inputs = random_inputs((2,))
params = net.init_parameters(inputs, key=PRNGKey(0))
out = net.apply(params, inputs)
assert out.shape == (3,)
out_ = net.apply(params, inputs)
assert jnp.array_equal(out, out_)
out_ = net.apply(params, inputs, jit=True)
assert jnp.allclose(out, out_)
示例12: test_external_submodule2
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import PRNGKey [as 别名]
def test_external_submodule2():
layer = Dense(2, zeros, zeros)
@parametrized
def net(inputs):
return layer(inputs)
inputs = jnp.zeros((1, 2))
params = net.init_parameters(inputs, key=PRNGKey(0))
assert_parameters_equal(((jnp.zeros((2, 2)), jnp.zeros(2)),), params)
out = net.apply(params, inputs)
assert jnp.array_equal(jnp.zeros((1, 2)), out)
out_ = net.apply(params, inputs, jit=True)
assert jnp.array_equal(out, out_)
示例13: test_submodule_reuse
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import PRNGKey [as 别名]
def test_submodule_reuse():
inputs = jnp.zeros((1, 2))
layer = Dense(5)
net1 = Sequential(layer, Dense(2))
net2 = Sequential(layer, Dense(3))
layer_params = layer.init_parameters(inputs, key=PRNGKey(0))
net1_params = net1.init_parameters(inputs, key=PRNGKey(1), reuse={layer: layer_params})
net2_params = net2.init_parameters(inputs, key=PRNGKey(2), reuse={layer: layer_params})
out1 = net1.apply(net1_params, inputs)
assert out1.shape == (1, 2)
out2 = net2.apply(net2_params, inputs)
assert out2.shape == (1, 3)
assert_dense_parameters_equal(layer_params, net1_params[0])
assert_dense_parameters_equal(layer_params, net2_params[0])
new_layer_params = layer.init_parameters(inputs, key=PRNGKey(3))
combined_params = net1.parameters_from({net1: net1_params, layer: new_layer_params}, inputs)
assert_dense_parameters_equal(new_layer_params, combined_params.dense0)
assert_dense_parameters_equal(net1_params.dense1, combined_params.dense1)
示例14: test_scan_parametrized_cell_without_params
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import PRNGKey [as 别名]
def test_scan_parametrized_cell_without_params():
@parametrized
def cell(carry, x):
return jnp.array([2]) * carry * x, jnp.array([2]) * carry * x
@parametrized
def rnn(inputs):
_, outs = lax.scan(cell, jnp.zeros((2,)), inputs)
return outs
inputs = jnp.zeros((3,))
params = rnn.init_parameters(inputs, key=PRNGKey(0))
assert_parameters_equal(((),), params)
outs = rnn.apply(params, inputs)
assert (3, 2) == outs.shape
示例15: test_scan_parametrized_cell
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import PRNGKey [as 别名]
def test_scan_parametrized_cell():
@parametrized
def cell(carry, x):
scale = parameter((2,), zeros)
return scale * jnp.array([2]) * carry * x, scale * jnp.array([2]) * carry * x
@parametrized
def rnn(inputs):
_, outs = lax.scan(cell, jnp.zeros((2,)), inputs)
return outs
inputs = jnp.zeros((3,))
rnn_params = rnn.init_parameters(inputs, key=PRNGKey(0))
assert (2,) == rnn_params.cell.parameter.shape
outs = rnn.apply(rnn_params, inputs)
assert (3, 2) == outs.shape