本文整理汇总了Python中jax.value_and_grad方法的典型用法代码示例。如果您正苦于以下问题:Python jax.value_and_grad方法的具体用法?Python jax.value_and_grad怎么用?Python jax.value_and_grad使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类jax
的用法示例。
在下文中一共展示了jax.value_and_grad方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test_categorical_log_prob_grad
# 需要导入模块: import jax [as 别名]
# 或者: from jax import value_and_grad [as 别名]
def test_categorical_log_prob_grad():
data = jnp.repeat(jnp.arange(3), 10)
def f(x):
return dist.Categorical(jax.nn.softmax(x * jnp.arange(1, 4))).log_prob(data).sum()
def g(x):
return dist.Categorical(logits=x * jnp.arange(1, 4)).log_prob(data).sum()
x = 0.5
fx, grad_fx = jax.value_and_grad(f)(x)
gx, grad_gx = jax.value_and_grad(g)(x)
assert_allclose(fx, gx)
assert_allclose(grad_fx, grad_gx, atol=1e-4)
########################################
# Tests for constraints and transforms #
########################################
示例2: update
# 需要导入模块: import jax [as 别名]
# 或者: from jax import value_and_grad [as 别名]
def update(self, svi_state, *args, **kwargs):
"""
Take a single step of SVI (possibly on a batch / minibatch of data),
using the optimizer.
:param svi_state: current state of SVI.
:param args: arguments to the model / guide (these can possibly vary during
the course of fitting).
:param kwargs: keyword arguments to the model / guide (these can possibly vary
during the course of fitting).
:return: tuple of `(svi_state, loss)`.
"""
rng_key, rng_key_step = random.split(svi_state.rng_key)
params = self.optim.get_params(svi_state.optim_state)
loss_val, grads = value_and_grad(
lambda x: self.loss.loss(rng_key_step, self.constrain_fn(x), self.model, self.guide,
*args, **kwargs, **self.static_kwargs))(params)
optim_state = self.optim.update(grads, svi_state.optim_state)
return SVIState(optim_state, rng_key), loss_val
示例3: test_renyi_elbo
# 需要导入模块: import jax [as 别名]
# 或者: from jax import value_and_grad [as 别名]
def test_renyi_elbo(alpha):
def model(x):
numpyro.sample("obs", dist.Normal(0, 1), obs=x)
def guide(x):
pass
def elbo_loss_fn(x):
return ELBO().loss(random.PRNGKey(0), {}, model, guide, x)
def renyi_loss_fn(x):
return RenyiELBO(alpha=alpha, num_particles=10).loss(random.PRNGKey(0), {}, model, guide, x)
elbo_loss, elbo_grad = value_and_grad(elbo_loss_fn)(2.)
renyi_loss, renyi_grad = value_and_grad(renyi_loss_fn)(2.)
assert_allclose(elbo_loss, renyi_loss, rtol=1e-6)
assert_allclose(elbo_grad, renyi_grad, rtol=1e-6)
示例4: _check_sample
# 需要导入模块: import jax [as 别名]
# 或者: from jax import value_and_grad [as 别名]
def _check_sample(funsor_dist_class, params, sample_inputs, inputs, atol=1e-2,
num_samples=100000, statistic="mean", skip_grad=False, with_lazy=None):
"""utility that compares a Monte Carlo estimate of a distribution mean with the true mean"""
samples_per_dim = int(num_samples ** (1./max(1, len(sample_inputs))))
sample_inputs = OrderedDict((k, bint(samples_per_dim)) for k in sample_inputs)
_get_stat_diff_fn = functools.partial(
_get_stat_diff, funsor_dist_class, sample_inputs, inputs, num_samples, statistic, with_lazy)
if get_backend() == "torch":
import torch
for param in params:
param.requires_grad_()
res = _get_stat_diff_fn(params)
if sample_inputs:
diff_sum, diff = res
assert_close(diff, ops.new_zeros(diff, diff.shape), atol=atol, rtol=None)
if not skip_grad:
diff_grads = torch.autograd.grad(diff_sum, params, allow_unused=True)
for diff_grad in diff_grads:
assert_close(diff_grad, ops.new_zeros(diff_grad, diff_grad.shape), atol=atol, rtol=None)
elif get_backend() == "jax":
import jax
if sample_inputs:
if skip_grad:
_, diff = _get_stat_diff_fn(params)
assert_close(diff, ops.new_zeros(diff, diff.shape), atol=atol, rtol=None)
else:
(_, diff), diff_grads = jax.value_and_grad(_get_stat_diff_fn, has_aux=True)(params)
assert_close(diff, ops.new_zeros(diff, diff.shape), atol=atol, rtol=None)
for diff_grad in diff_grads:
assert_close(diff_grad, ops.new_zeros(diff_grad, diff_grad.shape), atol=atol, rtol=None)
else:
_get_stat_diff_fn(params)
示例5: _update_fun
# 需要导入模块: import jax [as 别名]
# 或者: from jax import value_and_grad [as 别名]
def _update_fun(self, loss_fun, return_loss=False):
def update(state, *inputs, **kwargs):
params = self.get_parameters(state)
if return_loss:
loss, gradient = value_and_grad(loss_fun)(params, *inputs, **kwargs)
return self.update_from_gradients(gradient, state), loss
else:
gradient = grad(loss_fun)(params, *inputs, **kwargs)
return self.update_from_gradients(gradient, state)
return update
示例6: velocity_verlet
# 需要导入模块: import jax [as 别名]
# 或者: from jax import value_and_grad [as 别名]
def velocity_verlet(potential_fn, kinetic_fn):
r"""
Second order symplectic integrator that uses the velocity verlet algorithm
for position `z` and momentum `r`.
:param potential_fn: Python callable that computes the potential energy
given input parameters. The input parameters to `potential_fn` can be
any python collection type.
:param kinetic_fn: Python callable that returns the kinetic energy given
inverse mass matrix and momentum.
:return: a pair of (`init_fn`, `update_fn`).
"""
def init_fn(z, r, potential_energy=None, z_grad=None):
"""
:param z: Position of the particle.
:param r: Momentum of the particle.
:param potential_energy: Potential energy at `z`.
:param z_grad: gradient of potential energy at `z`.
:return: initial state for the integrator.
"""
if potential_energy is None or z_grad is None:
potential_energy, z_grad = value_and_grad(potential_fn)(z)
return IntegratorState(z, r, potential_energy, z_grad)
def update_fn(step_size, inverse_mass_matrix, state):
"""
:param float step_size: Size of a single step.
:param inverse_mass_matrix: Inverse of mass matrix, which is used to
calculate kinetic energy.
:param state: Current state of the integrator.
:return: new state for the integrator.
"""
z, r, _, z_grad = state
r = tree_multimap(lambda r, z_grad: r - 0.5 * step_size * z_grad, r, z_grad) # r(n+1/2)
r_grad = grad(kinetic_fn, argnums=1)(inverse_mass_matrix, r)
z = tree_multimap(lambda z, r_grad: z + step_size * r_grad, z, r_grad) # z(n+1)
potential_energy, z_grad = value_and_grad(potential_fn)(z)
r = tree_multimap(lambda r, z_grad: r - 0.5 * step_size * z_grad, r, z_grad) # r(n+1)
return IntegratorState(z, r, potential_energy, z_grad)
return init_fn, update_fn
示例7: find_reasonable_step_size
# 需要导入模块: import jax [as 别名]
# 或者: from jax import value_and_grad [as 别名]
def find_reasonable_step_size(potential_fn, kinetic_fn, momentum_generator,
init_step_size, inverse_mass_matrix, position, rng_key):
"""
Finds a reasonable step size by tuning `init_step_size`. This function is used
to avoid working with a too large or too small step size in HMC.
**References:**
1. *The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo*,
Matthew D. Hoffman, Andrew Gelman
:param potential_fn: A callable to compute potential energy.
:param kinetic_fn: A callable to compute kinetic energy.
:param momentum_generator: A generator to get a random momentum variable.
:param float init_step_size: Initial step size to be tuned.
:param inverse_mass_matrix: Inverse of mass matrix.
:param position: Current position of the particle.
:param jax.random.PRNGKey rng_key: Random key to be used as the source of randomness.
:return: a reasonable value for step size.
:rtype: float
"""
# We are going to find a step_size which make accept_prob (Metropolis correction)
# near the target_accept_prob. If accept_prob:=exp(-delta_energy) is small,
# then we have to decrease step_size; otherwise, increase step_size.
target_accept_prob = jnp.log(0.8)
_, vv_update = velocity_verlet(potential_fn, kinetic_fn)
z = position
potential_energy, z_grad = value_and_grad(potential_fn)(z)
finfo = jnp.finfo(get_dtype(init_step_size))
def _body_fn(state):
step_size, _, direction, rng_key = state
rng_key, rng_key_momentum = random.split(rng_key)
# scale step_size: increase 2x or decrease 2x depends on direction;
# direction=1 means keep increasing step_size, otherwise decreasing step_size.
# Note that the direction is -1 if delta_energy is `NaN`, which may be the
# case for a diverging trajectory (e.g. in the case of evaluating log prob
# of a value simulated using a large step size for a constrained sample site).
step_size = (2.0 ** direction) * step_size
r = momentum_generator(position, inverse_mass_matrix, rng_key_momentum)
_, r_new, potential_energy_new, _ = vv_update(step_size,
inverse_mass_matrix,
(z, r, potential_energy, z_grad))
energy_current = kinetic_fn(inverse_mass_matrix, r) + potential_energy
energy_new = kinetic_fn(inverse_mass_matrix, r_new) + potential_energy_new
delta_energy = energy_new - energy_current
direction_new = jnp.where(target_accept_prob < -delta_energy, 1, -1)
return step_size, direction, direction_new, rng_key
def _cond_fn(state):
step_size, last_direction, direction, _ = state
# condition to run only if step_size is not too small or we are not decreasing step_size
not_small_step_size_cond = (step_size > finfo.tiny) | (direction >= 0)
# condition to run only if step_size is not too large or we are not increasing step_size
not_large_step_size_cond = (step_size < finfo.max) | (direction <= 0)
not_extreme_cond = not_small_step_size_cond & not_large_step_size_cond
return not_extreme_cond & ((last_direction == 0) | (direction == last_direction))
step_size, _, _, _ = while_loop(_cond_fn, _body_fn, (init_step_size, 0, 0, rng_key))
return step_size