本文整理汇总了Python中jax.random.normal方法的典型用法代码示例。如果您正苦于以下问题:Python random.normal方法的具体用法?Python random.normal怎么用?Python random.normal使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类jax.random
的用法示例。
在下文中一共展示了random.normal方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test_Parameter_dense
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import normal [as 别名]
def test_Parameter_dense():
def Dense(out_dim, kernel_init=glorot_normal(), bias_init=normal()):
@parametrized
def dense(inputs):
kernel = parameter((inputs.shape[-1], out_dim), kernel_init)
bias = parameter((out_dim,), bias_init)
return jnp.dot(inputs, kernel) + bias
return dense
net = Dense(2)
inputs = jnp.zeros((1, 3))
params = net.init_parameters(inputs, key=PRNGKey(0))
assert (3, 2) == params.parameter0.shape
assert (2,) == params.parameter1.shape
out = net.apply(params, inputs, jit=True)
assert (1, 2) == out.shape
示例2: _onion
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import normal [as 别名]
def _onion(self, key, size):
key_beta, key_normal = random.split(key)
# Now we generate w term in Algorithm 3.2 of [1].
beta_sample = self._beta.sample(key_beta, size)
# The following Normal distribution is used to create a uniform distribution on
# a hypershere (ref: http://mathworld.wolfram.com/HyperspherePointPicking.html)
normal_sample = random.normal(
key_normal,
shape=size + self.batch_shape + (self.dimension * (self.dimension - 1) // 2,)
)
normal_sample = vec_to_tril_matrix(normal_sample, diagonal=0)
u_hypershere = normal_sample / jnp.linalg.norm(normal_sample, axis=-1, keepdims=True)
w = jnp.expand_dims(jnp.sqrt(beta_sample), axis=-1) * u_hypershere
# put w into the off-diagonal triangular part
cholesky = ops.index_add(jnp.zeros(size + self.batch_shape + self.event_shape),
ops.index[..., 1:, :-1], w)
# correct the diagonal
# NB: we clip due to numerical precision
diag = jnp.sqrt(jnp.clip(1 - jnp.sum(cholesky ** 2, axis=-1), a_min=0.))
cholesky = cholesky + jnp.expand_dims(diag, axis=-1) * jnp.identity(self.dimension)
return cholesky
示例3: test_laplace_approximation_warning
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import normal [as 别名]
def test_laplace_approximation_warning():
def model(x, y):
a = numpyro.sample("a", dist.Normal(0, 10))
b = numpyro.sample("b", dist.Normal(0, 10), sample_shape=(3,))
mu = a + b[0] * x + b[1] * x ** 2 + b[2] * x ** 3
numpyro.sample("y", dist.Normal(mu, 0.001), obs=y)
x = random.normal(random.PRNGKey(0), (3,))
y = 1 + 2 * x + 3 * x ** 2 + 4 * x ** 3
guide = AutoLaplaceApproximation(model)
svi = SVI(model, guide, optim.Adam(0.1), ELBO(), x=x, y=y)
init_state = svi.init(random.PRNGKey(0))
svi_state = fori_loop(0, 10000, lambda i, val: svi.update(val)[0], init_state)
params = svi.get_params(svi_state)
with pytest.warns(UserWarning, match="Hessian of log posterior"):
guide.sample_posterior(random.PRNGKey(1), params)
示例4: test_correlated_mvn
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import normal [as 别名]
def test_correlated_mvn():
# This requires dense mass matrix estimation.
D = 5
warmup_steps, num_samples = 5000, 8000
true_mean = 0.
a = jnp.tril(0.5 * jnp.fliplr(jnp.eye(D)) + 0.1 * jnp.exp(random.normal(random.PRNGKey(0), shape=(D, D))))
true_cov = jnp.dot(a, a.T)
true_prec = jnp.linalg.inv(true_cov)
def potential_fn(z):
return 0.5 * jnp.dot(z.T, jnp.dot(true_prec, z))
init_params = jnp.zeros(D)
kernel = NUTS(potential_fn=potential_fn, dense_mass=True)
mcmc = MCMC(kernel, warmup_steps, num_samples)
mcmc.run(random.PRNGKey(0), init_params=init_params)
samples = mcmc.get_samples()
assert_allclose(jnp.mean(samples), true_mean, atol=0.02)
assert np.sum(np.abs(np.cov(samples.T) - true_cov)) / D**2 < 0.02
示例5: test_uniform_normal
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import normal [as 别名]
def test_uniform_normal():
true_coef = 0.9
num_warmup, num_samples = 1000, 1000
def model(data):
alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
with numpyro.handlers.reparam(config={'loc': TransformReparam()}):
loc = numpyro.sample('loc', dist.Uniform(0, alpha))
numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)
data = true_coef + random.normal(random.PRNGKey(0), (1000,))
kernel = NUTS(model=model)
mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
mcmc.warmup(random.PRNGKey(2), data, collect_warmup=True)
warmup_samples = mcmc.get_samples()
mcmc.run(random.PRNGKey(3), data)
samples = mcmc.get_samples()
assert len(warmup_samples['loc']) == num_warmup
assert len(samples['loc']) == num_samples
assert_allclose(jnp.mean(samples['loc'], 0), true_coef, atol=0.05)
示例6: test_improper_normal
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import normal [as 别名]
def test_improper_normal():
true_coef = 0.9
def model(data):
alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
with numpyro.handlers.reparam(config={'loc': TransformReparam()}):
loc = numpyro.sample('loc', dist.TransformedDistribution(
dist.Uniform(0, 1).mask(False),
AffineTransform(0, alpha)))
numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)
data = true_coef + random.normal(random.PRNGKey(0), (1000,))
kernel = NUTS(model=model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000)
mcmc.run(random.PRNGKey(0), data)
samples = mcmc.get_samples()
assert_allclose(jnp.mean(samples['loc'], 0), true_coef, atol=0.05)
示例7: test_diverging
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import normal [as 别名]
def test_diverging(kernel_cls, adapt_step_size):
data = random.normal(random.PRNGKey(0), (1000,))
def model(data):
loc = numpyro.sample('loc', dist.Normal(0., 1.))
numpyro.sample('obs', dist.Normal(loc, 1), obs=data)
kernel = kernel_cls(model, step_size=10., adapt_step_size=adapt_step_size, adapt_mass_matrix=False)
num_warmup = num_samples = 1000
mcmc = MCMC(kernel, num_warmup, num_samples)
mcmc.warmup(random.PRNGKey(1), data, extra_fields=['diverging'], collect_warmup=True)
warmup_divergences = mcmc.get_extra_fields()['diverging'].sum()
mcmc.run(random.PRNGKey(2), data, extra_fields=['diverging'])
num_divergences = warmup_divergences + mcmc.get_extra_fields()['diverging'].sum()
if adapt_step_size:
assert num_divergences <= num_warmup
else:
assert_allclose(num_divergences, num_warmup + num_samples)
示例8: test_chain
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import normal [as 别名]
def test_chain(use_init_params, chain_method):
N, dim = 3000, 3
num_chains = 2
num_warmup, num_samples = 5000, 5000
data = random.normal(random.PRNGKey(0), (N, dim))
true_coefs = jnp.arange(1., dim + 1.)
logits = jnp.sum(true_coefs * data, axis=-1)
labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))
def model(labels):
coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
logits = jnp.sum(coefs * data, axis=-1)
return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)
kernel = NUTS(model=model)
mcmc = MCMC(kernel, num_warmup, num_samples, num_chains=num_chains)
mcmc.chain_method = chain_method
init_params = None if not use_init_params else \
{'coefs': jnp.tile(jnp.ones(dim), num_chains).reshape(num_chains, dim)}
mcmc.run(random.PRNGKey(2), labels, init_params=init_params)
samples_flat = mcmc.get_samples()
assert samples_flat['coefs'].shape[0] == num_chains * num_samples
samples = mcmc.get_samples(group_by_chain=True)
assert samples['coefs'].shape[:2] == (num_chains, num_samples)
assert_allclose(jnp.mean(samples_flat['coefs'], 0), true_coefs, atol=0.21)
示例9: test_reuse_mcmc_run
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import normal [as 别名]
def test_reuse_mcmc_run(jit_args, shape):
y1 = np.random.normal(3, 0.1, (100,))
y2 = np.random.normal(-3, 0.1, (shape,))
def model(y_obs):
mu = numpyro.sample('mu', dist.Normal(0., 1.))
sigma = numpyro.sample("sigma", dist.HalfCauchy(3.))
numpyro.sample("y", dist.Normal(mu, sigma), obs=y_obs)
# Run MCMC on zero observations.
kernel = NUTS(model)
mcmc = MCMC(kernel, 300, 500, jit_model_args=jit_args)
mcmc.run(random.PRNGKey(32), y1)
# Re-run on new data - should be much faster.
mcmc.run(random.PRNGKey(32), y2)
assert_allclose(mcmc.get_samples()['mu'].mean(), -3., atol=0.1)
示例10: test_block_neural_arn
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import normal [as 别名]
def test_block_neural_arn(input_dim, hidden_factors, residual, batch_shape):
arn_init, arn = BlockNeuralAutoregressiveNN(input_dim, hidden_factors, residual)
rng = random.PRNGKey(0)
input_shape = batch_shape + (input_dim,)
out_shape, init_params = arn_init(rng, input_shape)
assert out_shape == input_shape
x = random.normal(random.PRNGKey(1), input_shape)
output, logdet = arn(init_params, x)
assert output.shape == input_shape
assert logdet.shape == input_shape
if len(batch_shape) == 1:
jac = vmap(jacfwd(lambda x: arn(init_params, x)[0]))(x)
else:
jac = jacfwd(lambda x: arn(init_params, x)[0])(x)
assert_allclose(logdet.sum(-1), jnp.linalg.slogdet(jac)[1], rtol=1e-6)
# make sure jacobians are lower triangular
assert np.sum(np.abs(np.triu(jac, k=1))) == 0.0
assert np.all(np.abs(matrix_to_tril_vec(jac)) > 0)
示例11: test_predictive_with_improper
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import normal [as 别名]
def test_predictive_with_improper():
true_coef = 0.9
def model(data):
alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
with handlers.reparam(config={'loc': TransformReparam()}):
loc = numpyro.sample('loc', dist.TransformedDistribution(
dist.Uniform(0, 1).mask(False),
AffineTransform(0, alpha)))
numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)
data = true_coef + random.normal(random.PRNGKey(0), (1000,))
kernel = NUTS(model=model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000)
mcmc.run(random.PRNGKey(0), data)
samples = mcmc.get_samples()
obs_pred = Predictive(model, samples)(random.PRNGKey(1), data=None)["obs"]
assert_allclose(jnp.mean(obs_pred), true_coef, atol=0.05)
示例12: normal
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import normal [as 别名]
def normal(self, *args, **kwargs):
return backend()["random_normal"](*args, **kwargs)
示例13: get_batch
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import normal [as 别名]
def get_batch(input_size, output_size, batch_size, key):
key, split = random.split(key)
# jax.random will always generate float32 even if jax_enable_x64==True.
xs = random.normal(split, shape=(batch_size, input_size),
dtype=canonicalize_dtype(onp.float64))
key, split = random.split(key)
ys = random.randint(split, minval=0, maxval=output_size, shape=(batch_size,))
ys = to_onehot(ys, output_size)
return (xs, ys), key
示例14: testTridiagEigenvalues
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import normal [as 别名]
def testTridiagEigenvalues(self, shape):
onp.random.seed(100)
sigma_squared = 1e-2
# if order > matrix shape, lanczos may fail due to linear dependence.
order = min(70, shape[0])
atol = 1e-5
key = random.PRNGKey(0)
matrix = random.normal(key, shape)
matrix = np.dot(matrix, matrix.T) # symmetrize the matrix
mvp = jit(lambda v: np.dot(matrix, v))
eigs_true, _ = onp.linalg.eigh(matrix)
@jit
def get_tridiag(key):
return lanczos.lanczos_alg(mvp, matrix.shape[0], order, rng_key=key)[0]
tridiag_matrix = get_tridiag(key)
eigs_tridiag, _ = onp.linalg.eigh(tridiag_matrix)
density, grids = density_lib.eigv_to_density(
onp.expand_dims(eigs_tridiag, 0), sigma_squared=sigma_squared)
density_true, _ = density_lib.eigv_to_density(
onp.expand_dims(eigs_true, 0), grids=grids, sigma_squared=sigma_squared)
self.assertAlmostEqual(np.max(eigs_tridiag), np.max(eigs_true), delta=atol)
self.assertAlmostEqual(np.min(eigs_tridiag), np.min(eigs_true), delta=atol)
self.assertArraysAllClose(density, density_true, True, atol=atol)
示例15: gaussian_sample
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import normal [as 别名]
def gaussian_sample(key, mu, sigmasq):
"""Sample a diagonal Gaussian."""
return mu + np.sqrt(sigmasq) * random.normal(key, mu.shape)