當前位置: 首頁>>代碼示例>>Python>>正文


Python jax.value_and_grad方法代碼示例

本文整理匯總了Python中jax.value_and_grad方法的典型用法代碼示例。如果您正苦於以下問題:Python jax.value_and_grad方法的具體用法?Python jax.value_and_grad怎麽用?Python jax.value_and_grad使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在jax的用法示例。


在下文中一共展示了jax.value_and_grad方法的7個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: test_categorical_log_prob_grad

# 需要導入模塊: import jax [as 別名]
# 或者: from jax import value_and_grad [as 別名]
def test_categorical_log_prob_grad():
    data = jnp.repeat(jnp.arange(3), 10)

    def f(x):
        return dist.Categorical(jax.nn.softmax(x * jnp.arange(1, 4))).log_prob(data).sum()

    def g(x):
        return dist.Categorical(logits=x * jnp.arange(1, 4)).log_prob(data).sum()

    x = 0.5
    fx, grad_fx = jax.value_and_grad(f)(x)
    gx, grad_gx = jax.value_and_grad(g)(x)
    assert_allclose(fx, gx)
    assert_allclose(grad_fx, grad_gx, atol=1e-4)


########################################
# Tests for constraints and transforms #
######################################## 
開發者ID:pyro-ppl,項目名稱:numpyro,代碼行數:21,代碼來源:test_distributions.py

示例2: update

# 需要導入模塊: import jax [as 別名]
# 或者: from jax import value_and_grad [as 別名]
def update(self, svi_state, *args, **kwargs):
        """
        Take a single step of SVI (possibly on a batch / minibatch of data),
        using the optimizer.

        :param svi_state: current state of SVI.
        :param args: arguments to the model / guide (these can possibly vary during
            the course of fitting).
        :param kwargs: keyword arguments to the model / guide (these can possibly vary
            during the course of fitting).
        :return: tuple of `(svi_state, loss)`.
        """
        rng_key, rng_key_step = random.split(svi_state.rng_key)
        params = self.optim.get_params(svi_state.optim_state)
        loss_val, grads = value_and_grad(
            lambda x: self.loss.loss(rng_key_step, self.constrain_fn(x), self.model, self.guide,
                                     *args, **kwargs, **self.static_kwargs))(params)
        optim_state = self.optim.update(grads, svi_state.optim_state)
        return SVIState(optim_state, rng_key), loss_val 
開發者ID:pyro-ppl,項目名稱:numpyro,代碼行數:21,代碼來源:svi.py

示例3: test_renyi_elbo

# 需要導入模塊: import jax [as 別名]
# 或者: from jax import value_and_grad [as 別名]
def test_renyi_elbo(alpha):
    def model(x):
        numpyro.sample("obs", dist.Normal(0, 1), obs=x)

    def guide(x):
        pass

    def elbo_loss_fn(x):
        return ELBO().loss(random.PRNGKey(0), {}, model, guide, x)

    def renyi_loss_fn(x):
        return RenyiELBO(alpha=alpha, num_particles=10).loss(random.PRNGKey(0), {}, model, guide, x)

    elbo_loss, elbo_grad = value_and_grad(elbo_loss_fn)(2.)
    renyi_loss, renyi_grad = value_and_grad(renyi_loss_fn)(2.)
    assert_allclose(elbo_loss, renyi_loss, rtol=1e-6)
    assert_allclose(elbo_grad, renyi_grad, rtol=1e-6) 
開發者ID:pyro-ppl,項目名稱:numpyro,代碼行數:19,代碼來源:test_svi.py

示例4: _check_sample

# 需要導入模塊: import jax [as 別名]
# 或者: from jax import value_and_grad [as 別名]
def _check_sample(funsor_dist_class, params, sample_inputs, inputs, atol=1e-2,
                  num_samples=100000, statistic="mean", skip_grad=False, with_lazy=None):
    """utility that compares a Monte Carlo estimate of a distribution mean with the true mean"""
    samples_per_dim = int(num_samples ** (1./max(1, len(sample_inputs))))
    sample_inputs = OrderedDict((k, bint(samples_per_dim)) for k in sample_inputs)
    _get_stat_diff_fn = functools.partial(
        _get_stat_diff, funsor_dist_class, sample_inputs, inputs, num_samples, statistic, with_lazy)

    if get_backend() == "torch":
        import torch

        for param in params:
            param.requires_grad_()

        res = _get_stat_diff_fn(params)
        if sample_inputs:
            diff_sum, diff = res
            assert_close(diff, ops.new_zeros(diff, diff.shape), atol=atol, rtol=None)
            if not skip_grad:
                diff_grads = torch.autograd.grad(diff_sum, params, allow_unused=True)
                for diff_grad in diff_grads:
                    assert_close(diff_grad, ops.new_zeros(diff_grad, diff_grad.shape), atol=atol, rtol=None)
    elif get_backend() == "jax":
        import jax

        if sample_inputs:
            if skip_grad:
                _, diff = _get_stat_diff_fn(params)
                assert_close(diff, ops.new_zeros(diff, diff.shape), atol=atol, rtol=None)
            else:
                (_, diff), diff_grads = jax.value_and_grad(_get_stat_diff_fn, has_aux=True)(params)
                assert_close(diff, ops.new_zeros(diff, diff.shape), atol=atol, rtol=None)
                for diff_grad in diff_grads:
                    assert_close(diff_grad, ops.new_zeros(diff_grad, diff_grad.shape), atol=atol, rtol=None)
        else:
            _get_stat_diff_fn(params) 
開發者ID:pyro-ppl,項目名稱:funsor,代碼行數:38,代碼來源:test_distribution.py

示例5: _update_fun

# 需要導入模塊: import jax [as 別名]
# 或者: from jax import value_and_grad [as 別名]
def _update_fun(self, loss_fun, return_loss=False):
        def update(state, *inputs, **kwargs):
            params = self.get_parameters(state)
            if return_loss:
                loss, gradient = value_and_grad(loss_fun)(params, *inputs, **kwargs)
                return self.update_from_gradients(gradient, state), loss
            else:
                gradient = grad(loss_fun)(params, *inputs, **kwargs)
                return self.update_from_gradients(gradient, state)

        return update 
開發者ID:JuliusKunze,項目名稱:jaxnet,代碼行數:13,代碼來源:optimizers.py

示例6: velocity_verlet

# 需要導入模塊: import jax [as 別名]
# 或者: from jax import value_and_grad [as 別名]
def velocity_verlet(potential_fn, kinetic_fn):
    r"""
    Second order symplectic integrator that uses the velocity verlet algorithm
    for position `z` and momentum `r`.

    :param potential_fn: Python callable that computes the potential energy
        given input parameters. The input parameters to `potential_fn` can be
        any python collection type.
    :param kinetic_fn: Python callable that returns the kinetic energy given
        inverse mass matrix and momentum.
    :return: a pair of (`init_fn`, `update_fn`).
    """
    def init_fn(z, r, potential_energy=None, z_grad=None):
        """
        :param z: Position of the particle.
        :param r: Momentum of the particle.
        :param potential_energy: Potential energy at `z`.
        :param z_grad: gradient of potential energy at `z`.
        :return: initial state for the integrator.
        """
        if potential_energy is None or z_grad is None:
            potential_energy, z_grad = value_and_grad(potential_fn)(z)
        return IntegratorState(z, r, potential_energy, z_grad)

    def update_fn(step_size, inverse_mass_matrix, state):
        """
        :param float step_size: Size of a single step.
        :param inverse_mass_matrix: Inverse of mass matrix, which is used to
            calculate kinetic energy.
        :param state: Current state of the integrator.
        :return: new state for the integrator.
        """
        z, r, _, z_grad = state
        r = tree_multimap(lambda r, z_grad: r - 0.5 * step_size * z_grad, r, z_grad)  # r(n+1/2)
        r_grad = grad(kinetic_fn, argnums=1)(inverse_mass_matrix, r)
        z = tree_multimap(lambda z, r_grad: z + step_size * r_grad, z, r_grad)  # z(n+1)
        potential_energy, z_grad = value_and_grad(potential_fn)(z)
        r = tree_multimap(lambda r, z_grad: r - 0.5 * step_size * z_grad, r, z_grad)  # r(n+1)
        return IntegratorState(z, r, potential_energy, z_grad)

    return init_fn, update_fn 
開發者ID:pyro-ppl,項目名稱:numpyro,代碼行數:43,代碼來源:hmc_util.py

示例7: find_reasonable_step_size

# 需要導入模塊: import jax [as 別名]
# 或者: from jax import value_and_grad [as 別名]
def find_reasonable_step_size(potential_fn, kinetic_fn, momentum_generator,
                              init_step_size, inverse_mass_matrix, position, rng_key):
    """
    Finds a reasonable step size by tuning `init_step_size`. This function is used
    to avoid working with a too large or too small step size in HMC.

    **References:**

    1. *The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo*,
       Matthew D. Hoffman, Andrew Gelman

    :param potential_fn: A callable to compute potential energy.
    :param kinetic_fn: A callable to compute kinetic energy.
    :param momentum_generator: A generator to get a random momentum variable.
    :param float init_step_size: Initial step size to be tuned.
    :param inverse_mass_matrix: Inverse of mass matrix.
    :param position: Current position of the particle.
    :param jax.random.PRNGKey rng_key: Random key to be used as the source of randomness.
    :return: a reasonable value for step size.
    :rtype: float
    """
    # We are going to find a step_size which make accept_prob (Metropolis correction)
    # near the target_accept_prob. If accept_prob:=exp(-delta_energy) is small,
    # then we have to decrease step_size; otherwise, increase step_size.
    target_accept_prob = jnp.log(0.8)

    _, vv_update = velocity_verlet(potential_fn, kinetic_fn)
    z = position
    potential_energy, z_grad = value_and_grad(potential_fn)(z)
    finfo = jnp.finfo(get_dtype(init_step_size))

    def _body_fn(state):
        step_size, _, direction, rng_key = state
        rng_key, rng_key_momentum = random.split(rng_key)
        # scale step_size: increase 2x or decrease 2x depends on direction;
        # direction=1 means keep increasing step_size, otherwise decreasing step_size.
        # Note that the direction is -1 if delta_energy is `NaN`, which may be the
        # case for a diverging trajectory (e.g. in the case of evaluating log prob
        # of a value simulated using a large step size for a constrained sample site).
        step_size = (2.0 ** direction) * step_size
        r = momentum_generator(position, inverse_mass_matrix, rng_key_momentum)
        _, r_new, potential_energy_new, _ = vv_update(step_size,
                                                      inverse_mass_matrix,
                                                      (z, r, potential_energy, z_grad))
        energy_current = kinetic_fn(inverse_mass_matrix, r) + potential_energy
        energy_new = kinetic_fn(inverse_mass_matrix, r_new) + potential_energy_new
        delta_energy = energy_new - energy_current
        direction_new = jnp.where(target_accept_prob < -delta_energy, 1, -1)
        return step_size, direction, direction_new, rng_key

    def _cond_fn(state):
        step_size, last_direction, direction, _ = state
        # condition to run only if step_size is not too small or we are not decreasing step_size
        not_small_step_size_cond = (step_size > finfo.tiny) | (direction >= 0)
        # condition to run only if step_size is not too large or we are not increasing step_size
        not_large_step_size_cond = (step_size < finfo.max) | (direction <= 0)
        not_extreme_cond = not_small_step_size_cond & not_large_step_size_cond
        return not_extreme_cond & ((last_direction == 0) | (direction == last_direction))

    step_size, _, _, _ = while_loop(_cond_fn, _body_fn, (init_step_size, 0, 0, rng_key))
    return step_size 
開發者ID:pyro-ppl,項目名稱:numpyro,代碼行數:63,代碼來源:hmc_util.py


注:本文中的jax.value_and_grad方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。