本文整理汇总了Python中jax.numpy.array方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.array方法的具体用法?Python numpy.array怎么用?Python numpy.array使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类jax.numpy
的用法示例。
在下文中一共展示了numpy.array方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: __init__
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import array [as 别名]
def __init__(self, *dim):
"""
>>> Id(1)
Tensor(dom=Dim(1), cod=Dim(1), array=[1])
>>> list(Id(2).array.flatten())
[1.0, 0.0, 0.0, 1.0]
>>> Id(2).array.shape
(2, 2)
>>> list(Id(2, 2).array.flatten())[:8]
[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]
>>> list(Id(2, 2).array.flatten())[8:]
[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0]
"""
dim = dim[0] if isinstance(dim[0], Dim) else Dim(*dim)
array = functools.reduce(
lambda a, x: np.tensordot(a, np.identity(x), 0)
if a.shape else np.identity(x), dim, np.array(1))
array = np.moveaxis(
array, [2 * i for i in range(len(dim))], list(range(len(dim))))
super().__init__(dim, dim, array)
示例2: test_wrapped_policy_continuous
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import array [as 别名]
def test_wrapped_policy_continuous(self, vocab_size):
precision = 3
n_controls = 2
n_actions = 4
gin.bind_parameter('BoxSpaceSerializer.precision', precision)
obs = np.array([[[1.5, 2], [-0.3, 1.23], [0.84, 0.07], [0.01, 0.66]]])
act = np.array([[[0, 1], [2, 0], [1, 3]]])
wrapped_policy = serialization_utils.wrap_policy(
TestModel(extra_dim=vocab_size), # pylint: disable=no-value-for-parameter
observation_space=gym.spaces.Box(shape=(2,), low=-2, high=2),
action_space=gym.spaces.MultiDiscrete([n_actions] * n_controls),
vocab_size=vocab_size,
)
example = (obs, act)
wrapped_policy.init(shapes.signature(example))
(act_logits, values) = wrapped_policy(example)
self.assertEqual(act_logits.shape, obs.shape[:2] + (n_controls, n_actions))
self.assertEqual(values.shape, obs.shape[:2])
示例3: test_scan_parametrized_cell_without_params
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import array [as 别名]
def test_scan_parametrized_cell_without_params():
@parametrized
def cell(carry, x):
return jnp.array([2]) * carry * x, jnp.array([2]) * carry * x
@parametrized
def rnn(inputs):
_, outs = lax.scan(cell, jnp.zeros((2,)), inputs)
return outs
inputs = jnp.zeros((3,))
params = rnn.init_parameters(inputs, key=PRNGKey(0))
assert_parameters_equal(((),), params)
outs = rnn.apply(params, inputs)
assert (3, 2) == outs.shape
示例4: test_scan_parametrized_nonflat_cell
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import array [as 别名]
def test_scan_parametrized_nonflat_cell():
@parametrized
def cell(carry, x):
scale = parameter((2,), zeros)
return {'a': scale * jnp.array([2]) * carry['a'] * x}, scale * jnp.array([2]) * carry[
'a'] * x
@parametrized
def rnn(inputs):
_, outs = lax.scan(cell, {'a': jnp.zeros((2,))}, inputs)
return outs
inputs = jnp.zeros((3,))
rnn_params = rnn.init_parameters(inputs, key=PRNGKey(0))
assert (2,) == rnn_params.cell.parameter.shape
outs = rnn.apply(rnn_params, inputs)
assert (3, 2) == outs.shape
示例5: test_Batched
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import array [as 别名]
def test_Batched():
out_dim = 1
@parametrized
def unbatched_dense(input):
kernel = parameter((out_dim, input.shape[-1]), ones)
bias = parameter((out_dim,), ones)
return jnp.dot(kernel, input) + bias
batch_size = 4
unbatched_params = unbatched_dense.init_parameters(jnp.zeros(2), key=PRNGKey(0))
out = unbatched_dense.apply(unbatched_params, jnp.ones(2))
assert jnp.array([3.]) == out
dense_apply = vmap(unbatched_dense.apply, (None, 0))
out_batched_ = dense_apply(unbatched_params, jnp.ones((batch_size, 2)))
assert jnp.array_equal(jnp.stack([out] * batch_size), out_batched_)
dense = Batched(unbatched_dense)
params = dense.init_parameters(jnp.ones((batch_size, 2)), key=PRNGKey(0))
assert_parameters_equal((unbatched_params,), params)
out_batched = dense.apply(params, jnp.ones((batch_size, 2)))
assert jnp.array_equal(out_batched_, out_batched)
示例6: model
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import array [as 别名]
def model(N, y=None):
"""
:param int N: number of measurement times
:param numpy.ndarray y: measured populations with shape (N, 2)
"""
# initial population
z_init = numpyro.sample("z_init", dist.LogNormal(jnp.log(10), 1), sample_shape=(2,))
# measurement times
ts = jnp.arange(float(N))
# parameters alpha, beta, gamma, delta of dz_dt
theta = numpyro.sample(
"theta",
dist.TruncatedNormal(low=0., loc=jnp.array([0.5, 0.05, 1.5, 0.05]),
scale=jnp.array([0.5, 0.05, 0.5, 0.05])))
# integrate dz/dt, the result will have shape N x 2
z = odeint(dz_dt, z_init, ts, theta, rtol=1e-5, atol=1e-3, mxstep=500)
# measurement errors, we expect that measured hare has larger error than measured lynx
sigma = numpyro.sample("sigma", dist.Exponential(jnp.array([1, 2])))
# measured populations (in log scale)
numpyro.sample("y", dist.Normal(jnp.log(z), sigma), obs=y)
示例7: glmm
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import array [as 别名]
def glmm(dept, male, applications, admit=None):
v_mu = numpyro.sample('v_mu', dist.Normal(0, jnp.array([4., 1.])))
sigma = numpyro.sample('sigma', dist.HalfNormal(jnp.ones(2)))
L_Rho = numpyro.sample('L_Rho', dist.LKJCholesky(2, concentration=2))
scale_tril = sigma[..., jnp.newaxis] * L_Rho
# non-centered parameterization
num_dept = len(jnp.unique(dept))
z = numpyro.sample('z', dist.Normal(jnp.zeros((num_dept, 2)), 1))
v = jnp.dot(scale_tril, z.T).T
logits = v_mu[0] + v[dept, 0] + (v_mu[1] + v[dept, 1]) * male
if admit is None:
# we use a Delta site to record probs for predictive distribution
probs = expit(logits)
numpyro.sample('probs', dist.Delta(probs), obs=probs)
numpyro.sample('admit', dist.Binomial(applications, logits=logits), obs=admit)
示例8: print_results
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import array [as 别名]
def print_results(posterior, dates):
def _print_row(values, row_name=''):
quantiles = jnp.array([0.2, 0.4, 0.5, 0.6, 0.8])
row_name_fmt = '{:>8}'
header_format = row_name_fmt + '{:>12}' * 5
row_format = row_name_fmt + '{:>12.3f}' * 5
columns = ['(p{})'.format(q * 100) for q in quantiles]
q_values = jnp.quantile(values, quantiles, axis=0)
print(header_format.format('', *columns))
print(row_format.format(row_name, *q_values))
print('\n')
print('=' * 20, 'sigma', '=' * 20)
_print_row(posterior['sigma'])
print('=' * 20, 'nu', '=' * 20)
_print_row(posterior['nu'])
print('=' * 20, 'volatility', '=' * 20)
for i in range(0, len(dates), 180):
_print_row(jnp.exp(posterior['s'][:, i]), dates[i])
示例9: stirling_approx_tail
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import array [as 别名]
def stirling_approx_tail(k):
precomputed = jnp.array([
0.08106146679532726,
0.04134069595540929,
0.02767792568499834,
0.02079067210376509,
0.01664469118982119,
0.01387612882307075,
0.01189670994589177,
0.01041126526197209,
0.009255462182712733,
0.008330563433362871,
])
kp1 = k + 1
kp1sq = (k + 1) ** 2
return jnp.where(k < 10,
precomputed[k],
(1. / 12 - (1. / 360 - (1. / 1260) / kp1sq) / kp1sq) / kp1)
示例10: scan_wrapper
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import array [as 别名]
def scan_wrapper(f, init, xs, length, reverse, rng_key=None, substitute_stack=[]):
def body_fn(wrapped_carry, x):
i, rng_key, carry = wrapped_carry
rng_key, subkey = random.split(rng_key) if rng_key is not None else (None, None)
with handlers.block():
seeded_fn = handlers.seed(f, subkey) if subkey is not None else f
for subs_type, subs_map in substitute_stack:
subs_fn = partial(_subs_wrapper, subs_map, i, length)
if subs_type == 'condition':
seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn)
elif subs_type == 'substitute':
seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn)
with handlers.trace() as trace:
carry, y = seeded_fn(carry, x)
return (i + 1, rng_key, carry), (PytreeTrace(trace), y)
if length is None:
length = tree_flatten(xs)[0][0].shape[0]
return lax.scan(body_fn, (jnp.array(0), rng_key, init), xs, length=length, reverse=reverse)
示例11: test_unnormalized_normal_x64
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import array [as 别名]
def test_unnormalized_normal_x64(kernel_cls, dense_mass):
true_mean, true_std = 1., 0.5
warmup_steps, num_samples = (100000, 100000) if kernel_cls is SA else (1000, 8000)
def potential_fn(z):
return 0.5 * jnp.sum(((z - true_mean) / true_std) ** 2)
init_params = jnp.array(0.)
if kernel_cls is SA:
kernel = SA(potential_fn=potential_fn, dense_mass=dense_mass)
else:
kernel = kernel_cls(potential_fn=potential_fn, trajectory_length=8, dense_mass=dense_mass)
mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False)
mcmc.run(random.PRNGKey(0), init_params=init_params)
mcmc.print_summary()
hmc_states = mcmc.get_samples()
assert_allclose(jnp.mean(hmc_states), true_mean, rtol=0.07)
assert_allclose(jnp.std(hmc_states), true_std, rtol=0.07)
if 'JAX_ENABLE_X64' in os.environ:
assert hmc_states.dtype == jnp.float64
示例12: test_beta_bernoulli_x64
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import array [as 别名]
def test_beta_bernoulli_x64(kernel_cls):
warmup_steps, num_samples = (100000, 100000) if kernel_cls is SA else (500, 20000)
def model(data):
alpha = jnp.array([1.1, 1.1])
beta = jnp.array([1.1, 1.1])
p_latent = numpyro.sample('p_latent', dist.Beta(alpha, beta))
numpyro.sample('obs', dist.Bernoulli(p_latent), obs=data)
return p_latent
true_probs = jnp.array([0.9, 0.1])
data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), (1000, 2))
if kernel_cls is SA:
kernel = SA(model=model)
else:
kernel = kernel_cls(model=model, trajectory_length=0.1)
mcmc = MCMC(kernel, num_warmup=warmup_steps, num_samples=num_samples, progress_bar=False)
mcmc.run(random.PRNGKey(2), data)
mcmc.print_summary()
samples = mcmc.get_samples()
assert_allclose(jnp.mean(samples['p_latent'], 0), true_probs, atol=0.05)
if 'JAX_ENABLE_X64' in os.environ:
assert samples['p_latent'].dtype == jnp.float64
示例13: test_dirichlet_categorical_x64
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import array [as 别名]
def test_dirichlet_categorical_x64(kernel_cls, dense_mass):
warmup_steps, num_samples = 100, 20000
def model(data):
concentration = jnp.array([1.0, 1.0, 1.0])
p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration))
numpyro.sample('obs', dist.Categorical(p_latent), obs=data)
return p_latent
true_probs = jnp.array([0.1, 0.6, 0.3])
data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000,))
kernel = kernel_cls(model, trajectory_length=1., dense_mass=dense_mass)
mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False)
mcmc.run(random.PRNGKey(2), data)
samples = mcmc.get_samples()
assert_allclose(jnp.mean(samples['p_latent'], 0), true_probs, atol=0.02)
if 'JAX_ENABLE_X64' in os.environ:
assert samples['p_latent'].dtype == jnp.float64
示例14: test_prior_with_sample_shape
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import array [as 别名]
def test_prior_with_sample_shape():
data = {
"J": 8,
"y": jnp.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]),
"sigma": jnp.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]),
}
def schools_model():
mu = numpyro.sample('mu', dist.Normal(0, 5))
tau = numpyro.sample('tau', dist.HalfCauchy(5))
theta = numpyro.sample('theta', dist.Normal(mu, tau), sample_shape=(data['J'],))
numpyro.sample('obs', dist.Normal(theta, data['sigma']), obs=data['y'])
num_samples = 500
mcmc = MCMC(NUTS(schools_model), num_warmup=500, num_samples=num_samples)
mcmc.run(random.PRNGKey(0))
assert mcmc.get_samples()['theta'].shape == (num_samples, data['J'])
示例15: test_functional_beta_bernoulli_x64
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import array [as 别名]
def test_functional_beta_bernoulli_x64(algo):
warmup_steps, num_samples = 500, 20000
def model(data):
alpha = jnp.array([1.1, 1.1])
beta = jnp.array([1.1, 1.1])
p_latent = numpyro.sample('p_latent', dist.Beta(alpha, beta))
numpyro.sample('obs', dist.Bernoulli(p_latent), obs=data)
return p_latent
true_probs = jnp.array([0.9, 0.1])
data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), (1000, 2))
init_params, potential_fn, constrain_fn, _ = initialize_model(random.PRNGKey(2), model, model_args=(data,))
init_kernel, sample_kernel = hmc(potential_fn, algo=algo)
hmc_state = init_kernel(init_params,
trajectory_length=1.,
num_warmup=warmup_steps)
samples = fori_collect(0, num_samples, sample_kernel, hmc_state,
transform=lambda x: constrain_fn(x.z))
assert_allclose(jnp.mean(samples['p_latent'], 0), true_probs, atol=0.05)
if 'JAX_ENABLE_X64' in os.environ:
assert samples['p_latent'].dtype == jnp.float64