本文整理汇总了Python中jax.grad方法的典型用法代码示例。如果您正苦于以下问题:Python jax.grad方法的具体用法?Python jax.grad怎么用?Python jax.grad使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类jax
的用法示例。
在下文中一共展示了jax.grad方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: __init__
# 需要导入模块: import jax [as 别名]
# 或者: from jax import grad [as 别名]
def __init__(self, obs_dim, *, seed=None):
"""Internal setup for Jax-based reward models.
Initialises reward model using given seed & input size (`obs_dim`).
Args:
obs_dim (int): dimensionality of observation space.
seed (int or None): random seed for generating initial params. If
None, seed will be chosen arbitrarily, as in
LinearRewardModel.
"""
# TODO: apply jax.jit() to everything in sight
net_init, self._net_apply = self.make_stax_model()
if seed is None:
# oh well
seed = np.random.randint((1 << 63) - 1)
rng = jrandom.PRNGKey(seed)
out_shape, self._net_params = net_init(rng, (-1, obs_dim))
self._net_grads = jax.grad(self._net_apply)
# output shape should just be batch dim, nothing else
assert out_shape == (-1,), "got a weird output shape %s" % (out_shape,)
示例2: test_dual_averaging
# 需要导入模块: import jax [as 别名]
# 或者: from jax import grad [as 别名]
def test_dual_averaging(jitted):
def optimize(f):
da_init, da_update = dual_averaging(gamma=0.5)
da_state = da_init()
for i in range(10):
x = da_state[0]
g = grad(f)(x)
da_state = da_update(g, da_state)
x_avg = da_state[1]
return x_avg
f = lambda x: (x + 1) ** 2 # noqa: E731
fn = jit(optimize, static_argnums=(0,)) if jitted else optimize
x_opt = fn(f)
assert_allclose(x_opt, -1., atol=1e-3)
示例3: grad
# 需要导入模块: import jax [as 别名]
# 或者: from jax import grad [as 别名]
def grad(*args, **kwargs):
return backend()["grad"](*args, **kwargs)
示例4: test_tensor_distribution
# 需要导入模块: import jax [as 别名]
# 或者: from jax import grad [as 别名]
def test_tensor_distribution(event_inputs, batch_inputs, test_grad):
num_samples = 50000
sample_inputs = OrderedDict(n=bint(num_samples))
be_inputs = OrderedDict(batch_inputs + event_inputs)
batch_inputs = OrderedDict(batch_inputs)
event_inputs = OrderedDict(event_inputs)
sampled_vars = frozenset(event_inputs)
p_data = random_tensor(be_inputs).data
rng_key = None if get_backend() == "torch" else np.array([0, 0], dtype=np.uint32)
probe = randn(p_data.shape)
def diff_fn(p_data):
p = Tensor(p_data, be_inputs)
q = p.sample(sampled_vars, sample_inputs, rng_key=rng_key)
mq = p.materialize(q).reduce(ops.logaddexp, 'n')
mq = mq.align(tuple(p.inputs))
_, (p_data, mq_data) = align_tensors(p, mq)
assert p_data.shape == mq_data.shape
return (ops.exp(mq_data) * probe).sum() - (ops.exp(p_data) * probe).sum(), mq
if test_grad:
if get_backend() == "jax":
import jax
diff_grad, mq = jax.grad(diff_fn, has_aux=True)(p_data)
else:
import torch
p_data.requires_grad_(True)
diff_grad = torch.autograd.grad(diff_fn(p_data)[0], [p_data])[0]
assert_close(diff_grad, ops.new_zeros(diff_grad, diff_grad.shape), atol=0.1, rtol=None)
else:
_, mq = diff_fn(p_data)
assert_close(mq, Tensor(p_data, be_inputs), atol=0.1, rtol=None)
示例5: test_JAX
# 需要导入模块: import jax [as 别名]
# 或者: from jax import grad [as 别名]
def test_JAX(self):
# importing inside the gpu-only test because these packages can't be
# imported on the CPU image since they are not present there.
from jax import grad, jit
grad_tanh = grad(self.tanh)
ag = grad_tanh(1.0)
self.assertEqual(0.4199743, ag)
示例6: _update_fun
# 需要导入模块: import jax [as 别名]
# 或者: from jax import grad [as 别名]
def _update_fun(self, loss_fun, return_loss=False):
def update(state, *inputs, **kwargs):
params = self.get_parameters(state)
if return_loss:
loss, gradient = value_and_grad(loss_fun)(params, *inputs, **kwargs)
return self.update_from_gradients(gradient, state), loss
else:
gradient = grad(loss_fun)(params, *inputs, **kwargs)
return self.update_from_gradients(gradient, state)
return update
示例7: update
# 需要导入模块: import jax [as 别名]
# 或者: from jax import grad [as 别名]
def update(i, opt_state, batch):
params = get_params(opt_state)
grad_loss = grad(loss)
g = grad_loss(params, batch)
return opt_update(i, g, opt_state)
示例8: update
# 需要导入模块: import jax [as 别名]
# 或者: from jax import grad [as 别名]
def update(i, opt_state, batch):
params = optimizers.get_params(opt_state)
return opt_update(i, grad(loss)(params, batch), opt_state)
示例9: velocity_verlet
# 需要导入模块: import jax [as 别名]
# 或者: from jax import grad [as 别名]
def velocity_verlet(potential_fn, kinetic_fn):
r"""
Second order symplectic integrator that uses the velocity verlet algorithm
for position `z` and momentum `r`.
:param potential_fn: Python callable that computes the potential energy
given input parameters. The input parameters to `potential_fn` can be
any python collection type.
:param kinetic_fn: Python callable that returns the kinetic energy given
inverse mass matrix and momentum.
:return: a pair of (`init_fn`, `update_fn`).
"""
def init_fn(z, r, potential_energy=None, z_grad=None):
"""
:param z: Position of the particle.
:param r: Momentum of the particle.
:param potential_energy: Potential energy at `z`.
:param z_grad: gradient of potential energy at `z`.
:return: initial state for the integrator.
"""
if potential_energy is None or z_grad is None:
potential_energy, z_grad = value_and_grad(potential_fn)(z)
return IntegratorState(z, r, potential_energy, z_grad)
def update_fn(step_size, inverse_mass_matrix, state):
"""
:param float step_size: Size of a single step.
:param inverse_mass_matrix: Inverse of mass matrix, which is used to
calculate kinetic energy.
:param state: Current state of the integrator.
:return: new state for the integrator.
"""
z, r, _, z_grad = state
r = tree_multimap(lambda r, z_grad: r - 0.5 * step_size * z_grad, r, z_grad) # r(n+1/2)
r_grad = grad(kinetic_fn, argnums=1)(inverse_mass_matrix, r)
z = tree_multimap(lambda z, r_grad: z + step_size * r_grad, z, r_grad) # z(n+1)
potential_energy, z_grad = value_and_grad(potential_fn)(z)
r = tree_multimap(lambda r, z_grad: r - 0.5 * step_size * z_grad, r, z_grad) # r(n+1)
return IntegratorState(z, r, potential_energy, z_grad)
return init_fn, update_fn
示例10: step
# 需要导入模块: import jax [as 别名]
# 或者: from jax import grad [as 别名]
def step(opt_state, optim):
params = optim.get_params(opt_state)
g = grad(loss)(params)
return optim.update(g, opt_state)
示例11: test_sample_gradient
# 需要导入模块: import jax [as 别名]
# 或者: from jax import grad [as 别名]
def test_sample_gradient(jax_dist, sp_dist, params):
if not jax_dist.reparametrized_params:
pytest.skip('{} not reparametrized.'.format(jax_dist.__name__))
dist_args = [p.name for p in inspect.signature(jax_dist).parameters.values()]
params_dict = dict(zip(dist_args[:len(params)], params))
nonrepara_params_dict = {k: v for k, v in params_dict.items()
if k not in jax_dist.reparametrized_params}
repara_params = tuple(v for k, v in params_dict.items()
if k in jax_dist.reparametrized_params)
rng_key = random.PRNGKey(0)
def fn(args):
args_dict = dict(zip(jax_dist.reparametrized_params, args))
return jnp.sum(jax_dist(**args_dict, **nonrepara_params_dict).sample(key=rng_key))
actual_grad = jax.grad(fn)(repara_params)
assert len(actual_grad) == len(repara_params)
eps = 1e-3
for i in range(len(repara_params)):
if repara_params[i] is None:
continue
args_lhs = [p if j != i else p - eps for j, p in enumerate(repara_params)]
args_rhs = [p if j != i else p + eps for j, p in enumerate(repara_params)]
fn_lhs = fn(args_lhs)
fn_rhs = fn(args_rhs)
# finite diff approximation
expected_grad = (fn_rhs - fn_lhs) / (2. * eps)
assert jnp.shape(actual_grad[i]) == jnp.shape(repara_params[i])
assert_allclose(jnp.sum(actual_grad[i]), expected_grad, rtol=0.02)
示例12: test_log_prob_LKJCholesky
# 需要导入模块: import jax [as 别名]
# 或者: from jax import grad [as 别名]
def test_log_prob_LKJCholesky(dimension, concentration):
# We will test against the fact that LKJCorrCholesky can be seen as a
# TransformedDistribution with base distribution is a distribution of partial
# correlations in C-vine method (modulo an affine transform to change domain from (0, 1)
# to (1, 0)) and transform is a signed stick-breaking process.
d = dist.LKJCholesky(dimension, concentration, sample_method="cvine")
beta_sample = d._beta.sample(random.PRNGKey(0))
beta_log_prob = jnp.sum(d._beta.log_prob(beta_sample))
partial_correlation = 2 * beta_sample - 1
affine_logdet = beta_sample.shape[-1] * jnp.log(2)
sample = signed_stick_breaking_tril(partial_correlation)
# compute signed stick breaking logdet
inv_tanh = lambda t: jnp.log((1 + t) / (1 - t)) / 2 # noqa: E731
inv_tanh_logdet = jnp.sum(jnp.log(vmap(grad(inv_tanh))(partial_correlation)))
unconstrained = inv_tanh(partial_correlation)
corr_cholesky_logdet = biject_to(constraints.corr_cholesky).log_abs_det_jacobian(
unconstrained,
sample,
)
signed_stick_breaking_logdet = corr_cholesky_logdet + inv_tanh_logdet
actual_log_prob = d.log_prob(sample)
expected_log_prob = beta_log_prob - affine_logdet - signed_stick_breaking_logdet
assert_allclose(actual_log_prob, expected_log_prob, rtol=2e-5)
assert_allclose(jax.jit(d.log_prob)(sample), d.log_prob(sample), atol=1e-7)
示例13: test_bijective_transforms
# 需要导入模块: import jax [as 别名]
# 或者: from jax import grad [as 别名]
def test_bijective_transforms(transform, event_shape, batch_shape):
shape = batch_shape + event_shape
rng_key = random.PRNGKey(0)
x = biject_to(transform.domain)(random.normal(rng_key, shape))
y = transform(x)
# test codomain
assert_array_equal(transform.codomain(y), jnp.ones(batch_shape))
# test inv
z = transform.inv(y)
assert_allclose(x, z, atol=1e-6, rtol=1e-6)
# test domain
assert_array_equal(transform.domain(z), jnp.ones(batch_shape))
# test log_abs_det_jacobian
actual = transform.log_abs_det_jacobian(x, y)
assert jnp.shape(actual) == batch_shape
if len(shape) == transform.event_dim:
if len(event_shape) == 1:
expected = np.linalg.slogdet(jax.jacobian(transform)(x))[1]
inv_expected = np.linalg.slogdet(jax.jacobian(transform.inv)(y))[1]
else:
expected = jnp.log(jnp.abs(grad(transform)(x)))
inv_expected = jnp.log(jnp.abs(grad(transform.inv)(y)))
assert_allclose(actual, expected, atol=1e-6)
assert_allclose(actual, -inv_expected, atol=1e-6)
示例14: policy_and_value_opt_step
# 需要导入模块: import jax [as 别名]
# 或者: from jax import grad [as 别名]
def policy_and_value_opt_step(i,
opt_state,
opt_update,
get_params,
policy_and_value_net_apply,
log_probab_actions_old,
value_predictions_old,
padded_observations,
padded_actions,
padded_rewards,
reward_mask,
c1=1.0,
c2=0.01,
gamma=0.99,
lambda_=0.95,
epsilon=0.1,
rng=None):
"""Policy and Value optimizer step."""
# Combined loss function given the new params.
def policy_and_value_loss(params):
"""Returns the combined loss given just parameters."""
(loss, _, _, _) = combined_loss(
params,
log_probab_actions_old,
value_predictions_old,
policy_and_value_net_apply,
padded_observations,
padded_actions,
padded_rewards,
reward_mask,
c1=c1,
c2=c2,
gamma=gamma,
lambda_=lambda_,
epsilon=epsilon,
rng=rng)
return loss
new_params = get_params(opt_state)
g = grad(policy_and_value_loss)(new_params)
# TODO(afrozm): Maybe clip gradients?
return opt_update(i, g, opt_state)
示例15: test_reformer_lm_memory
# 需要导入模块: import jax [as 别名]
# 或者: from jax import grad [as 别名]
def test_reformer_lm_memory(self):
lsh_self_attention = functools.partial(
tl.LSHSelfAttention,
attention_dropout=0.0,
chunk_len=64,
n_buckets=[128, 128],
n_chunks_after=0,
n_chunks_before=1,
n_hashes=1,
n_parallel_heads=1,
predict_drop_len=128,
predict_mem_len=1024,
)
timebin_self_attention = functools.partial(
tl.SelfAttention,
attention_dropout=0.05,
chunk_len=64,
n_chunks_before=1,
n_parallel_heads=1,
)
model = reformer.ReformerLM(
vocab_size=256,
d_model=256,
d_ff=512,
d_attention_key=64,
d_attention_value=64,
n_layers=6,
n_heads=2,
dropout=0.05,
max_len=1048576,
attention_type=[timebin_self_attention, lsh_self_attention],
axial_pos_shape=(1024, 1024),
d_axial_pos_embs=(64, 192),
ff_activation=tl.Relu,
ff_use_sru=0,
ff_chunk_size=131072,
mode='train',
)
x = np.ones((1, 1048576)).astype(np.int32)
weights, state = model.init(shapes.signature(x))
@jax.jit
def mock_training_step(x, weights, state, rng):
def compute_mock_loss(weights):
logits, new_state = model.pure_fn(x, weights, state, rng)
loss = jnp.mean(logits[..., 0])
return loss, (new_state, logits)
gradients, (new_state, logits) = jax.grad(
compute_mock_loss, has_aux=True)(weights)
new_weights = fastmath.nested_map_multiarg(
lambda w, g: w - 1e-4 * g, weights, gradients)
return new_weights, new_state, logits
weights, state, logits = mock_training_step(
x, weights, state, jax.random.PRNGKey(0))
self.assertEqual(logits.shape, (1, 1048576, 256))