本文整理汇总了Python中jax.vmap方法的典型用法代码示例。如果您正苦于以下问题:Python jax.vmap方法的具体用法?Python jax.vmap怎么用?Python jax.vmap使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类jax
的用法示例。
在下文中一共展示了jax.vmap方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test_Batched
# 需要导入模块: import jax [as 别名]
# 或者: from jax import vmap [as 别名]
def test_Batched():
out_dim = 1
@parametrized
def unbatched_dense(input):
kernel = parameter((out_dim, input.shape[-1]), ones)
bias = parameter((out_dim,), ones)
return jnp.dot(kernel, input) + bias
batch_size = 4
unbatched_params = unbatched_dense.init_parameters(jnp.zeros(2), key=PRNGKey(0))
out = unbatched_dense.apply(unbatched_params, jnp.ones(2))
assert jnp.array([3.]) == out
dense_apply = vmap(unbatched_dense.apply, (None, 0))
out_batched_ = dense_apply(unbatched_params, jnp.ones((batch_size, 2)))
assert jnp.array_equal(jnp.stack([out] * batch_size), out_batched_)
dense = Batched(unbatched_dense)
params = dense.init_parameters(jnp.ones((batch_size, 2)), key=PRNGKey(0))
assert_parameters_equal((unbatched_params,), params)
out_batched = dense.apply(params, jnp.ones((batch_size, 2)))
assert jnp.array_equal(out_batched_, out_batched)
示例2: _predictive
# 需要导入模块: import jax [as 别名]
# 或者: from jax import vmap [as 别名]
def _predictive(rng_key, model, posterior_samples, num_samples, return_sites=None,
parallel=True, model_args=(), model_kwargs={}):
rng_keys = random.split(rng_key, num_samples)
def single_prediction(val):
rng_key, samples = val
model_trace = trace(seed(substitute(model, samples), rng_key)).get_trace(
*model_args, **model_kwargs)
if return_sites is not None:
if return_sites == '':
sites = {k for k, site in model_trace.items() if site['type'] != 'plate'}
else:
sites = return_sites
else:
sites = {k for k, site in model_trace.items()
if (site['type'] == 'sample' and k not in samples) or (site['type'] == 'deterministic')}
return {name: site['value'] for name, site in model_trace.items() if name in sites}
if parallel:
return vmap(single_prediction)((rng_keys, posterior_samples))
else:
return lax.map(single_prediction, (rng_keys, posterior_samples))
示例3: log_likelihood
# 需要导入模块: import jax [as 别名]
# 或者: from jax import vmap [as 别名]
def log_likelihood(model, posterior_samples, *args, **kwargs):
"""
(EXPERIMENTAL INTERFACE) Returns log likelihood at observation nodes of model,
given samples of all latent variables.
:param model: Python callable containing Pyro primitives.
:param dict posterior_samples: dictionary of samples from the posterior.
:param args: model arguments.
:param kwargs: model kwargs.
:return: dict of log likelihoods at observation sites.
"""
def single_loglik(samples):
model_trace = trace(substitute(model, samples)).get_trace(*args, **kwargs)
return {name: site['fn'].log_prob(site['value']) for name, site in model_trace.items()
if site['type'] == 'sample' and site['is_observed']}
return vmap(single_loglik)(posterior_samples)
示例4: _multinomial
# 需要导入模块: import jax [as 别名]
# 或者: from jax import vmap [as 别名]
def _multinomial(key, p, n, n_max, shape=()):
if jnp.shape(n) != jnp.shape(p)[:-1]:
broadcast_shape = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1])
n = jnp.broadcast_to(n, broadcast_shape)
p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:])
shape = shape or p.shape[:-1]
# get indices from categorical distribution then gather the result
indices = categorical(key, p, (n_max,) + shape)
# mask out values when counts is heterogeneous
if jnp.ndim(n) > 0:
mask = promote_shapes(jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,))[0]
mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype)
excess = jnp.concatenate([jnp.expand_dims(n_max - n, -1), jnp.zeros(jnp.shape(n) + (p.shape[-1] - 1,))], -1)
else:
mask = 1
excess = 0
# NB: we transpose to move batch shape to the front
indices_2D = (jnp.reshape(indices * mask, (n_max, -1,))).T
samples_2D = vmap(_scatter_add_one, (0, 0, 0))(jnp.zeros((indices_2D.shape[0], p.shape[-1]),
dtype=indices.dtype),
jnp.expand_dims(indices_2D, axis=-1),
jnp.ones(indices_2D.shape, dtype=indices.dtype))
return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess
示例5: test_block_neural_arn
# 需要导入模块: import jax [as 别名]
# 或者: from jax import vmap [as 别名]
def test_block_neural_arn(input_dim, hidden_factors, residual, batch_shape):
arn_init, arn = BlockNeuralAutoregressiveNN(input_dim, hidden_factors, residual)
rng = random.PRNGKey(0)
input_shape = batch_shape + (input_dim,)
out_shape, init_params = arn_init(rng, input_shape)
assert out_shape == input_shape
x = random.normal(random.PRNGKey(1), input_shape)
output, logdet = arn(init_params, x)
assert output.shape == input_shape
assert logdet.shape == input_shape
if len(batch_shape) == 1:
jac = vmap(jacfwd(lambda x: arn(init_params, x)[0]))(x)
else:
jac = jacfwd(lambda x: arn(init_params, x)[0])(x)
assert_allclose(logdet.sum(-1), jnp.linalg.slogdet(jac)[1], rtol=1e-6)
# make sure jacobians are lower triangular
assert np.sum(np.abs(np.triu(jac, k=1))) == 0.0
assert np.all(np.abs(matrix_to_tril_vec(jac)) > 0)
示例6: test_binop_batch_rule
# 需要导入模块: import jax [as 别名]
# 或者: from jax import vmap [as 别名]
def test_binop_batch_rule(prim):
bx = jnp.array([1., 2., 3.])
by = jnp.array([2., 3., 4.])
x = jnp.array(1.)
y = jnp.array(2.)
actual_bx_by = vmap(lambda x, y: prim(x, y))(bx, by)
for i in range(3):
assert_allclose(actual_bx_by[i], prim(bx[i], by[i]))
actual_x_by = vmap(lambda y: prim(x, y))(by)
for i in range(3):
assert_allclose(actual_x_by[i], prim(x, by[i]))
actual_bx_y = vmap(lambda x: prim(x, y))(bx)
for i in range(3):
assert_allclose(actual_bx_y[i], prim(bx[i], y))
示例7: grads
# 需要导入模块: import jax [as 别名]
# 或者: from jax import vmap [as 别名]
def grads(self, inputs):
in_grad_partial = jax.partial(self._net_grads, self._net_params)
grad_vmap = jax.vmap(in_grad_partial)
rich_grads = grad_vmap(inputs)
flat_grads = np.asarray(self._flatten_batch(rich_grads))
assert flat_grads.ndim == 2 and flat_grads.shape[0] == inputs.shape[0]
return flat_grads
示例8: Batched
# 需要导入模块: import jax [as 别名]
# 或者: from jax import vmap [as 别名]
def Batched(unbatched_model: parametrized, batch_dim=0):
@parametrized
def batched(*batched_args):
args = tree_map(lambda x: x[0], batched_args)
params = Parameter(lambda key: unbatched_model.init_parameters(*args, key=key), 'model')()
batched_apply = vmap(partial(unbatched_model.apply, params), batch_dim)
return batched_apply(*batched_args)
return batched
示例9: main
# 需要导入模块: import jax [as 别名]
# 或者: from jax import vmap [as 别名]
def main(args):
N, D_X, D_H = args.num_data, 3, args.num_hidden
X, Y, X_test = get_data(N=N, D_X=D_X)
# do inference
rng_key, rng_key_predict = random.split(random.PRNGKey(0))
samples = run_inference(model, args, rng_key, X, Y, D_H)
# predict Y_test at inputs X_test
vmap_args = (samples, random.split(rng_key_predict, args.num_samples * args.num_chains))
predictions = vmap(lambda samples, rng_key: predict(model, rng_key, samples, X_test, D_H))(*vmap_args)
predictions = predictions[..., 0]
# compute mean prediction and confidence interval around median
mean_prediction = jnp.mean(predictions, axis=0)
percentiles = np.percentile(predictions, [5.0, 95.0], axis=0)
# make plots
fig, ax = plt.subplots(1, 1)
# plot training data
ax.plot(X[:, 1], Y[:, 0], 'kx')
# plot 90% confidence level of predictions
ax.fill_between(X_test[:, 1], percentiles[0, :], percentiles[1, :], color='lightblue')
# plot mean prediction
ax.plot(X_test[:, 1], mean_prediction, 'blue', ls='solid', lw=2.0)
ax.set(xlabel="X", ylabel="Y", title="Mean predictions with 90% CI")
plt.savefig('bnn_plot.pdf')
plt.tight_layout()
示例10: analyze_dimension
# 需要导入模块: import jax [as 别名]
# 或者: from jax import vmap [as 别名]
def analyze_dimension(samples, X, Y, dimension, hypers):
vmap_args = (samples['msq'], samples['lambda'], samples['eta1'], samples['xisq'], samples['var_obs'])
mus, variances = vmap(lambda msq, lam, eta1, xisq, var_obs:
compute_singleton_mean_variance(X, Y, dimension, msq, lam,
eta1, xisq, hypers['c'], var_obs))(*vmap_args)
mean, variance = gaussian_mixture_stats(mus, variances)
std = jnp.sqrt(variance)
return mean, std
# Helper function for analyzing the posterior statistics for coefficient theta_ij
示例11: analyze_pair_of_dimensions
# 需要导入模块: import jax [as 别名]
# 或者: from jax import vmap [as 别名]
def analyze_pair_of_dimensions(samples, X, Y, dim1, dim2, hypers):
vmap_args = (samples['msq'], samples['lambda'], samples['eta1'], samples['xisq'], samples['var_obs'])
mus, variances = vmap(lambda msq, lam, eta1, xisq, var_obs:
compute_pairwise_mean_variance(X, Y, dim1, dim2, msq, lam,
eta1, xisq, hypers['c'], var_obs))(*vmap_args)
mean, variance = gaussian_mixture_stats(mus, variances)
std = jnp.sqrt(variance)
return mean, std
示例12: main
# 需要导入模块: import jax [as 别名]
# 或者: from jax import vmap [as 别名]
def main(args):
X, Y, X_test = get_data(N=args.num_data)
# do inference
rng_key, rng_key_predict = random.split(random.PRNGKey(0))
samples = run_inference(model, args, rng_key, X, Y)
# do prediction
vmap_args = (random.split(rng_key_predict, args.num_samples * args.num_chains), samples['kernel_var'],
samples['kernel_length'], samples['kernel_noise'])
means, predictions = vmap(lambda rng_key, var, length, noise:
predict(rng_key, X, Y, X_test, var, length, noise))(*vmap_args)
mean_prediction = np.mean(means, axis=0)
percentiles = np.percentile(predictions, [5.0, 95.0], axis=0)
# make plots
fig, ax = plt.subplots(1, 1)
# plot training data
ax.plot(X, Y, 'kx')
# plot 90% confidence level of predictions
ax.fill_between(X_test, percentiles[0, :], percentiles[1, :], color='lightblue')
# plot mean prediction
ax.plot(X_test, mean_prediction, 'blue', ls='solid', lw=2.0)
ax.set(xlabel="X", ylabel="Y", title="Mean predictions with 90% CI")
plt.savefig("gp_plot.pdf")
plt.tight_layout()
示例13: init
# 需要导入模块: import jax [as 别名]
# 或者: from jax import vmap [as 别名]
def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={}):
# non-vectorized
if rng_key.ndim == 1:
rng_key, rng_key_init_model = random.split(rng_key)
# vectorized
else:
rng_key, rng_key_init_model = jnp.swapaxes(vmap(random.split)(rng_key), 0, 1)
init_params = self._init_state(rng_key_init_model, model_args, model_kwargs, init_params)
if self._potential_fn and init_params is None:
raise ValueError('Valid value of `init_params` must be provided with'
' `potential_fn`.')
hmc_init_fn = lambda init_params, rng_key: self._init_fn( # noqa: E731
init_params,
num_warmup=num_warmup,
step_size=self._step_size,
adapt_step_size=self._adapt_step_size,
adapt_mass_matrix=self._adapt_mass_matrix,
dense_mass=self._dense_mass,
target_accept_prob=self._target_accept_prob,
trajectory_length=self._trajectory_length,
max_tree_depth=self._max_tree_depth,
find_heuristic_step_size=self._find_heuristic_step_size,
model_args=model_args,
model_kwargs=model_kwargs,
rng_key=rng_key,
)
if rng_key.ndim == 1:
init_state = hmc_init_fn(init_params, rng_key)
else:
# XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some
# nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth,
# wa_steps because those variables do not depend on traced args: init_params, rng_key.
init_state = vmap(hmc_init_fn)(init_params, rng_key)
sample_fn = vmap(self._sample_fn, in_axes=(0, None, None))
self._sample_fn = sample_fn
return init_state
示例14: parametric_draws
# 需要导入模块: import jax [as 别名]
# 或者: from jax import vmap [as 别名]
def parametric_draws(subposteriors, num_draws, diagonal=False, rng_key=None):
"""
Merges subposteriors following (embarrassingly parallel) parametric Monte Carlo algorithm.
**References:**
1. *Asymptotically Exact, Embarrassingly Parallel MCMC*,
Willie Neiswanger, Chong Wang, Eric Xing
:param list subposteriors: a list in which each element is a collection of samples.
:param int num_draws: number of draws from the merged posterior.
:param bool diagonal: whether to compute weights using variance or covariance, defaults to
`False` (using covariance).
:param jax.random.PRNGKey rng_key: source of the randomness, defaults to `jax.random.PRNGKey(0)`.
:return: a collection of `num_draws` samples with the same data structure as each subposterior.
"""
rng_key = random.PRNGKey(0) if rng_key is None else rng_key
if diagonal:
mean, var = parametric(subposteriors, diagonal=True)
samples_flat = dist.Normal(mean, jnp.sqrt(var)).sample(rng_key, (num_draws,))
else:
mean, cov = parametric(subposteriors, diagonal=False)
samples_flat = dist.MultivariateNormal(mean, cov).sample(rng_key, (num_draws,))
_, unravel_fn = ravel_pytree(tree_map(lambda x: x[0], subposteriors[0]))
return vmap(lambda x: unravel_fn(x))(samples_flat)
示例15: loss
# 需要导入模块: import jax [as 别名]
# 或者: from jax import vmap [as 别名]
def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
"""
Evaluates the ELBO with an estimator that uses num_particles many samples/particles.
:param jax.random.PRNGKey rng_key: random number generator seed.
:param dict param_map: dictionary of current parameter values keyed by site
name.
:param model: Python callable with NumPyro primitives for the model.
:param guide: Python callable with NumPyro primitives for the guide.
:param args: arguments to the model / guide (these can possibly vary during
the course of fitting).
:param kwargs: keyword arguments to the model / guide (these can possibly vary
during the course of fitting).
:return: negative of the Evidence Lower Bound (ELBO) to be minimized.
"""
def single_particle_elbo(rng_key):
model_seed, guide_seed = random.split(rng_key)
seeded_model = seed(model, model_seed)
seeded_guide = seed(guide, guide_seed)
guide_log_density, guide_trace = log_density(seeded_guide, args, kwargs, param_map)
seeded_model = replay(seeded_model, guide_trace)
model_log_density, _ = log_density(seeded_model, args, kwargs, param_map)
# log p(z) - log q(z)
elbo = model_log_density - guide_log_density
return elbo
# Return (-elbo) since by convention we do gradient descent on a loss and
# the ELBO is a lower bound that needs to be maximized.
if self.num_particles == 1:
return - single_particle_elbo(rng_key)
else:
rng_keys = random.split(rng_key, self.num_particles)
return - jnp.mean(vmap(single_particle_elbo)(rng_keys))