当前位置: 首页>>代码示例>>Python>>正文


Python random.normal方法代码示例

本文整理汇总了Python中jax.random.normal方法的典型用法代码示例。如果您正苦于以下问题:Python random.normal方法的具体用法?Python random.normal怎么用?Python random.normal使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在jax.random的用法示例。


在下文中一共展示了random.normal方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_Parameter_dense

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import normal [as 别名]
def test_Parameter_dense():
    def Dense(out_dim, kernel_init=glorot_normal(), bias_init=normal()):
        @parametrized
        def dense(inputs):
            kernel = parameter((inputs.shape[-1], out_dim), kernel_init)
            bias = parameter((out_dim,), bias_init)
            return jnp.dot(inputs, kernel) + bias

        return dense

    net = Dense(2)
    inputs = jnp.zeros((1, 3))
    params = net.init_parameters(inputs, key=PRNGKey(0))
    assert (3, 2) == params.parameter0.shape
    assert (2,) == params.parameter1.shape

    out = net.apply(params, inputs, jit=True)
    assert (1, 2) == out.shape 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:20,代码来源:test_examples.py

示例2: _onion

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import normal [as 别名]
def _onion(self, key, size):
        key_beta, key_normal = random.split(key)
        # Now we generate w term in Algorithm 3.2 of [1].
        beta_sample = self._beta.sample(key_beta, size)
        # The following Normal distribution is used to create a uniform distribution on
        # a hypershere (ref: http://mathworld.wolfram.com/HyperspherePointPicking.html)
        normal_sample = random.normal(
            key_normal,
            shape=size + self.batch_shape + (self.dimension * (self.dimension - 1) // 2,)
        )
        normal_sample = vec_to_tril_matrix(normal_sample, diagonal=0)
        u_hypershere = normal_sample / jnp.linalg.norm(normal_sample, axis=-1, keepdims=True)
        w = jnp.expand_dims(jnp.sqrt(beta_sample), axis=-1) * u_hypershere

        # put w into the off-diagonal triangular part
        cholesky = ops.index_add(jnp.zeros(size + self.batch_shape + self.event_shape),
                                 ops.index[..., 1:, :-1], w)
        # correct the diagonal
        # NB: we clip due to numerical precision
        diag = jnp.sqrt(jnp.clip(1 - jnp.sum(cholesky ** 2, axis=-1), a_min=0.))
        cholesky = cholesky + jnp.expand_dims(diag, axis=-1) * jnp.identity(self.dimension)
        return cholesky 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:24,代码来源:continuous.py

示例3: test_laplace_approximation_warning

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import normal [as 别名]
def test_laplace_approximation_warning():
    def model(x, y):
        a = numpyro.sample("a", dist.Normal(0, 10))
        b = numpyro.sample("b", dist.Normal(0, 10), sample_shape=(3,))
        mu = a + b[0] * x + b[1] * x ** 2 + b[2] * x ** 3
        numpyro.sample("y", dist.Normal(mu, 0.001), obs=y)

    x = random.normal(random.PRNGKey(0), (3,))
    y = 1 + 2 * x + 3 * x ** 2 + 4 * x ** 3
    guide = AutoLaplaceApproximation(model)
    svi = SVI(model, guide, optim.Adam(0.1), ELBO(), x=x, y=y)
    init_state = svi.init(random.PRNGKey(0))
    svi_state = fori_loop(0, 10000, lambda i, val: svi.update(val)[0], init_state)
    params = svi.get_params(svi_state)
    with pytest.warns(UserWarning, match="Hessian of log posterior"):
        guide.sample_posterior(random.PRNGKey(1), params) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:18,代码来源:test_autoguide.py

示例4: test_correlated_mvn

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import normal [as 别名]
def test_correlated_mvn():
    # This requires dense mass matrix estimation.
    D = 5

    warmup_steps, num_samples = 5000, 8000

    true_mean = 0.
    a = jnp.tril(0.5 * jnp.fliplr(jnp.eye(D)) + 0.1 * jnp.exp(random.normal(random.PRNGKey(0), shape=(D, D))))
    true_cov = jnp.dot(a, a.T)
    true_prec = jnp.linalg.inv(true_cov)

    def potential_fn(z):
        return 0.5 * jnp.dot(z.T, jnp.dot(true_prec, z))

    init_params = jnp.zeros(D)
    kernel = NUTS(potential_fn=potential_fn, dense_mass=True)
    mcmc = MCMC(kernel, warmup_steps, num_samples)
    mcmc.run(random.PRNGKey(0), init_params=init_params)
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples), true_mean, atol=0.02)
    assert np.sum(np.abs(np.cov(samples.T) - true_cov)) / D**2 < 0.02 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:23,代码来源:test_mcmc.py

示例5: test_uniform_normal

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import normal [as 别名]
def test_uniform_normal():
    true_coef = 0.9
    num_warmup, num_samples = 1000, 1000

    def model(data):
        alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
        with numpyro.handlers.reparam(config={'loc': TransformReparam()}):
            loc = numpyro.sample('loc', dist.Uniform(0, alpha))
        numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)

    data = true_coef + random.normal(random.PRNGKey(0), (1000,))
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
    mcmc.warmup(random.PRNGKey(2), data, collect_warmup=True)
    warmup_samples = mcmc.get_samples()
    mcmc.run(random.PRNGKey(3), data)
    samples = mcmc.get_samples()
    assert len(warmup_samples['loc']) == num_warmup
    assert len(samples['loc']) == num_samples
    assert_allclose(jnp.mean(samples['loc'], 0), true_coef, atol=0.05) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:22,代码来源:test_mcmc.py

示例6: test_improper_normal

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import normal [as 别名]
def test_improper_normal():
    true_coef = 0.9

    def model(data):
        alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
        with numpyro.handlers.reparam(config={'loc': TransformReparam()}):
            loc = numpyro.sample('loc', dist.TransformedDistribution(
                dist.Uniform(0, 1).mask(False),
                AffineTransform(0, alpha)))
        numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)

    data = true_coef + random.normal(random.PRNGKey(0), (1000,))
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000)
    mcmc.run(random.PRNGKey(0), data)
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples['loc'], 0), true_coef, atol=0.05) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:19,代码来源:test_mcmc.py

示例7: test_diverging

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import normal [as 别名]
def test_diverging(kernel_cls, adapt_step_size):
    data = random.normal(random.PRNGKey(0), (1000,))

    def model(data):
        loc = numpyro.sample('loc', dist.Normal(0., 1.))
        numpyro.sample('obs', dist.Normal(loc, 1), obs=data)

    kernel = kernel_cls(model, step_size=10., adapt_step_size=adapt_step_size, adapt_mass_matrix=False)
    num_warmup = num_samples = 1000
    mcmc = MCMC(kernel, num_warmup, num_samples)
    mcmc.warmup(random.PRNGKey(1), data, extra_fields=['diverging'], collect_warmup=True)
    warmup_divergences = mcmc.get_extra_fields()['diverging'].sum()
    mcmc.run(random.PRNGKey(2), data, extra_fields=['diverging'])
    num_divergences = warmup_divergences + mcmc.get_extra_fields()['diverging'].sum()
    if adapt_step_size:
        assert num_divergences <= num_warmup
    else:
        assert_allclose(num_divergences, num_warmup + num_samples) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:20,代码来源:test_mcmc.py

示例8: test_chain

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import normal [as 别名]
def test_chain(use_init_params, chain_method):
    N, dim = 3000, 3
    num_chains = 2
    num_warmup, num_samples = 5000, 5000
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = jnp.arange(1., dim + 1.)
    logits = jnp.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(labels):
        coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
        logits = jnp.sum(coefs * data, axis=-1)
        return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)

    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, num_warmup, num_samples, num_chains=num_chains)
    mcmc.chain_method = chain_method
    init_params = None if not use_init_params else \
        {'coefs': jnp.tile(jnp.ones(dim), num_chains).reshape(num_chains, dim)}
    mcmc.run(random.PRNGKey(2), labels, init_params=init_params)
    samples_flat = mcmc.get_samples()
    assert samples_flat['coefs'].shape[0] == num_chains * num_samples
    samples = mcmc.get_samples(group_by_chain=True)
    assert samples['coefs'].shape[:2] == (num_chains, num_samples)
    assert_allclose(jnp.mean(samples_flat['coefs'], 0), true_coefs, atol=0.21) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:27,代码来源:test_mcmc.py

示例9: test_reuse_mcmc_run

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import normal [as 别名]
def test_reuse_mcmc_run(jit_args, shape):
    y1 = np.random.normal(3, 0.1, (100,))
    y2 = np.random.normal(-3, 0.1, (shape,))

    def model(y_obs):
        mu = numpyro.sample('mu', dist.Normal(0., 1.))
        sigma = numpyro.sample("sigma", dist.HalfCauchy(3.))
        numpyro.sample("y", dist.Normal(mu, sigma), obs=y_obs)

    # Run MCMC on zero observations.
    kernel = NUTS(model)
    mcmc = MCMC(kernel, 300, 500, jit_model_args=jit_args)
    mcmc.run(random.PRNGKey(32), y1)

    # Re-run on new data - should be much faster.
    mcmc.run(random.PRNGKey(32), y2)
    assert_allclose(mcmc.get_samples()['mu'].mean(), -3., atol=0.1) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:19,代码来源:test_mcmc.py

示例10: test_block_neural_arn

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import normal [as 别名]
def test_block_neural_arn(input_dim, hidden_factors, residual, batch_shape):
    arn_init, arn = BlockNeuralAutoregressiveNN(input_dim, hidden_factors, residual)

    rng = random.PRNGKey(0)
    input_shape = batch_shape + (input_dim,)
    out_shape, init_params = arn_init(rng, input_shape)
    assert out_shape == input_shape

    x = random.normal(random.PRNGKey(1), input_shape)
    output, logdet = arn(init_params, x)
    assert output.shape == input_shape
    assert logdet.shape == input_shape

    if len(batch_shape) == 1:
        jac = vmap(jacfwd(lambda x: arn(init_params, x)[0]))(x)
    else:
        jac = jacfwd(lambda x: arn(init_params, x)[0])(x)
    assert_allclose(logdet.sum(-1), jnp.linalg.slogdet(jac)[1], rtol=1e-6)

    # make sure jacobians are lower triangular
    assert np.sum(np.abs(np.triu(jac, k=1))) == 0.0
    assert np.all(np.abs(matrix_to_tril_vec(jac)) > 0) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:24,代码来源:test_nn.py

示例11: test_predictive_with_improper

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import normal [as 别名]
def test_predictive_with_improper():
    true_coef = 0.9

    def model(data):
        alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
        with handlers.reparam(config={'loc': TransformReparam()}):
            loc = numpyro.sample('loc', dist.TransformedDistribution(
                dist.Uniform(0, 1).mask(False),
                AffineTransform(0, alpha)))
        numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)

    data = true_coef + random.normal(random.PRNGKey(0), (1000,))
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000)
    mcmc.run(random.PRNGKey(0), data)
    samples = mcmc.get_samples()
    obs_pred = Predictive(model, samples)(random.PRNGKey(1), data=None)["obs"]
    assert_allclose(jnp.mean(obs_pred), true_coef, atol=0.05) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:20,代码来源:test_infer_util.py

示例12: normal

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import normal [as 别名]
def normal(self, *args, **kwargs):
    return backend()["random_normal"](*args, **kwargs) 
开发者ID:yyht,项目名称:BERT,代码行数:4,代码来源:backend.py

示例13: get_batch

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import normal [as 别名]
def get_batch(input_size, output_size, batch_size, key):
  key, split = random.split(key)
  # jax.random will always generate float32 even if jax_enable_x64==True.
  xs = random.normal(split, shape=(batch_size, input_size),
                     dtype=canonicalize_dtype(onp.float64))
  key, split = random.split(key)
  ys = random.randint(split, minval=0, maxval=output_size, shape=(batch_size,))
  ys = to_onehot(ys, output_size)
  return (xs, ys), key 
开发者ID:google,项目名称:spectral-density,代码行数:11,代码来源:spectral_density_test.py

示例14: testTridiagEigenvalues

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import normal [as 别名]
def testTridiagEigenvalues(self, shape):
    onp.random.seed(100)
    sigma_squared = 1e-2

    # if order > matrix shape, lanczos may fail due to linear dependence.
    order = min(70, shape[0])

    atol = 1e-5

    key = random.PRNGKey(0)
    matrix = random.normal(key, shape)
    matrix = np.dot(matrix, matrix.T)  # symmetrize the matrix
    mvp = jit(lambda v: np.dot(matrix, v))

    eigs_true, _ = onp.linalg.eigh(matrix)

    @jit
    def get_tridiag(key):
      return lanczos.lanczos_alg(mvp, matrix.shape[0], order, rng_key=key)[0]

    tridiag_matrix = get_tridiag(key)
    eigs_tridiag, _ = onp.linalg.eigh(tridiag_matrix)
    density, grids = density_lib.eigv_to_density(
        onp.expand_dims(eigs_tridiag, 0), sigma_squared=sigma_squared)
    density_true, _ = density_lib.eigv_to_density(
        onp.expand_dims(eigs_true, 0), grids=grids, sigma_squared=sigma_squared)

    self.assertAlmostEqual(np.max(eigs_tridiag), np.max(eigs_true), delta=atol)
    self.assertAlmostEqual(np.min(eigs_tridiag), np.min(eigs_true), delta=atol)
    self.assertArraysAllClose(density, density_true, True, atol=atol) 
开发者ID:google,项目名称:spectral-density,代码行数:32,代码来源:lanczos_test.py

示例15: gaussian_sample

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import normal [as 别名]
def gaussian_sample(key, mu, sigmasq):
    """Sample a diagonal Gaussian."""
    return mu + np.sqrt(sigmasq) * random.normal(key, mu.shape) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:5,代码来源:mnist_vae.py


注:本文中的jax.random.normal方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。