当前位置: 首页>>代码示例>>Python>>正文


Python jax.grad方法代码示例

本文整理汇总了Python中jax.grad方法的典型用法代码示例。如果您正苦于以下问题:Python jax.grad方法的具体用法?Python jax.grad怎么用?Python jax.grad使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在jax的用法示例。


在下文中一共展示了jax.grad方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __init__

# 需要导入模块: import jax [as 别名]
# 或者: from jax import grad [as 别名]
def __init__(self, obs_dim, *, seed=None):
        """Internal setup for Jax-based reward models.

        Initialises reward model using given seed & input size (`obs_dim`).

        Args:
            obs_dim (int): dimensionality of observation space.
            seed (int or None): random seed for generating initial params. If
                None, seed will be chosen arbitrarily, as in
                LinearRewardModel.
        """
        # TODO: apply jax.jit() to everything in sight
        net_init, self._net_apply = self.make_stax_model()
        if seed is None:
            # oh well
            seed = np.random.randint((1 << 63) - 1)
        rng = jrandom.PRNGKey(seed)
        out_shape, self._net_params = net_init(rng, (-1, obs_dim))
        self._net_grads = jax.grad(self._net_apply)
        # output shape should just be batch dim, nothing else
        assert out_shape == (-1,), "got a weird output shape %s" % (out_shape,) 
开发者ID:HumanCompatibleAI,项目名称:imitation,代码行数:23,代码来源:tabular_irl.py

示例2: test_dual_averaging

# 需要导入模块: import jax [as 别名]
# 或者: from jax import grad [as 别名]
def test_dual_averaging(jitted):
    def optimize(f):
        da_init, da_update = dual_averaging(gamma=0.5)
        da_state = da_init()
        for i in range(10):
            x = da_state[0]
            g = grad(f)(x)
            da_state = da_update(g, da_state)
        x_avg = da_state[1]
        return x_avg

    f = lambda x: (x + 1) ** 2  # noqa: E731
    fn = jit(optimize, static_argnums=(0,)) if jitted else optimize
    x_opt = fn(f)

    assert_allclose(x_opt, -1., atol=1e-3) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:18,代码来源:test_hmc_util.py

示例3: grad

# 需要导入模块: import jax [as 别名]
# 或者: from jax import grad [as 别名]
def grad(*args, **kwargs):
  return backend()["grad"](*args, **kwargs) 
开发者ID:yyht,项目名称:BERT,代码行数:4,代码来源:backend.py

示例4: test_tensor_distribution

# 需要导入模块: import jax [as 别名]
# 或者: from jax import grad [as 别名]
def test_tensor_distribution(event_inputs, batch_inputs, test_grad):
    num_samples = 50000
    sample_inputs = OrderedDict(n=bint(num_samples))
    be_inputs = OrderedDict(batch_inputs + event_inputs)
    batch_inputs = OrderedDict(batch_inputs)
    event_inputs = OrderedDict(event_inputs)
    sampled_vars = frozenset(event_inputs)
    p_data = random_tensor(be_inputs).data
    rng_key = None if get_backend() == "torch" else np.array([0, 0], dtype=np.uint32)
    probe = randn(p_data.shape)

    def diff_fn(p_data):
        p = Tensor(p_data, be_inputs)
        q = p.sample(sampled_vars, sample_inputs, rng_key=rng_key)
        mq = p.materialize(q).reduce(ops.logaddexp, 'n')
        mq = mq.align(tuple(p.inputs))

        _, (p_data, mq_data) = align_tensors(p, mq)
        assert p_data.shape == mq_data.shape
        return (ops.exp(mq_data) * probe).sum() - (ops.exp(p_data) * probe).sum(), mq

    if test_grad:
        if get_backend() == "jax":
            import jax

            diff_grad, mq = jax.grad(diff_fn, has_aux=True)(p_data)
        else:
            import torch

            p_data.requires_grad_(True)
            diff_grad = torch.autograd.grad(diff_fn(p_data)[0], [p_data])[0]

        assert_close(diff_grad, ops.new_zeros(diff_grad, diff_grad.shape), atol=0.1, rtol=None)
    else:
        _, mq = diff_fn(p_data)
        assert_close(mq, Tensor(p_data, be_inputs), atol=0.1, rtol=None) 
开发者ID:pyro-ppl,项目名称:funsor,代码行数:38,代码来源:test_samplers.py

示例5: test_JAX

# 需要导入模块: import jax [as 别名]
# 或者: from jax import grad [as 别名]
def test_JAX(self):
        # importing inside the gpu-only test because these packages can't be
        # imported on the CPU image since they are not present there.
        from jax import grad, jit

        grad_tanh = grad(self.tanh)
        ag = grad_tanh(1.0)
        self.assertEqual(0.4199743, ag) 
开发者ID:Kaggle,项目名称:docker-python,代码行数:10,代码来源:test_jax.py

示例6: _update_fun

# 需要导入模块: import jax [as 别名]
# 或者: from jax import grad [as 别名]
def _update_fun(self, loss_fun, return_loss=False):
        def update(state, *inputs, **kwargs):
            params = self.get_parameters(state)
            if return_loss:
                loss, gradient = value_and_grad(loss_fun)(params, *inputs, **kwargs)
                return self.update_from_gradients(gradient, state), loss
            else:
                gradient = grad(loss_fun)(params, *inputs, **kwargs)
                return self.update_from_gradients(gradient, state)

        return update 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:13,代码来源:optimizers.py

示例7: update

# 需要导入模块: import jax [as 别名]
# 或者: from jax import grad [as 别名]
def update(i, opt_state, batch):
    params = get_params(opt_state)
    grad_loss = grad(loss)
    g = grad_loss(params, batch)
    return opt_update(i, g, opt_state) 
开发者ID:sharadmv,项目名称:deepx,代码行数:7,代码来源:lstm.py

示例8: update

# 需要导入模块: import jax [as 别名]
# 或者: from jax import grad [as 别名]
def update(i, opt_state, batch):
    params = optimizers.get_params(opt_state)
    return opt_update(i, grad(loss)(params, batch), opt_state) 
开发者ID:sharadmv,项目名称:deepx,代码行数:5,代码来源:mnist_classifier.py

示例9: velocity_verlet

# 需要导入模块: import jax [as 别名]
# 或者: from jax import grad [as 别名]
def velocity_verlet(potential_fn, kinetic_fn):
    r"""
    Second order symplectic integrator that uses the velocity verlet algorithm
    for position `z` and momentum `r`.

    :param potential_fn: Python callable that computes the potential energy
        given input parameters. The input parameters to `potential_fn` can be
        any python collection type.
    :param kinetic_fn: Python callable that returns the kinetic energy given
        inverse mass matrix and momentum.
    :return: a pair of (`init_fn`, `update_fn`).
    """
    def init_fn(z, r, potential_energy=None, z_grad=None):
        """
        :param z: Position of the particle.
        :param r: Momentum of the particle.
        :param potential_energy: Potential energy at `z`.
        :param z_grad: gradient of potential energy at `z`.
        :return: initial state for the integrator.
        """
        if potential_energy is None or z_grad is None:
            potential_energy, z_grad = value_and_grad(potential_fn)(z)
        return IntegratorState(z, r, potential_energy, z_grad)

    def update_fn(step_size, inverse_mass_matrix, state):
        """
        :param float step_size: Size of a single step.
        :param inverse_mass_matrix: Inverse of mass matrix, which is used to
            calculate kinetic energy.
        :param state: Current state of the integrator.
        :return: new state for the integrator.
        """
        z, r, _, z_grad = state
        r = tree_multimap(lambda r, z_grad: r - 0.5 * step_size * z_grad, r, z_grad)  # r(n+1/2)
        r_grad = grad(kinetic_fn, argnums=1)(inverse_mass_matrix, r)
        z = tree_multimap(lambda z, r_grad: z + step_size * r_grad, z, r_grad)  # z(n+1)
        potential_energy, z_grad = value_and_grad(potential_fn)(z)
        r = tree_multimap(lambda r, z_grad: r - 0.5 * step_size * z_grad, r, z_grad)  # r(n+1)
        return IntegratorState(z, r, potential_energy, z_grad)

    return init_fn, update_fn 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:43,代码来源:hmc_util.py

示例10: step

# 需要导入模块: import jax [as 别名]
# 或者: from jax import grad [as 别名]
def step(opt_state, optim):
    params = optim.get_params(opt_state)
    g = grad(loss)(params)
    return optim.update(g, opt_state) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:6,代码来源:test_optimizers.py

示例11: test_sample_gradient

# 需要导入模块: import jax [as 别名]
# 或者: from jax import grad [as 别名]
def test_sample_gradient(jax_dist, sp_dist, params):
    if not jax_dist.reparametrized_params:
        pytest.skip('{} not reparametrized.'.format(jax_dist.__name__))

    dist_args = [p.name for p in inspect.signature(jax_dist).parameters.values()]
    params_dict = dict(zip(dist_args[:len(params)], params))
    nonrepara_params_dict = {k: v for k, v in params_dict.items()
                             if k not in jax_dist.reparametrized_params}
    repara_params = tuple(v for k, v in params_dict.items()
                          if k in jax_dist.reparametrized_params)

    rng_key = random.PRNGKey(0)

    def fn(args):
        args_dict = dict(zip(jax_dist.reparametrized_params, args))
        return jnp.sum(jax_dist(**args_dict, **nonrepara_params_dict).sample(key=rng_key))

    actual_grad = jax.grad(fn)(repara_params)
    assert len(actual_grad) == len(repara_params)

    eps = 1e-3
    for i in range(len(repara_params)):
        if repara_params[i] is None:
            continue
        args_lhs = [p if j != i else p - eps for j, p in enumerate(repara_params)]
        args_rhs = [p if j != i else p + eps for j, p in enumerate(repara_params)]
        fn_lhs = fn(args_lhs)
        fn_rhs = fn(args_rhs)
        # finite diff approximation
        expected_grad = (fn_rhs - fn_lhs) / (2. * eps)
        assert jnp.shape(actual_grad[i]) == jnp.shape(repara_params[i])
        assert_allclose(jnp.sum(actual_grad[i]), expected_grad, rtol=0.02) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:34,代码来源:test_distributions.py

示例12: test_log_prob_LKJCholesky

# 需要导入模块: import jax [as 别名]
# 或者: from jax import grad [as 别名]
def test_log_prob_LKJCholesky(dimension, concentration):
    # We will test against the fact that LKJCorrCholesky can be seen as a
    # TransformedDistribution with base distribution is a distribution of partial
    # correlations in C-vine method (modulo an affine transform to change domain from (0, 1)
    # to (1, 0)) and transform is a signed stick-breaking process.
    d = dist.LKJCholesky(dimension, concentration, sample_method="cvine")

    beta_sample = d._beta.sample(random.PRNGKey(0))
    beta_log_prob = jnp.sum(d._beta.log_prob(beta_sample))
    partial_correlation = 2 * beta_sample - 1
    affine_logdet = beta_sample.shape[-1] * jnp.log(2)
    sample = signed_stick_breaking_tril(partial_correlation)

    # compute signed stick breaking logdet
    inv_tanh = lambda t: jnp.log((1 + t) / (1 - t)) / 2  # noqa: E731
    inv_tanh_logdet = jnp.sum(jnp.log(vmap(grad(inv_tanh))(partial_correlation)))
    unconstrained = inv_tanh(partial_correlation)
    corr_cholesky_logdet = biject_to(constraints.corr_cholesky).log_abs_det_jacobian(
        unconstrained,
        sample,
    )
    signed_stick_breaking_logdet = corr_cholesky_logdet + inv_tanh_logdet

    actual_log_prob = d.log_prob(sample)
    expected_log_prob = beta_log_prob - affine_logdet - signed_stick_breaking_logdet
    assert_allclose(actual_log_prob, expected_log_prob, rtol=2e-5)

    assert_allclose(jax.jit(d.log_prob)(sample), d.log_prob(sample), atol=1e-7) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:30,代码来源:test_distributions.py

示例13: test_bijective_transforms

# 需要导入模块: import jax [as 别名]
# 或者: from jax import grad [as 别名]
def test_bijective_transforms(transform, event_shape, batch_shape):
    shape = batch_shape + event_shape
    rng_key = random.PRNGKey(0)
    x = biject_to(transform.domain)(random.normal(rng_key, shape))
    y = transform(x)

    # test codomain
    assert_array_equal(transform.codomain(y), jnp.ones(batch_shape))

    # test inv
    z = transform.inv(y)
    assert_allclose(x, z, atol=1e-6, rtol=1e-6)

    # test domain
    assert_array_equal(transform.domain(z), jnp.ones(batch_shape))

    # test log_abs_det_jacobian
    actual = transform.log_abs_det_jacobian(x, y)
    assert jnp.shape(actual) == batch_shape
    if len(shape) == transform.event_dim:
        if len(event_shape) == 1:
            expected = np.linalg.slogdet(jax.jacobian(transform)(x))[1]
            inv_expected = np.linalg.slogdet(jax.jacobian(transform.inv)(y))[1]
        else:
            expected = jnp.log(jnp.abs(grad(transform)(x)))
            inv_expected = jnp.log(jnp.abs(grad(transform.inv)(y)))

        assert_allclose(actual, expected, atol=1e-6)
        assert_allclose(actual, -inv_expected, atol=1e-6) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:31,代码来源:test_distributions.py

示例14: policy_and_value_opt_step

# 需要导入模块: import jax [as 别名]
# 或者: from jax import grad [as 别名]
def policy_and_value_opt_step(i,
                              opt_state,
                              opt_update,
                              get_params,
                              policy_and_value_net_apply,
                              log_probab_actions_old,
                              value_predictions_old,
                              padded_observations,
                              padded_actions,
                              padded_rewards,
                              reward_mask,
                              c1=1.0,
                              c2=0.01,
                              gamma=0.99,
                              lambda_=0.95,
                              epsilon=0.1,
                              rng=None):
  """Policy and Value optimizer step."""

  # Combined loss function given the new params.
  def policy_and_value_loss(params):
    """Returns the combined loss given just parameters."""
    (loss, _, _, _) = combined_loss(
        params,
        log_probab_actions_old,
        value_predictions_old,
        policy_and_value_net_apply,
        padded_observations,
        padded_actions,
        padded_rewards,
        reward_mask,
        c1=c1,
        c2=c2,
        gamma=gamma,
        lambda_=lambda_,
        epsilon=epsilon,
        rng=rng)
    return loss

  new_params = get_params(opt_state)
  g = grad(policy_and_value_loss)(new_params)
  # TODO(afrozm): Maybe clip gradients?
  return opt_update(i, g, opt_state) 
开发者ID:yyht,项目名称:BERT,代码行数:45,代码来源:ppo.py

示例15: test_reformer_lm_memory

# 需要导入模块: import jax [as 别名]
# 或者: from jax import grad [as 别名]
def test_reformer_lm_memory(self):
    lsh_self_attention = functools.partial(
        tl.LSHSelfAttention,
        attention_dropout=0.0,
        chunk_len=64,
        n_buckets=[128, 128],
        n_chunks_after=0,
        n_chunks_before=1,
        n_hashes=1,
        n_parallel_heads=1,
        predict_drop_len=128,
        predict_mem_len=1024,
    )
    timebin_self_attention = functools.partial(
        tl.SelfAttention,
        attention_dropout=0.05,
        chunk_len=64,
        n_chunks_before=1,
        n_parallel_heads=1,
    )

    model = reformer.ReformerLM(
        vocab_size=256,
        d_model=256,
        d_ff=512,
        d_attention_key=64,
        d_attention_value=64,
        n_layers=6,
        n_heads=2,
        dropout=0.05,
        max_len=1048576,
        attention_type=[timebin_self_attention, lsh_self_attention],
        axial_pos_shape=(1024, 1024),
        d_axial_pos_embs=(64, 192),
        ff_activation=tl.Relu,
        ff_use_sru=0,
        ff_chunk_size=131072,
        mode='train',
    )
    x = np.ones((1, 1048576)).astype(np.int32)
    weights, state = model.init(shapes.signature(x))

    @jax.jit
    def mock_training_step(x, weights, state, rng):
      def compute_mock_loss(weights):
        logits, new_state = model.pure_fn(x, weights, state, rng)
        loss = jnp.mean(logits[..., 0])
        return loss, (new_state, logits)
      gradients, (new_state, logits) = jax.grad(
          compute_mock_loss, has_aux=True)(weights)
      new_weights = fastmath.nested_map_multiarg(
          lambda w, g: w - 1e-4 * g, weights, gradients)
      return new_weights, new_state, logits

    weights, state, logits = mock_training_step(
        x, weights, state, jax.random.PRNGKey(0))
    self.assertEqual(logits.shape, (1, 1048576, 256)) 
开发者ID:google,项目名称:trax,代码行数:59,代码来源:reformer_memory_test.py


注:本文中的jax.grad方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。