本文整理汇总了Python中jax.numpy.sum方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.sum方法的具体用法?Python numpy.sum怎么用?Python numpy.sum使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类jax.numpy
的用法示例。
在下文中一共展示了numpy.sum方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: approximate_kl
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sum [as 别名]
def approximate_kl(log_prob_new, log_prob_old, mask):
"""Computes the approximate KL divergence between the old and new log-probs.
Args:
log_prob_new: (B, T+1, A) log probs new
log_prob_old: (B, T+1, A) log probs old
mask: (B, T)
Returns:
Approximate KL.
"""
diff = log_prob_old - log_prob_new
# Cut the last time-step out.
diff = diff[:, :-1]
# Mask out the irrelevant part.
diff *= mask[:, :, np.newaxis] # make mask (B, T, 1)
# Average on non-masked part.
return np.sum(diff) / np.sum(mask)
示例2: masked_entropy
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sum [as 别名]
def masked_entropy(log_probs, mask):
"""Computes the entropy for the given log-probs.
Args:
log_probs: (B, T+1, A) log probs
mask: (B, T) mask.
Returns:
Entropy.
"""
# Cut the last time-step out.
lp = log_probs[:, :-1]
# Mask out the irrelevant part.
lp *= mask[:, :, np.newaxis] # make mask (B, T, 1)
p = np.exp(lp) * mask[:, :, np.newaxis] # (B, T, 1)
# Average on non-masked part and take negative.
return -(np.sum(lp * p) / np.sum(mask))
示例3: test_parameters_from_subsubmodule
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sum [as 别名]
def test_parameters_from_subsubmodule():
subsublayer = Dense(2)
sublayer = Sequential(subsublayer, relu)
net = Sequential(sublayer, jnp.sum)
inputs = jnp.zeros((1, 3))
params = net.init_parameters(inputs, key=PRNGKey(0))
out = net.apply(params, inputs)
subsublayer_params = subsublayer.init_parameters(inputs, key=PRNGKey(0))
params_ = net.parameters_from({subsublayer: subsublayer_params}, inputs)
assert_dense_parameters_equal(subsublayer_params, params_[0][0])
out_ = net.apply(params_, inputs)
assert out.shape == out_.shape
out_ = net.apply_from({subsublayer: subsublayer_params}, inputs)
assert out.shape == out_.shape
out_ = net.apply_from({subsublayer: subsublayer_params}, inputs, jit=True)
assert out.shape == out_.shape
示例4: test_parameters_from_sharing_between_multiple_parents
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sum [as 别名]
def test_parameters_from_sharing_between_multiple_parents():
a = Dense(2)
b = Sequential(a, jnp.sum)
@parametrized
def net(inputs):
return a(inputs), b(inputs)
inputs = jnp.zeros((1, 3))
a_params = a.init_parameters(inputs, key=PRNGKey(0))
out = a.apply(a_params, inputs)
params = net.parameters_from({a: a_params}, inputs)
assert_dense_parameters_equal(a_params, params.dense)
assert_parameters_equal((), params.sequential)
assert 2 == len(params)
out_, _ = net.apply(params, inputs)
assert jnp.array_equal(out, out_)
示例5: get_data
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sum [as 别名]
def get_data(N=20, S=2, P=10, sigma_obs=0.05):
assert S < P and P > 1 and S > 0
np.random.seed(0)
X = np.random.randn(N, P)
# generate S coefficients with non-negligible magnitude
W = 0.5 + 2.5 * np.random.rand(S)
# generate data using the S coefficients and a single pairwise interaction
Y = np.sum(X[:, 0:S] * W, axis=-1) + X[:, 0] * X[:, 1] + sigma_obs * np.random.randn(N)
Y -= jnp.mean(Y)
Y_std = jnp.std(Y)
assert X.shape == (N, P)
assert Y.shape == (N,)
return X, Y / Y_std, W / Y_std, 1.0 / Y_std
# Helper function for analyzing the posterior statistics for coefficient theta_i
示例6: test_beta_bernoulli
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sum [as 别名]
def test_beta_bernoulli(auto_class):
data = jnp.array([[1.0] * 8 + [0.0] * 2,
[1.0] * 4 + [0.0] * 6]).T
def model(data):
f = numpyro.sample('beta', dist.Beta(jnp.ones(2), jnp.ones(2)))
numpyro.sample('obs', dist.Bernoulli(f), obs=data)
adam = optim.Adam(0.01)
guide = auto_class(model, init_strategy=init_strategy)
svi = SVI(model, guide, adam, ELBO())
svi_state = svi.init(random.PRNGKey(1), data)
def body_fn(i, val):
svi_state, loss = svi.update(val, data)
return svi_state
svi_state = fori_loop(0, 3000, body_fn, svi_state)
params = svi.get_params(svi_state)
true_coefs = (jnp.sum(data, axis=0) + 1) / (data.shape[0] + 2)
# test .sample_posterior method
posterior_samples = guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(1000,))
assert_allclose(jnp.mean(posterior_samples['beta'], 0), true_coefs, atol=0.05)
示例7: test_unnormalized_normal_x64
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sum [as 别名]
def test_unnormalized_normal_x64(kernel_cls, dense_mass):
true_mean, true_std = 1., 0.5
warmup_steps, num_samples = (100000, 100000) if kernel_cls is SA else (1000, 8000)
def potential_fn(z):
return 0.5 * jnp.sum(((z - true_mean) / true_std) ** 2)
init_params = jnp.array(0.)
if kernel_cls is SA:
kernel = SA(potential_fn=potential_fn, dense_mass=dense_mass)
else:
kernel = kernel_cls(potential_fn=potential_fn, trajectory_length=8, dense_mass=dense_mass)
mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False)
mcmc.run(random.PRNGKey(0), init_params=init_params)
mcmc.print_summary()
hmc_states = mcmc.get_samples()
assert_allclose(jnp.mean(hmc_states), true_mean, rtol=0.07)
assert_allclose(jnp.std(hmc_states), true_std, rtol=0.07)
if 'JAX_ENABLE_X64' in os.environ:
assert hmc_states.dtype == jnp.float64
示例8: test_correlated_mvn
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sum [as 别名]
def test_correlated_mvn():
# This requires dense mass matrix estimation.
D = 5
warmup_steps, num_samples = 5000, 8000
true_mean = 0.
a = jnp.tril(0.5 * jnp.fliplr(jnp.eye(D)) + 0.1 * jnp.exp(random.normal(random.PRNGKey(0), shape=(D, D))))
true_cov = jnp.dot(a, a.T)
true_prec = jnp.linalg.inv(true_cov)
def potential_fn(z):
return 0.5 * jnp.dot(z.T, jnp.dot(true_prec, z))
init_params = jnp.zeros(D)
kernel = NUTS(potential_fn=potential_fn, dense_mass=True)
mcmc = MCMC(kernel, warmup_steps, num_samples)
mcmc.run(random.PRNGKey(0), init_params=init_params)
samples = mcmc.get_samples()
assert_allclose(jnp.mean(samples), true_mean, atol=0.02)
assert np.sum(np.abs(np.cov(samples.T) - true_cov)) / D**2 < 0.02
示例9: test_diverging
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sum [as 别名]
def test_diverging(kernel_cls, adapt_step_size):
data = random.normal(random.PRNGKey(0), (1000,))
def model(data):
loc = numpyro.sample('loc', dist.Normal(0., 1.))
numpyro.sample('obs', dist.Normal(loc, 1), obs=data)
kernel = kernel_cls(model, step_size=10., adapt_step_size=adapt_step_size, adapt_mass_matrix=False)
num_warmup = num_samples = 1000
mcmc = MCMC(kernel, num_warmup, num_samples)
mcmc.warmup(random.PRNGKey(1), data, extra_fields=['diverging'], collect_warmup=True)
warmup_divergences = mcmc.get_extra_fields()['diverging'].sum()
mcmc.run(random.PRNGKey(2), data, extra_fields=['diverging'])
num_divergences = warmup_divergences + mcmc.get_extra_fields()['diverging'].sum()
if adapt_step_size:
assert num_divergences <= num_warmup
else:
assert_allclose(num_divergences, num_warmup + num_samples)
示例10: test_functional_map
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sum [as 别名]
def test_functional_map(algo, map_fn):
if map_fn is pmap and xla_bridge.device_count() == 1:
pytest.skip('pmap test requires device_count greater than 1.')
true_mean, true_std = 1., 2.
warmup_steps, num_samples = 1000, 8000
def potential_fn(z):
return 0.5 * jnp.sum(((z - true_mean) / true_std) ** 2)
init_kernel, sample_kernel = hmc(potential_fn, algo=algo)
init_params = jnp.array([0., -1.])
rng_keys = random.split(random.PRNGKey(0), 2)
init_kernel_map = map_fn(lambda init_param, rng_key: init_kernel(
init_param, trajectory_length=9, num_warmup=warmup_steps, rng_key=rng_key))
init_states = init_kernel_map(init_params, rng_keys)
fori_collect_map = map_fn(lambda hmc_state: fori_collect(0, num_samples, sample_kernel, hmc_state,
transform=lambda x: x.z, progbar=False))
chain_samples = fori_collect_map(init_states)
assert_allclose(jnp.mean(chain_samples, axis=1), jnp.repeat(true_mean, 2), rtol=0.06)
assert_allclose(jnp.std(chain_samples, axis=1), jnp.repeat(true_std, 2), rtol=0.06)
示例11: test_categorical_log_prob_grad
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sum [as 别名]
def test_categorical_log_prob_grad():
data = jnp.repeat(jnp.arange(3), 10)
def f(x):
return dist.Categorical(jax.nn.softmax(x * jnp.arange(1, 4))).log_prob(data).sum()
def g(x):
return dist.Categorical(logits=x * jnp.arange(1, 4)).log_prob(data).sum()
x = 0.5
fx, grad_fx = jax.value_and_grad(f)(x)
gx, grad_gx = jax.value_and_grad(g)(x)
assert_allclose(fx, gx)
assert_allclose(grad_fx, grad_gx, atol=1e-4)
########################################
# Tests for constraints and transforms #
########################################
示例12: clip_eta
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sum [as 别名]
def clip_eta(eta, norm, eps):
"""
Helper function to clip the perturbation to epsilon norm ball.
:param eta: A tensor with the current perturbation.
:param norm: Order of the norm (mimics Numpy).
Possible values: np.inf or 2.
:param eps: Epsilon, bound of the perturbation.
"""
# Clipping perturbation eta to self.norm norm ball
if norm not in [np.inf, 2]:
raise ValueError('norm must be np.inf or 2.')
axis = list(range(1, len(eta.shape)))
avoid_zero_div = 1e-12
if norm == np.inf:
eta = np.clip(eta, a_min=-eps, a_max=eps)
elif norm == 2:
# avoid_zero_div must go inside sqrt to avoid a divide by zero in the gradient through this operation
norm = np.sqrt(np.maximum(avoid_zero_div, np.sum(np.square(eta), axis=axis, keepdims=True)))
# We must *clip* to within the norm ball, not *normalize* onto the surface of the ball
factor = np.minimum(1., np.divide(eps, norm))
eta = eta * factor
return eta
示例13: evaluate_policy
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sum [as 别名]
def evaluate_policy(eval_env,
get_predictions,
temperatures,
max_timestep=20000,
n_evals=1,
len_history_for_policy=32,
rng=None):
"""Evaluate the policy."""
processed_reward_sums = collections.defaultdict(list)
raw_reward_sums = collections.defaultdict(list)
for eval_rng in jax_random.split(rng, num=n_evals):
for temperature in temperatures:
trajs, _, _ = env_problem_utils.play_env_problem_with_policy(
eval_env,
get_predictions,
num_trajectories=eval_env.batch_size,
max_timestep=max_timestep,
reset=True,
policy_sampling=env_problem_utils.GUMBEL_SAMPLING,
temperature=temperature,
rng=eval_rng,
len_history_for_policy=len_history_for_policy)
processed_reward_sums[temperature].extend(sum(traj[2]) for traj in trajs)
raw_reward_sums[temperature].extend(sum(traj[3]) for traj in trajs)
# Return the mean and standard deviation for each temperature.
def compute_stats(reward_dict):
return {
temperature: {"mean": onp.mean(rewards), "std": onp.std(rewards)}
for (temperature, rewards) in reward_dict.items()
}
return {
"processed": compute_stats(processed_reward_sums),
"raw": compute_stats(raw_reward_sums),
}
示例14: _sum
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sum [as 别名]
def _sum(x, dim):
return np.sum(x, axis=dim)
示例15: loss
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sum [as 别名]
def loss(y, y_hat):
return -np.sum(y * y_hat)