当前位置: 首页>>代码示例>>Python>>正文


Python numpy.array方法代码示例

本文整理汇总了Python中jax.numpy.array方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.array方法的具体用法?Python numpy.array怎么用?Python numpy.array使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在jax.numpy的用法示例。


在下文中一共展示了numpy.array方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __init__

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import array [as 别名]
def __init__(self, *dim):
        """
        >>> Id(1)
        Tensor(dom=Dim(1), cod=Dim(1), array=[1])
        >>> list(Id(2).array.flatten())
        [1.0, 0.0, 0.0, 1.0]
        >>> Id(2).array.shape
        (2, 2)
        >>> list(Id(2, 2).array.flatten())[:8]
        [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]
        >>> list(Id(2, 2).array.flatten())[8:]
        [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0]
        """
        dim = dim[0] if isinstance(dim[0], Dim) else Dim(*dim)
        array = functools.reduce(
            lambda a, x: np.tensordot(a, np.identity(x), 0)
            if a.shape else np.identity(x), dim, np.array(1))
        array = np.moveaxis(
            array, [2 * i for i in range(len(dim))], list(range(len(dim))))
        super().__init__(dim, dim, array) 
开发者ID:oxford-quantum-group,项目名称:discopy,代码行数:22,代码来源:tensor.py

示例2: test_wrapped_policy_continuous

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import array [as 别名]
def test_wrapped_policy_continuous(self, vocab_size):
    precision = 3
    n_controls = 2
    n_actions = 4
    gin.bind_parameter('BoxSpaceSerializer.precision', precision)

    obs = np.array([[[1.5, 2], [-0.3, 1.23], [0.84, 0.07], [0.01, 0.66]]])
    act = np.array([[[0, 1], [2, 0], [1, 3]]])

    wrapped_policy = serialization_utils.wrap_policy(
        TestModel(extra_dim=vocab_size),  # pylint: disable=no-value-for-parameter
        observation_space=gym.spaces.Box(shape=(2,), low=-2, high=2),
        action_space=gym.spaces.MultiDiscrete([n_actions] * n_controls),
        vocab_size=vocab_size,
    )

    example = (obs, act)
    wrapped_policy.init(shapes.signature(example))
    (act_logits, values) = wrapped_policy(example)
    self.assertEqual(act_logits.shape, obs.shape[:2] + (n_controls, n_actions))
    self.assertEqual(values.shape, obs.shape[:2]) 
开发者ID:google,项目名称:trax,代码行数:23,代码来源:serialization_utils_test.py

示例3: test_scan_parametrized_cell_without_params

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import array [as 别名]
def test_scan_parametrized_cell_without_params():
    @parametrized
    def cell(carry, x):
        return jnp.array([2]) * carry * x, jnp.array([2]) * carry * x

    @parametrized
    def rnn(inputs):
        _, outs = lax.scan(cell, jnp.zeros((2,)), inputs)
        return outs

    inputs = jnp.zeros((3,))

    params = rnn.init_parameters(inputs, key=PRNGKey(0))
    assert_parameters_equal(((),), params)

    outs = rnn.apply(params, inputs)

    assert (3, 2) == outs.shape 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:20,代码来源:test_core.py

示例4: test_scan_parametrized_nonflat_cell

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import array [as 别名]
def test_scan_parametrized_nonflat_cell():
    @parametrized
    def cell(carry, x):
        scale = parameter((2,), zeros)
        return {'a': scale * jnp.array([2]) * carry['a'] * x}, scale * jnp.array([2]) * carry[
            'a'] * x

    @parametrized
    def rnn(inputs):
        _, outs = lax.scan(cell, {'a': jnp.zeros((2,))}, inputs)
        return outs

    inputs = jnp.zeros((3,))

    rnn_params = rnn.init_parameters(inputs, key=PRNGKey(0))
    assert (2,) == rnn_params.cell.parameter.shape
    outs = rnn.apply(rnn_params, inputs)

    assert (3, 2) == outs.shape 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:21,代码来源:test_core.py

示例5: test_Batched

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import array [as 别名]
def test_Batched():
    out_dim = 1

    @parametrized
    def unbatched_dense(input):
        kernel = parameter((out_dim, input.shape[-1]), ones)
        bias = parameter((out_dim,), ones)
        return jnp.dot(kernel, input) + bias

    batch_size = 4

    unbatched_params = unbatched_dense.init_parameters(jnp.zeros(2), key=PRNGKey(0))
    out = unbatched_dense.apply(unbatched_params, jnp.ones(2))
    assert jnp.array([3.]) == out

    dense_apply = vmap(unbatched_dense.apply, (None, 0))
    out_batched_ = dense_apply(unbatched_params, jnp.ones((batch_size, 2)))
    assert jnp.array_equal(jnp.stack([out] * batch_size), out_batched_)

    dense = Batched(unbatched_dense)
    params = dense.init_parameters(jnp.ones((batch_size, 2)), key=PRNGKey(0))
    assert_parameters_equal((unbatched_params,), params)
    out_batched = dense.apply(params, jnp.ones((batch_size, 2)))
    assert jnp.array_equal(out_batched_, out_batched) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:26,代码来源:test_modules.py

示例6: model

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import array [as 别名]
def model(N, y=None):
    """
    :param int N: number of measurement times
    :param numpy.ndarray y: measured populations with shape (N, 2)
    """
    # initial population
    z_init = numpyro.sample("z_init", dist.LogNormal(jnp.log(10), 1), sample_shape=(2,))
    # measurement times
    ts = jnp.arange(float(N))
    # parameters alpha, beta, gamma, delta of dz_dt
    theta = numpyro.sample(
        "theta",
        dist.TruncatedNormal(low=0., loc=jnp.array([0.5, 0.05, 1.5, 0.05]),
                             scale=jnp.array([0.5, 0.05, 0.5, 0.05])))
    # integrate dz/dt, the result will have shape N x 2
    z = odeint(dz_dt, z_init, ts, theta, rtol=1e-5, atol=1e-3, mxstep=500)
    # measurement errors, we expect that measured hare has larger error than measured lynx
    sigma = numpyro.sample("sigma", dist.Exponential(jnp.array([1, 2])))
    # measured populations (in log scale)
    numpyro.sample("y", dist.Normal(jnp.log(z), sigma), obs=y) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:22,代码来源:ode.py

示例7: glmm

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import array [as 别名]
def glmm(dept, male, applications, admit=None):
    v_mu = numpyro.sample('v_mu', dist.Normal(0, jnp.array([4., 1.])))

    sigma = numpyro.sample('sigma', dist.HalfNormal(jnp.ones(2)))
    L_Rho = numpyro.sample('L_Rho', dist.LKJCholesky(2, concentration=2))
    scale_tril = sigma[..., jnp.newaxis] * L_Rho
    # non-centered parameterization
    num_dept = len(jnp.unique(dept))
    z = numpyro.sample('z', dist.Normal(jnp.zeros((num_dept, 2)), 1))
    v = jnp.dot(scale_tril, z.T).T

    logits = v_mu[0] + v[dept, 0] + (v_mu[1] + v[dept, 1]) * male
    if admit is None:
        # we use a Delta site to record probs for predictive distribution
        probs = expit(logits)
        numpyro.sample('probs', dist.Delta(probs), obs=probs)
    numpyro.sample('admit', dist.Binomial(applications, logits=logits), obs=admit) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:19,代码来源:ucbadmit.py

示例8: print_results

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import array [as 别名]
def print_results(posterior, dates):
    def _print_row(values, row_name=''):
        quantiles = jnp.array([0.2, 0.4, 0.5, 0.6, 0.8])
        row_name_fmt = '{:>8}'
        header_format = row_name_fmt + '{:>12}' * 5
        row_format = row_name_fmt + '{:>12.3f}' * 5
        columns = ['(p{})'.format(q * 100) for q in quantiles]
        q_values = jnp.quantile(values, quantiles, axis=0)
        print(header_format.format('', *columns))
        print(row_format.format(row_name, *q_values))
        print('\n')

    print('=' * 20, 'sigma', '=' * 20)
    _print_row(posterior['sigma'])
    print('=' * 20, 'nu', '=' * 20)
    _print_row(posterior['nu'])
    print('=' * 20, 'volatility', '=' * 20)
    for i in range(0, len(dates), 180):
        _print_row(jnp.exp(posterior['s'][:, i]), dates[i]) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:21,代码来源:stochastic_volatility.py

示例9: stirling_approx_tail

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import array [as 别名]
def stirling_approx_tail(k):
    precomputed = jnp.array([
        0.08106146679532726,
        0.04134069595540929,
        0.02767792568499834,
        0.02079067210376509,
        0.01664469118982119,
        0.01387612882307075,
        0.01189670994589177,
        0.01041126526197209,
        0.009255462182712733,
        0.008330563433362871,
    ])
    kp1 = k + 1
    kp1sq = (k + 1) ** 2
    return jnp.where(k < 10,
                     precomputed[k],
                     (1. / 12 - (1. / 360 - (1. / 1260) / kp1sq) / kp1sq) / kp1) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:20,代码来源:util.py

示例10: scan_wrapper

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import array [as 别名]
def scan_wrapper(f, init, xs, length, reverse, rng_key=None, substitute_stack=[]):

    def body_fn(wrapped_carry, x):
        i, rng_key, carry = wrapped_carry
        rng_key, subkey = random.split(rng_key) if rng_key is not None else (None, None)

        with handlers.block():
            seeded_fn = handlers.seed(f, subkey) if subkey is not None else f
            for subs_type, subs_map in substitute_stack:
                subs_fn = partial(_subs_wrapper, subs_map, i, length)
                if subs_type == 'condition':
                    seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn)
                elif subs_type == 'substitute':
                    seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn)

            with handlers.trace() as trace:
                carry, y = seeded_fn(carry, x)

        return (i + 1, rng_key, carry), (PytreeTrace(trace), y)

    if length is None:
        length = tree_flatten(xs)[0][0].shape[0]
    return lax.scan(body_fn, (jnp.array(0), rng_key, init), xs, length=length, reverse=reverse) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:25,代码来源:scan.py

示例11: test_unnormalized_normal_x64

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import array [as 别名]
def test_unnormalized_normal_x64(kernel_cls, dense_mass):
    true_mean, true_std = 1., 0.5
    warmup_steps, num_samples = (100000, 100000) if kernel_cls is SA else (1000, 8000)

    def potential_fn(z):
        return 0.5 * jnp.sum(((z - true_mean) / true_std) ** 2)

    init_params = jnp.array(0.)
    if kernel_cls is SA:
        kernel = SA(potential_fn=potential_fn, dense_mass=dense_mass)
    else:
        kernel = kernel_cls(potential_fn=potential_fn, trajectory_length=8, dense_mass=dense_mass)
    mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False)
    mcmc.run(random.PRNGKey(0), init_params=init_params)
    mcmc.print_summary()
    hmc_states = mcmc.get_samples()
    assert_allclose(jnp.mean(hmc_states), true_mean, rtol=0.07)
    assert_allclose(jnp.std(hmc_states), true_std, rtol=0.07)

    if 'JAX_ENABLE_X64' in os.environ:
        assert hmc_states.dtype == jnp.float64 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:23,代码来源:test_mcmc.py

示例12: test_beta_bernoulli_x64

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import array [as 别名]
def test_beta_bernoulli_x64(kernel_cls):
    warmup_steps, num_samples = (100000, 100000) if kernel_cls is SA else (500, 20000)

    def model(data):
        alpha = jnp.array([1.1, 1.1])
        beta = jnp.array([1.1, 1.1])
        p_latent = numpyro.sample('p_latent', dist.Beta(alpha, beta))
        numpyro.sample('obs', dist.Bernoulli(p_latent), obs=data)
        return p_latent

    true_probs = jnp.array([0.9, 0.1])
    data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), (1000, 2))
    if kernel_cls is SA:
        kernel = SA(model=model)
    else:
        kernel = kernel_cls(model=model, trajectory_length=0.1)
    mcmc = MCMC(kernel, num_warmup=warmup_steps, num_samples=num_samples, progress_bar=False)
    mcmc.run(random.PRNGKey(2), data)
    mcmc.print_summary()
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples['p_latent'], 0), true_probs, atol=0.05)

    if 'JAX_ENABLE_X64' in os.environ:
        assert samples['p_latent'].dtype == jnp.float64 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:26,代码来源:test_mcmc.py

示例13: test_dirichlet_categorical_x64

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import array [as 别名]
def test_dirichlet_categorical_x64(kernel_cls, dense_mass):
    warmup_steps, num_samples = 100, 20000

    def model(data):
        concentration = jnp.array([1.0, 1.0, 1.0])
        p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration))
        numpyro.sample('obs', dist.Categorical(p_latent), obs=data)
        return p_latent

    true_probs = jnp.array([0.1, 0.6, 0.3])
    data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000,))
    kernel = kernel_cls(model, trajectory_length=1., dense_mass=dense_mass)
    mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False)
    mcmc.run(random.PRNGKey(2), data)
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples['p_latent'], 0), true_probs, atol=0.02)

    if 'JAX_ENABLE_X64' in os.environ:
        assert samples['p_latent'].dtype == jnp.float64 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:21,代码来源:test_mcmc.py

示例14: test_prior_with_sample_shape

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import array [as 别名]
def test_prior_with_sample_shape():
    data = {
        "J": 8,
        "y": jnp.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]),
        "sigma": jnp.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]),
    }

    def schools_model():
        mu = numpyro.sample('mu', dist.Normal(0, 5))
        tau = numpyro.sample('tau', dist.HalfCauchy(5))
        theta = numpyro.sample('theta', dist.Normal(mu, tau), sample_shape=(data['J'],))
        numpyro.sample('obs', dist.Normal(theta, data['sigma']), obs=data['y'])

    num_samples = 500
    mcmc = MCMC(NUTS(schools_model), num_warmup=500, num_samples=num_samples)
    mcmc.run(random.PRNGKey(0))
    assert mcmc.get_samples()['theta'].shape == (num_samples, data['J']) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:19,代码来源:test_mcmc.py

示例15: test_functional_beta_bernoulli_x64

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import array [as 别名]
def test_functional_beta_bernoulli_x64(algo):
    warmup_steps, num_samples = 500, 20000

    def model(data):
        alpha = jnp.array([1.1, 1.1])
        beta = jnp.array([1.1, 1.1])
        p_latent = numpyro.sample('p_latent', dist.Beta(alpha, beta))
        numpyro.sample('obs', dist.Bernoulli(p_latent), obs=data)
        return p_latent

    true_probs = jnp.array([0.9, 0.1])
    data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), (1000, 2))
    init_params, potential_fn, constrain_fn, _ = initialize_model(random.PRNGKey(2), model, model_args=(data,))
    init_kernel, sample_kernel = hmc(potential_fn, algo=algo)
    hmc_state = init_kernel(init_params,
                            trajectory_length=1.,
                            num_warmup=warmup_steps)
    samples = fori_collect(0, num_samples, sample_kernel, hmc_state,
                           transform=lambda x: constrain_fn(x.z))
    assert_allclose(jnp.mean(samples['p_latent'], 0), true_probs, atol=0.05)

    if 'JAX_ENABLE_X64' in os.environ:
        assert samples['p_latent'].dtype == jnp.float64 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:25,代码来源:test_mcmc.py


注:本文中的jax.numpy.array方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。