本文整理汇总了Python中jax.numpy.ones方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.ones方法的具体用法?Python numpy.ones怎么用?Python numpy.ones使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类jax.numpy
的用法示例。
在下文中一共展示了numpy.ones方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: BatchNorm
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ones [as 别名]
def BatchNorm(axis=(0, 1, 2), epsilon=1e-5, center=True, scale=True,
beta_init=zeros, gamma_init=ones):
"""Layer construction function for a batch normalization layer."""
axis = (axis,) if np.isscalar(axis) else axis
@parametrized
def batch_norm(x):
ed = tuple(None if i in axis else slice(None) for i in range(np.ndim(x)))
mean, var = np.mean(x, axis, keepdims=True), fastvar(x, axis, keepdims=True)
z = (x - mean) / np.sqrt(var + epsilon)
shape = tuple(d for i, d in enumerate(x.shape) if i not in axis)
scaled = z * parameter(shape, gamma_init, 'gamma')[ed] if scale else z
return scaled + parameter(shape, beta_init, 'beta')[ed] if center else scaled
return batch_norm
示例2: ConvOrConvTranspose
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ones [as 别名]
def ConvOrConvTranspose(out_chan, filter_shape=(3, 3), strides=None, padding='SAME', init_scale=1.,
transpose=False):
strides = strides or (1,) * len(filter_shape)
def apply(inputs, V, g, b):
V = g * _l2_normalize(V, (0, 1, 2))
return (lax.conv_transpose if transpose else _conv)(inputs, V, strides, padding) - b
@parametrized
def conv_or_conv_transpose(inputs):
V = parameter(filter_shape + (inputs.shape[-1], out_chan), normal(.05), 'V')
example_out = apply(inputs, V=V, g=jnp.ones(out_chan), b=jnp.zeros(out_chan))
# TODO remove need for `.aval.val` when capturing variables in initializer function:
g = Parameter(lambda key: init_scale /
jnp.sqrt(jnp.var(example_out.aval.val, (0, 1, 2)) + 1e-10), 'g')()
b = Parameter(lambda key: jnp.mean(example_out.aval.val, (0, 1, 2)) * g.aval.val, 'b')()
return apply(inputs, V, b, g)
return conv_or_conv_transpose
示例3: test_Regularized
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ones [as 别名]
def test_Regularized():
@parametrized
def loss(inputs):
a = parameter((), ones, 'a')
b = parameter((), lambda key, shape: 2 * jnp.ones(shape), 'b')
return a + b
reg_loss = Regularized(loss, regularizer=lambda x: x * x)
inputs = jnp.zeros(())
params = reg_loss.init_parameters(inputs, key=PRNGKey(0))
assert jnp.array_equal(jnp.ones(()), params.model.a)
assert jnp.array_equal(2 * jnp.ones(()), params.model.b)
reg_loss_out = reg_loss.apply(params, inputs)
assert 1 + 2 + 1 + 4 == reg_loss_out
示例4: test_L2Regularized
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ones [as 别名]
def test_L2Regularized():
@parametrized
def loss(inputs):
a = parameter((), ones, 'a')
b = parameter((), lambda key, shape: 2 * jnp.ones(shape), 'b')
return a + b
reg_loss = L2Regularized(loss, scale=2)
inputs = jnp.zeros(())
params = reg_loss.init_parameters(inputs, key=PRNGKey(0))
assert jnp.array_equal(jnp.ones(()), params.model.a)
assert jnp.array_equal(2 * jnp.ones(()), params.model.b)
reg_loss_out = reg_loss.apply(params, inputs)
assert 1 + 2 + 1 + 4 == reg_loss_out
示例5: test_Batched
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ones [as 别名]
def test_Batched():
out_dim = 1
@parametrized
def unbatched_dense(input):
kernel = parameter((out_dim, input.shape[-1]), ones)
bias = parameter((out_dim,), ones)
return jnp.dot(kernel, input) + bias
batch_size = 4
unbatched_params = unbatched_dense.init_parameters(jnp.zeros(2), key=PRNGKey(0))
out = unbatched_dense.apply(unbatched_params, jnp.ones(2))
assert jnp.array([3.]) == out
dense_apply = vmap(unbatched_dense.apply, (None, 0))
out_batched_ = dense_apply(unbatched_params, jnp.ones((batch_size, 2)))
assert jnp.array_equal(jnp.stack([out] * batch_size), out_batched_)
dense = Batched(unbatched_dense)
params = dense.init_parameters(jnp.ones((batch_size, 2)), key=PRNGKey(0))
assert_parameters_equal((unbatched_params,), params)
out_batched = dense.apply(params, jnp.ones((batch_size, 2)))
assert jnp.array_equal(out_batched_, out_batched)
示例6: model
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ones [as 别名]
def model(X, Y, D_H):
D_X, D_Y = X.shape[1], 1
# sample first layer (we put unit normal priors on all weights)
w1 = numpyro.sample("w1", dist.Normal(jnp.zeros((D_X, D_H)), jnp.ones((D_X, D_H)))) # D_X D_H
z1 = nonlin(jnp.matmul(X, w1)) # N D_H <= first layer of activations
# sample second layer
w2 = numpyro.sample("w2", dist.Normal(jnp.zeros((D_H, D_H)), jnp.ones((D_H, D_H)))) # D_H D_H
z2 = nonlin(jnp.matmul(z1, w2)) # N D_H <= second layer of activations
# sample final layer of weights and neural network output
w3 = numpyro.sample("w3", dist.Normal(jnp.zeros((D_H, D_Y)), jnp.ones((D_H, D_Y)))) # D_H D_Y
z3 = jnp.matmul(z2, w3) # N D_Y <= output of the neural network
# we put a prior on the observation noise
prec_obs = numpyro.sample("prec_obs", dist.Gamma(3.0, 1.0))
sigma_obs = 1.0 / jnp.sqrt(prec_obs)
# observe data
numpyro.sample("Y", dist.Normal(z3, sigma_obs), obs=Y)
# helper function for HMC inference
示例7: make_dataset
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ones [as 别名]
def make_dataset(rng_key) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Make simulated dataset where potential customers who get a
sales calls have ~2% higher chance of making another purchase.
"""
key1, key2, key3 = random.split(rng_key, 3)
num_calls = 51342
num_no_calls = 48658
made_purchase_got_called = dist.Bernoulli(0.084).sample(key1, sample_shape=(num_calls,))
made_purchase_no_calls = dist.Bernoulli(0.061).sample(key2, sample_shape=(num_no_calls,))
made_purchase = jnp.concatenate([made_purchase_got_called, made_purchase_no_calls])
is_female = dist.Bernoulli(0.5).sample(key3, sample_shape=(num_calls + num_no_calls,))
got_called = jnp.concatenate([jnp.ones(num_calls), jnp.zeros(num_no_calls)])
design_matrix = jnp.hstack([jnp.ones((num_no_calls + num_calls, 1)),
got_called.reshape(-1, 1),
is_female.reshape(-1, 1)])
return design_matrix, made_purchase
示例8: _load_dataset
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ones [as 别名]
def _load_dataset():
_, fetch = load_dataset(COVTYPE, shuffle=False)
features, labels = fetch()
# normalize features and add intercept
features = (features - features.mean(0)) / features.std(0)
features = jnp.hstack([features, jnp.ones((features.shape[0], 1))])
# make binary feature
_, counts = jnp.unique(labels, return_counts=True)
specific_category = jnp.argmax(counts)
labels = (labels == specific_category)
N, dim = features.shape
print("Data shape:", features.shape)
print("Label distribution: {} has label 1, {} has label 0"
.format(labels.sum(), N - labels.sum()))
return features, labels
示例9: _multinomial
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ones [as 别名]
def _multinomial(key, p, n, n_max, shape=()):
if jnp.shape(n) != jnp.shape(p)[:-1]:
broadcast_shape = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1])
n = jnp.broadcast_to(n, broadcast_shape)
p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:])
shape = shape or p.shape[:-1]
# get indices from categorical distribution then gather the result
indices = categorical(key, p, (n_max,) + shape)
# mask out values when counts is heterogeneous
if jnp.ndim(n) > 0:
mask = promote_shapes(jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,))[0]
mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype)
excess = jnp.concatenate([jnp.expand_dims(n_max - n, -1), jnp.zeros(jnp.shape(n) + (p.shape[-1] - 1,))], -1)
else:
mask = 1
excess = 0
# NB: we transpose to move batch shape to the front
indices_2D = (jnp.reshape(indices * mask, (n_max, -1,))).T
samples_2D = vmap(_scatter_add_one, (0, 0, 0))(jnp.zeros((indices_2D.shape[0], p.shape[-1]),
dtype=indices.dtype),
jnp.expand_dims(indices_2D, axis=-1),
jnp.ones(indices_2D.shape, dtype=indices.dtype))
return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess
示例10: test_beta_bernoulli
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ones [as 别名]
def test_beta_bernoulli(auto_class):
data = jnp.array([[1.0] * 8 + [0.0] * 2,
[1.0] * 4 + [0.0] * 6]).T
def model(data):
f = numpyro.sample('beta', dist.Beta(jnp.ones(2), jnp.ones(2)))
numpyro.sample('obs', dist.Bernoulli(f), obs=data)
adam = optim.Adam(0.01)
guide = auto_class(model, init_strategy=init_strategy)
svi = SVI(model, guide, adam, ELBO())
svi_state = svi.init(random.PRNGKey(1), data)
def body_fn(i, val):
svi_state, loss = svi.update(val, data)
return svi_state
svi_state = fori_loop(0, 3000, body_fn, svi_state)
params = svi.get_params(svi_state)
true_coefs = (jnp.sum(data, axis=0) + 1) / (data.shape[0] + 2)
# test .sample_posterior method
posterior_samples = guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(1000,))
assert_allclose(jnp.mean(posterior_samples['beta'], 0), true_coefs, atol=0.05)
示例11: test_chain
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ones [as 别名]
def test_chain(use_init_params, chain_method):
N, dim = 3000, 3
num_chains = 2
num_warmup, num_samples = 5000, 5000
data = random.normal(random.PRNGKey(0), (N, dim))
true_coefs = jnp.arange(1., dim + 1.)
logits = jnp.sum(true_coefs * data, axis=-1)
labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))
def model(labels):
coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
logits = jnp.sum(coefs * data, axis=-1)
return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)
kernel = NUTS(model=model)
mcmc = MCMC(kernel, num_warmup, num_samples, num_chains=num_chains)
mcmc.chain_method = chain_method
init_params = None if not use_init_params else \
{'coefs': jnp.tile(jnp.ones(dim), num_chains).reshape(num_chains, dim)}
mcmc.run(random.PRNGKey(2), labels, init_params=init_params)
samples_flat = mcmc.get_samples()
assert samples_flat['coefs'].shape[0] == num_chains * num_samples
samples = mcmc.get_samples(group_by_chain=True)
assert samples['coefs'].shape[:2] == (num_chains, num_samples)
assert_allclose(jnp.mean(samples_flat['coefs'], 0), true_coefs, atol=0.21)
示例12: test_gaussian_subposterior
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ones [as 别名]
def test_gaussian_subposterior(method, diagonal):
D = 10
n_samples = 10000
n_draws = 9000
n_subs = 8
mean = jnp.arange(D)
cov = jnp.ones((D, D)) * 0.9 + jnp.identity(D) * 0.1
subcov = n_subs * cov # subposterior's covariance
subposteriors = list(dist.MultivariateNormal(mean, subcov).sample(
random.PRNGKey(1), (n_subs, n_samples)))
draws = method(subposteriors, n_draws, diagonal=diagonal)
assert draws.shape == (n_draws, D)
assert_allclose(jnp.mean(draws, axis=0), mean, atol=0.03)
if diagonal:
assert_allclose(jnp.var(draws, axis=0), jnp.diag(cov), atol=0.05)
else:
assert_allclose(jnp.cov(draws.T), cov, atol=0.05)
示例13: test_mask
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ones [as 别名]
def test_mask(mask_last, use_jit):
N = 10
mask = np.ones(N, dtype=np.bool)
mask[-mask_last] = 0
def model(data, mask):
with numpyro.plate('N', N):
x = numpyro.sample('x', dist.Normal(0, 1))
with handlers.mask(mask_array=mask):
numpyro.sample('y', dist.Delta(x, log_density=1.))
with handlers.scale(scale=2):
numpyro.sample('obs', dist.Normal(x, 1), obs=data)
data = random.normal(random.PRNGKey(0), (N,))
x = random.normal(random.PRNGKey(1), (N,))
if use_jit:
log_joint = jit(lambda *args: log_density(*args)[0], static_argnums=(0,))(
model, (data, mask), {}, {'x': x, 'y': x})
else:
log_joint = log_density(model, (data, mask), {}, {'x': x, 'y': x})[0]
log_prob_x = dist.Normal(0, 1).log_prob(x)
log_prob_y = mask
log_prob_z = dist.Normal(x, 1).log_prob(data)
expected = (log_prob_x + jnp.where(mask, log_prob_y + 2 * log_prob_z, 0.)).sum()
assert_allclose(log_joint, expected, atol=1e-4)
示例14: test_numpyrooptim_no_double_jit
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ones [as 别名]
def test_numpyrooptim_no_double_jit(optim_class, args):
opt = optim_class(*args)
state = opt.init(jnp.zeros(10))
my_fn_calls = 0
@jit
def my_fn(state, g):
nonlocal my_fn_calls
my_fn_calls += 1
state = opt.update(g, state)
return state
state = my_fn(state, jnp.ones(10)*1.)
state = my_fn(state, jnp.ones(10)*2.)
state = my_fn(state, jnp.ones(10)*3.)
assert my_fn_calls == 1
示例15: test_value
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ones [as 别名]
def test_value(x_shape, i_shape, j_shape, event_shape):
x = jnp.array(np.random.rand(*(x_shape + (5, 6) + event_shape)))
i = dist.Categorical(jnp.ones((5,))).sample(random.PRNGKey(1), i_shape)
j = dist.Categorical(jnp.ones((6,))).sample(random.PRNGKey(2), j_shape)
if event_shape:
actual = Vindex(x)[..., i, j, :]
else:
actual = Vindex(x)[..., i, j]
shape = lax.broadcast_shapes(x_shape, i_shape, j_shape)
x = jnp.broadcast_to(x, shape + (5, 6) + event_shape)
i = jnp.broadcast_to(i, shape)
j = jnp.broadcast_to(j, shape)
expected = np.empty(shape + event_shape, dtype=x.dtype)
for ind in (itertools.product(*map(range, shape)) if shape else [()]):
expected[ind] = x[ind + (i[ind].item(), j[ind].item())]
assert jnp.all(actual == jnp.array(expected, dtype=x.dtype))