本文整理汇总了Python中jax.numpy.mean方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.mean方法的具体用法?Python numpy.mean怎么用?Python numpy.mean使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类jax.numpy
的用法示例。
在下文中一共展示了numpy.mean方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: BatchNorm
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import mean [as 别名]
def BatchNorm(axis=(0, 1, 2), epsilon=1e-5, center=True, scale=True,
beta_init=zeros, gamma_init=ones):
"""Layer construction function for a batch normalization layer."""
axis = (axis,) if np.isscalar(axis) else axis
@parametrized
def batch_norm(x):
ed = tuple(None if i in axis else slice(None) for i in range(np.ndim(x)))
mean, var = np.mean(x, axis, keepdims=True), fastvar(x, axis, keepdims=True)
z = (x - mean) / np.sqrt(var + epsilon)
shape = tuple(d for i, d in enumerate(x.shape) if i not in axis)
scaled = z * parameter(shape, gamma_init, 'gamma')[ed] if scale else z
return scaled + parameter(shape, beta_init, 'beta')[ed] if center else scaled
return batch_norm
示例2: ConvOrConvTranspose
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import mean [as 别名]
def ConvOrConvTranspose(out_chan, filter_shape=(3, 3), strides=None, padding='SAME', init_scale=1.,
transpose=False):
strides = strides or (1,) * len(filter_shape)
def apply(inputs, V, g, b):
V = g * _l2_normalize(V, (0, 1, 2))
return (lax.conv_transpose if transpose else _conv)(inputs, V, strides, padding) - b
@parametrized
def conv_or_conv_transpose(inputs):
V = parameter(filter_shape + (inputs.shape[-1], out_chan), normal(.05), 'V')
example_out = apply(inputs, V=V, g=jnp.ones(out_chan), b=jnp.zeros(out_chan))
# TODO remove need for `.aval.val` when capturing variables in initializer function:
g = Parameter(lambda key: init_scale /
jnp.sqrt(jnp.var(example_out.aval.val, (0, 1, 2)) + 1e-10), 'g')()
b = Parameter(lambda key: jnp.mean(example_out.aval.val, (0, 1, 2)) * g.aval.val, 'b')()
return apply(inputs, V, b, g)
return conv_or_conv_transpose
示例3: test_readme
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import mean [as 别名]
def test_readme():
net = Sequential(Dense(1024), relu, Dense(1024), relu, Dense(4), log_softmax)
@parametrized
def loss(inputs, targets):
return -jnp.mean(net(inputs) * targets)
def next_batch(): return jnp.zeros((3, 784)), jnp.zeros((3, 4))
params = loss.init_parameters(*next_batch(), key=PRNGKey(0))
print(params.sequential.dense2.bias) # [-0.01101029, -0.00749435, -0.00952365, 0.00493979]
assert jnp.allclose([-0.01101029, -0.00749435, -0.00952365, 0.00493979],
params.sequential.dense2.bias)
out = loss.apply(params, *next_batch())
assert () == out.shape
out_ = loss.apply(params, *next_batch(), jit=True)
assert out.shape == out_.shape
示例4: test_mnist_vae
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import mean [as 别名]
def test_mnist_vae():
@parametrized
def encode(input):
input = Sequential(Dense(5), relu, Dense(5), relu)(input)
mean = Dense(10)(input)
variance = Sequential(Dense(10), softplus)(input)
return mean, variance
decode = Sequential(Dense(5), relu, Dense(5), relu, Dense(5 * 5))
@parametrized
def elbo(key, images):
mu_z, sigmasq_z = encode(images)
logits_x = decode(gaussian_sample(key, mu_z, sigmasq_z))
return bernoulli_logpdf(logits_x, images) - gaussian_kl(mu_z, sigmasq_z)
params = elbo.init_parameters(PRNGKey(0), jnp.zeros((32, 5 * 5)), key=PRNGKey(0))
assert (5, 10) == params.encode.sequential1.dense.kernel.shape
示例5: get_data
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import mean [as 别名]
def get_data(N=50, D_X=3, sigma_obs=0.05, N_test=500):
D_Y = 1 # create 1d outputs
np.random.seed(0)
X = jnp.linspace(-1, 1, N)
X = jnp.power(X[:, np.newaxis], jnp.arange(D_X))
W = 0.5 * np.random.randn(D_X)
Y = jnp.dot(X, W) + 0.5 * jnp.power(0.5 + X[:, 1], 2.0) * jnp.sin(4.0 * X[:, 1])
Y += sigma_obs * np.random.randn(N)
Y = Y[:, np.newaxis]
Y -= jnp.mean(Y)
Y /= jnp.std(Y)
assert X.shape == (N, D_X)
assert Y.shape == (N, D_Y)
X_test = jnp.linspace(-1.3, 1.3, N_test)
X_test = jnp.power(X_test[:, np.newaxis], jnp.arange(D_X))
return X, Y, X_test
示例6: main
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import mean [as 别名]
def main(args):
_, fetch_train = load_dataset(UCBADMIT, split='train', shuffle=False)
dept, male, applications, admit = fetch_train()
rng_key, rng_key_predict = random.split(random.PRNGKey(1))
zs = run_inference(dept, male, applications, admit, rng_key, args)
pred_probs = Predictive(glmm, zs)(rng_key_predict, dept, male, applications)['probs']
header = '=' * 30 + 'glmm - TRAIN' + '=' * 30
print_results(header, pred_probs, dept, male, admit / applications)
# make plots
fig, ax = plt.subplots(1, 1)
ax.plot(range(1, 13), admit / applications, "o", ms=7, label="actual rate")
ax.errorbar(range(1, 13), jnp.mean(pred_probs, 0), jnp.std(pred_probs, 0),
fmt="o", c="k", mfc="none", ms=7, elinewidth=1, label=r"mean $\pm$ std")
ax.plot(range(1, 13), jnp.percentile(pred_probs, 5, 0), "k+")
ax.plot(range(1, 13), jnp.percentile(pred_probs, 95, 0), "k+")
ax.set(xlabel="cases", ylabel="admit rate", title="Posterior Predictive Check with 90% CI")
ax.legend()
plt.savefig("ucbadmit_plot.pdf")
plt.tight_layout()
示例7: get_data
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import mean [as 别名]
def get_data(N=20, S=2, P=10, sigma_obs=0.05):
assert S < P and P > 1 and S > 0
np.random.seed(0)
X = np.random.randn(N, P)
# generate S coefficients with non-negligible magnitude
W = 0.5 + 2.5 * np.random.rand(S)
# generate data using the S coefficients and a single pairwise interaction
Y = np.sum(X[:, 0:S] * W, axis=-1) + X[:, 0] * X[:, 1] + sigma_obs * np.random.randn(N)
Y -= jnp.mean(Y)
Y_std = jnp.std(Y)
assert X.shape == (N, P)
assert Y.shape == (N,)
return X, Y / Y_std, W / Y_std, 1.0 / Y_std
# Helper function for analyzing the posterior statistics for coefficient theta_i
示例8: test_beta_bernoulli
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import mean [as 别名]
def test_beta_bernoulli(auto_class):
data = jnp.array([[1.0] * 8 + [0.0] * 2,
[1.0] * 4 + [0.0] * 6]).T
def model(data):
f = numpyro.sample('beta', dist.Beta(jnp.ones(2), jnp.ones(2)))
numpyro.sample('obs', dist.Bernoulli(f), obs=data)
adam = optim.Adam(0.01)
guide = auto_class(model, init_strategy=init_strategy)
svi = SVI(model, guide, adam, ELBO())
svi_state = svi.init(random.PRNGKey(1), data)
def body_fn(i, val):
svi_state, loss = svi.update(val, data)
return svi_state
svi_state = fori_loop(0, 3000, body_fn, svi_state)
params = svi.get_params(svi_state)
true_coefs = (jnp.sum(data, axis=0) + 1) / (data.shape[0] + 2)
# test .sample_posterior method
posterior_samples = guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(1000,))
assert_allclose(jnp.mean(posterior_samples['beta'], 0), true_coefs, atol=0.05)
示例9: test_unnormalized_normal_x64
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import mean [as 别名]
def test_unnormalized_normal_x64(kernel_cls, dense_mass):
true_mean, true_std = 1., 0.5
warmup_steps, num_samples = (100000, 100000) if kernel_cls is SA else (1000, 8000)
def potential_fn(z):
return 0.5 * jnp.sum(((z - true_mean) / true_std) ** 2)
init_params = jnp.array(0.)
if kernel_cls is SA:
kernel = SA(potential_fn=potential_fn, dense_mass=dense_mass)
else:
kernel = kernel_cls(potential_fn=potential_fn, trajectory_length=8, dense_mass=dense_mass)
mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False)
mcmc.run(random.PRNGKey(0), init_params=init_params)
mcmc.print_summary()
hmc_states = mcmc.get_samples()
assert_allclose(jnp.mean(hmc_states), true_mean, rtol=0.07)
assert_allclose(jnp.std(hmc_states), true_std, rtol=0.07)
if 'JAX_ENABLE_X64' in os.environ:
assert hmc_states.dtype == jnp.float64
示例10: test_correlated_mvn
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import mean [as 别名]
def test_correlated_mvn():
# This requires dense mass matrix estimation.
D = 5
warmup_steps, num_samples = 5000, 8000
true_mean = 0.
a = jnp.tril(0.5 * jnp.fliplr(jnp.eye(D)) + 0.1 * jnp.exp(random.normal(random.PRNGKey(0), shape=(D, D))))
true_cov = jnp.dot(a, a.T)
true_prec = jnp.linalg.inv(true_cov)
def potential_fn(z):
return 0.5 * jnp.dot(z.T, jnp.dot(true_prec, z))
init_params = jnp.zeros(D)
kernel = NUTS(potential_fn=potential_fn, dense_mass=True)
mcmc = MCMC(kernel, warmup_steps, num_samples)
mcmc.run(random.PRNGKey(0), init_params=init_params)
samples = mcmc.get_samples()
assert_allclose(jnp.mean(samples), true_mean, atol=0.02)
assert np.sum(np.abs(np.cov(samples.T) - true_cov)) / D**2 < 0.02
示例11: test_improper_normal
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import mean [as 别名]
def test_improper_normal():
true_coef = 0.9
def model(data):
alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
with numpyro.handlers.reparam(config={'loc': TransformReparam()}):
loc = numpyro.sample('loc', dist.TransformedDistribution(
dist.Uniform(0, 1).mask(False),
AffineTransform(0, alpha)))
numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)
data = true_coef + random.normal(random.PRNGKey(0), (1000,))
kernel = NUTS(model=model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000)
mcmc.run(random.PRNGKey(0), data)
samples = mcmc.get_samples()
assert_allclose(jnp.mean(samples['loc'], 0), true_coef, atol=0.05)
示例12: test_beta_bernoulli_x64
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import mean [as 别名]
def test_beta_bernoulli_x64(kernel_cls):
warmup_steps, num_samples = (100000, 100000) if kernel_cls is SA else (500, 20000)
def model(data):
alpha = jnp.array([1.1, 1.1])
beta = jnp.array([1.1, 1.1])
p_latent = numpyro.sample('p_latent', dist.Beta(alpha, beta))
numpyro.sample('obs', dist.Bernoulli(p_latent), obs=data)
return p_latent
true_probs = jnp.array([0.9, 0.1])
data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), (1000, 2))
if kernel_cls is SA:
kernel = SA(model=model)
else:
kernel = kernel_cls(model=model, trajectory_length=0.1)
mcmc = MCMC(kernel, num_warmup=warmup_steps, num_samples=num_samples, progress_bar=False)
mcmc.run(random.PRNGKey(2), data)
mcmc.print_summary()
samples = mcmc.get_samples()
assert_allclose(jnp.mean(samples['p_latent'], 0), true_probs, atol=0.05)
if 'JAX_ENABLE_X64' in os.environ:
assert samples['p_latent'].dtype == jnp.float64
示例13: test_dirichlet_categorical_x64
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import mean [as 别名]
def test_dirichlet_categorical_x64(kernel_cls, dense_mass):
warmup_steps, num_samples = 100, 20000
def model(data):
concentration = jnp.array([1.0, 1.0, 1.0])
p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration))
numpyro.sample('obs', dist.Categorical(p_latent), obs=data)
return p_latent
true_probs = jnp.array([0.1, 0.6, 0.3])
data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000,))
kernel = kernel_cls(model, trajectory_length=1., dense_mass=dense_mass)
mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False)
mcmc.run(random.PRNGKey(2), data)
samples = mcmc.get_samples()
assert_allclose(jnp.mean(samples['p_latent'], 0), true_probs, atol=0.02)
if 'JAX_ENABLE_X64' in os.environ:
assert samples['p_latent'].dtype == jnp.float64
示例14: test_binomial_stable_x64
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import mean [as 别名]
def test_binomial_stable_x64(with_logits):
# Ref: https://github.com/pyro-ppl/pyro/issues/1706
warmup_steps, num_samples = 200, 200
def model(data):
p = numpyro.sample('p', dist.Beta(1., 1.))
if with_logits:
logits = logit(p)
numpyro.sample('obs', dist.Binomial(data['n'], logits=logits), obs=data['x'])
else:
numpyro.sample('obs', dist.Binomial(data['n'], probs=p), obs=data['x'])
data = {'n': 5000000, 'x': 3849}
kernel = NUTS(model=model)
mcmc = MCMC(kernel, warmup_steps, num_samples)
mcmc.run(random.PRNGKey(2), data)
samples = mcmc.get_samples()
assert_allclose(jnp.mean(samples['p'], 0), data['x'] / data['n'], rtol=0.05)
if 'JAX_ENABLE_X64' in os.environ:
assert samples['p'].dtype == jnp.float64
示例15: test_mcmc_progbar
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import mean [as 别名]
def test_mcmc_progbar():
true_mean, true_std = 1., 2.
num_warmup, num_samples = 10, 10
def model(data):
mean = numpyro.sample('mean', dist.Normal(0, 1).mask(False))
std = numpyro.sample('std', dist.LogNormal(0, 1).mask(False))
return numpyro.sample('obs', dist.Normal(mean, std), obs=data)
data = dist.Normal(true_mean, true_std).sample(random.PRNGKey(1), (2000,))
kernel = NUTS(model=model)
mcmc = MCMC(kernel, num_warmup, num_samples)
mcmc.warmup(random.PRNGKey(2), data)
mcmc.run(random.PRNGKey(3), data)
mcmc1 = MCMC(kernel, num_warmup, num_samples, progress_bar=False)
mcmc1.run(random.PRNGKey(2), data)
with pytest.raises(AssertionError):
check_close(mcmc1.get_samples(), mcmc.get_samples(), atol=1e-4, rtol=1e-4)
mcmc1.warmup(random.PRNGKey(2), data)
mcmc1.run(random.PRNGKey(3), data)
check_close(mcmc1.get_samples(), mcmc.get_samples(), atol=1e-4, rtol=1e-4)
check_close(mcmc1._warmup_state, mcmc._warmup_state, atol=1e-4, rtol=1e-4)