本文整理汇总了Python中jax.jit方法的典型用法代码示例。如果您正苦于以下问题:Python jax.jit方法的具体用法?Python jax.jit怎么用?Python jax.jit使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类jax
的用法示例。
在下文中一共展示了jax.jit方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: __init__
# 需要导入模块: import jax [as 别名]
# 或者: from jax import jit [as 别名]
def __init__(self, obs_dim, *, seed=None):
"""Internal setup for Jax-based reward models.
Initialises reward model using given seed & input size (`obs_dim`).
Args:
obs_dim (int): dimensionality of observation space.
seed (int or None): random seed for generating initial params. If
None, seed will be chosen arbitrarily, as in
LinearRewardModel.
"""
# TODO: apply jax.jit() to everything in sight
net_init, self._net_apply = self.make_stax_model()
if seed is None:
# oh well
seed = np.random.randint((1 << 63) - 1)
rng = jrandom.PRNGKey(seed)
out_shape, self._net_params = net_init(rng, (-1, obs_dim))
self._net_grads = jax.grad(self._net_apply)
# output shape should just be batch dim, nothing else
assert out_shape == (-1,), "got a weird output shape %s" % (out_shape,)
示例2: test_external_submodule
# 需要导入模块: import jax [as 别名]
# 或者: from jax import jit [as 别名]
def test_external_submodule():
layer = Dense(3)
@parametrized
def net(inputs):
return 2 * layer(inputs)
inputs = random_inputs((2,))
params = net.init_parameters(inputs, key=PRNGKey(0))
out = net.apply(params, inputs)
assert out.shape == (3,)
out_ = net.apply(params, inputs)
assert jnp.array_equal(out, out_)
out_ = net.apply(params, inputs, jit=True)
assert jnp.allclose(out, out_)
示例3: test_external_submodule2
# 需要导入模块: import jax [as 别名]
# 或者: from jax import jit [as 别名]
def test_external_submodule2():
layer = Dense(2, zeros, zeros)
@parametrized
def net(inputs):
return layer(inputs)
inputs = jnp.zeros((1, 2))
params = net.init_parameters(inputs, key=PRNGKey(0))
assert_parameters_equal(((jnp.zeros((2, 2)), jnp.zeros(2)),), params)
out = net.apply(params, inputs)
assert jnp.array_equal(jnp.zeros((1, 2)), out)
out_ = net.apply(params, inputs, jit=True)
assert jnp.array_equal(out, out_)
示例4: test_external_sequential_submodule
# 需要导入模块: import jax [as 别名]
# 或者: from jax import jit [as 别名]
def test_external_sequential_submodule():
layer = Sequential(Conv(4, (2, 2)), flatten, relu, Dense(3), relu, Dense(2),
Sequential(Dense(2), relu))
inputs = jnp.zeros((1, 5, 5, 2))
params = layer.init_parameters(inputs, key=PRNGKey(0))
assert (4,) == params.conv.bias.shape
assert (3,) == params.dense0.bias.shape
assert (3, 2) == params.dense1.kernel.shape
assert (2,) == params.dense1.bias.shape
assert (2,) == params.sequential.dense.bias.shape
out = layer.apply(params, inputs)
assert (1, 2) == out.shape
out_ = layer.apply(params, inputs, jit=True)
assert jnp.allclose(out, out_)
示例5: test_param_and_submodule_mixed
# 需要导入模块: import jax [as 别名]
# 或者: from jax import jit [as 别名]
def test_param_and_submodule_mixed():
@parametrized
def linear_map(inputs):
kernel = parameter((inputs.shape[-1], 2), zeros, 'kernel')
return jnp.dot(inputs, kernel)
@parametrized
def dense(inputs):
return linear_map(inputs) + parameter((2,), zeros, 'bias')
inputs = jnp.zeros((1, 3))
params = dense.init_parameters(inputs, key=PRNGKey(0))
assert (2,) == params.bias.shape
assert (3, 2) == params.linear_map.kernel.shape
out = dense.apply(params, inputs)
assert jnp.array_equal(jnp.zeros((1, 2)), out)
out_ = dense.apply(params, inputs, jit=True)
assert jnp.array_equal(out, out_)
示例6: test_mixed_up_execution_order
# 需要导入模块: import jax [as 别名]
# 或者: from jax import jit [as 别名]
def test_mixed_up_execution_order():
@parametrized
def dense(inputs):
bias = parameter((2,), zeros, 'bias')
kernel = parameter((inputs.shape[-1], 2), zeros, 'kernel')
return jnp.dot(inputs, kernel) + bias
inputs = jnp.zeros((1, 3))
params = dense.init_parameters(inputs, key=PRNGKey(0))
assert (2,) == params.bias.shape
assert (3, 2) == params.kernel.shape
out = dense.apply(params, inputs)
assert jnp.array_equal(jnp.zeros((1, 2)), out)
out_ = dense.apply(params, inputs, jit=True)
assert jnp.array_equal(out, out_)
示例7: test_parameters_from_subsubmodule
# 需要导入模块: import jax [as 别名]
# 或者: from jax import jit [as 别名]
def test_parameters_from_subsubmodule():
subsublayer = Dense(2)
sublayer = Sequential(subsublayer, relu)
net = Sequential(sublayer, jnp.sum)
inputs = jnp.zeros((1, 3))
params = net.init_parameters(inputs, key=PRNGKey(0))
out = net.apply(params, inputs)
subsublayer_params = subsublayer.init_parameters(inputs, key=PRNGKey(0))
params_ = net.parameters_from({subsublayer: subsublayer_params}, inputs)
assert_dense_parameters_equal(subsublayer_params, params_[0][0])
out_ = net.apply(params_, inputs)
assert out.shape == out_.shape
out_ = net.apply_from({subsublayer: subsublayer_params}, inputs)
assert out.shape == out_.shape
out_ = net.apply_from({subsublayer: subsublayer_params}, inputs, jit=True)
assert out.shape == out_.shape
示例8: test_parameters_from_top_level
# 需要导入模块: import jax [as 别名]
# 或者: from jax import jit [as 别名]
def test_parameters_from_top_level():
net = Dense(2)
inputs = jnp.zeros((1, 3))
params = net.init_parameters(inputs, key=PRNGKey(0))
out = net.apply(params, inputs)
params_ = net.parameters_from({net: params}, inputs)
assert_dense_parameters_equal(params, params_)
out_ = net.apply(params_, inputs)
assert jnp.array_equal(out, out_)
out_ = net.apply_from({net: params}, inputs)
assert jnp.array_equal(out, out_)
out_ = net.apply_from({net: params}, inputs, jit=True)
assert jnp.array_equal(out, out_)
示例9: test_jit_args
# 需要导入模块: import jax [as 别名]
# 或者: from jax import jit [as 别名]
def test_jit_args():
backend = jax_backend.JaxBackend()
def fun(x, A, y):
return jax.numpy.dot(x, jax.numpy.dot(A, y))
fun_jit = backend.jit(fun)
x = jax.numpy.array(np.random.rand(4))
y = jax.numpy.array(np.random.rand(4))
A = jax.numpy.array(np.random.rand(4, 4))
res1 = fun(x, A, y)
res2 = fun_jit(x, A, y)
res3 = fun_jit(x, y=y, A=A)
np.testing.assert_allclose(res1, res2)
np.testing.assert_allclose(res1, res3)
示例10: test_arnoldi_factorization
# 需要导入模块: import jax [as 别名]
# 或者: from jax import jit [as 别名]
def test_arnoldi_factorization(dtype):
np.random.seed(10)
D = 20
mat = np.random.rand(D, D).astype(dtype)
x = np.random.rand(D).astype(dtype)
@jax.tree_util.Partial
@jax.jit
def matvec(vector, matrix):
return matrix @ vector
arnoldi = _generate_arnoldi_factorization(jax)
ncv = 40
kv = jax.numpy.zeros((ncv + 1, D), dtype=dtype)
H = jax.numpy.zeros((ncv + 1, ncv), dtype=dtype)
start = 0
kv, H, it, _ = arnoldi(matvec, [mat], x, kv, H, start, ncv, 0.01)
Vm = jax.numpy.transpose(kv[:it, :])
Hm = H[:it, :it]
fm = kv[it, :] * H[it, it - 1]
em = np.zeros((1, Vm.shape[1]))
em[0, -1] = 1
np.testing.assert_almost_equal(mat @ Vm - Vm @ Hm - fm[:, None] * em,
np.zeros((it, Vm.shape[1])).astype(dtype))
示例11: step
# 需要导入模块: import jax [as 别名]
# 或者: from jax import jit [as 别名]
def step(self, *args, rng_key=None, **kwargs):
if self.svi_state is None:
if rng_key is None:
rng_key = numpyro.sample('svi.init', dist.PRNGIdentity())
self.svi_state = self.init(rng_key, *args, **kwargs)
try:
self.svi_state, loss = jit(self.update)(self.svi_state, *args, **kwargs)
except TypeError as e:
if 'not a valid JAX type' in str(e):
raise TypeError('NumPyro backend requires args, kwargs to be arrays or tuples, '
'dicts of arrays.')
else:
raise e
params = jit(super(SVI, self).get_params)(self.svi_state)
get_param_store().update(params)
return loss
示例12: test_dual_averaging
# 需要导入模块: import jax [as 别名]
# 或者: from jax import jit [as 别名]
def test_dual_averaging(jitted):
def optimize(f):
da_init, da_update = dual_averaging(gamma=0.5)
da_state = da_init()
for i in range(10):
x = da_state[0]
g = grad(f)(x)
da_state = da_update(g, da_state)
x_avg = da_state[1]
return x_avg
f = lambda x: (x + 1) ** 2 # noqa: E731
fn = jit(optimize, static_argnums=(0,)) if jitted else optimize
x_opt = fn(f)
assert_allclose(x_opt, -1., atol=1e-3)
示例13: test_jitted_update_fn
# 需要导入模块: import jax [as 别名]
# 或者: from jax import jit [as 别名]
def test_jitted_update_fn():
data = jnp.array([1.0] * 8 + [0.0] * 2)
def model(data):
f = numpyro.sample("beta", dist.Beta(1., 1.))
numpyro.sample("obs", dist.Bernoulli(f), obs=data)
def guide(data):
alpha_q = numpyro.param("alpha_q", 1.0,
constraint=constraints.positive)
beta_q = numpyro.param("beta_q", 1.0,
constraint=constraints.positive)
numpyro.sample("beta", dist.Beta(alpha_q, beta_q))
adam = optim.Adam(0.05)
svi = SVI(model, guide, adam, ELBO())
svi_state = svi.init(random.PRNGKey(1), data)
expected = svi.get_params(svi.update(svi_state, data)[0])
actual = svi.get_params(jit(svi.update)(svi_state, data=data)[0])
check_close(actual, expected, atol=1e-5)
示例14: test_mask
# 需要导入模块: import jax [as 别名]
# 或者: from jax import jit [as 别名]
def test_mask(mask_last, use_jit):
N = 10
mask = np.ones(N, dtype=np.bool)
mask[-mask_last] = 0
def model(data, mask):
with numpyro.plate('N', N):
x = numpyro.sample('x', dist.Normal(0, 1))
with handlers.mask(mask_array=mask):
numpyro.sample('y', dist.Delta(x, log_density=1.))
with handlers.scale(scale=2):
numpyro.sample('obs', dist.Normal(x, 1), obs=data)
data = random.normal(random.PRNGKey(0), (N,))
x = random.normal(random.PRNGKey(1), (N,))
if use_jit:
log_joint = jit(lambda *args: log_density(*args)[0], static_argnums=(0,))(
model, (data, mask), {}, {'x': x, 'y': x})
else:
log_joint = log_density(model, (data, mask), {}, {'x': x, 'y': x})[0]
log_prob_x = dist.Normal(0, 1).log_prob(x)
log_prob_y = mask
log_prob_z = dist.Normal(x, 1).log_prob(data)
expected = (log_prob_x + jnp.where(mask, log_prob_y + 2 * log_prob_z, 0.)).sum()
assert_allclose(log_joint, expected, atol=1e-4)
示例15: test_numpyrooptim_no_double_jit
# 需要导入模块: import jax [as 别名]
# 或者: from jax import jit [as 别名]
def test_numpyrooptim_no_double_jit(optim_class, args):
opt = optim_class(*args)
state = opt.init(jnp.zeros(10))
my_fn_calls = 0
@jit
def my_fn(state, g):
nonlocal my_fn_calls
my_fn_calls += 1
state = opt.update(g, state)
return state
state = my_fn(state, jnp.ones(10)*1.)
state = my_fn(state, jnp.ones(10)*2.)
state = my_fn(state, jnp.ones(10)*3.)
assert my_fn_calls == 1