当前位置: 首页>>代码示例>>Python>>正文


Python jax.jit方法代码示例

本文整理汇总了Python中jax.jit方法的典型用法代码示例。如果您正苦于以下问题:Python jax.jit方法的具体用法?Python jax.jit怎么用?Python jax.jit使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在jax的用法示例。


在下文中一共展示了jax.jit方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __init__

# 需要导入模块: import jax [as 别名]
# 或者: from jax import jit [as 别名]
def __init__(self, obs_dim, *, seed=None):
        """Internal setup for Jax-based reward models.

        Initialises reward model using given seed & input size (`obs_dim`).

        Args:
            obs_dim (int): dimensionality of observation space.
            seed (int or None): random seed for generating initial params. If
                None, seed will be chosen arbitrarily, as in
                LinearRewardModel.
        """
        # TODO: apply jax.jit() to everything in sight
        net_init, self._net_apply = self.make_stax_model()
        if seed is None:
            # oh well
            seed = np.random.randint((1 << 63) - 1)
        rng = jrandom.PRNGKey(seed)
        out_shape, self._net_params = net_init(rng, (-1, obs_dim))
        self._net_grads = jax.grad(self._net_apply)
        # output shape should just be batch dim, nothing else
        assert out_shape == (-1,), "got a weird output shape %s" % (out_shape,) 
开发者ID:HumanCompatibleAI,项目名称:imitation,代码行数:23,代码来源:tabular_irl.py

示例2: test_external_submodule

# 需要导入模块: import jax [as 别名]
# 或者: from jax import jit [as 别名]
def test_external_submodule():
    layer = Dense(3)

    @parametrized
    def net(inputs):
        return 2 * layer(inputs)

    inputs = random_inputs((2,))
    params = net.init_parameters(inputs, key=PRNGKey(0))
    out = net.apply(params, inputs)
    assert out.shape == (3,)

    out_ = net.apply(params, inputs)
    assert jnp.array_equal(out, out_)

    out_ = net.apply(params, inputs, jit=True)
    assert jnp.allclose(out, out_) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:19,代码来源:test_core.py

示例3: test_external_submodule2

# 需要导入模块: import jax [as 别名]
# 或者: from jax import jit [as 别名]
def test_external_submodule2():
    layer = Dense(2, zeros, zeros)

    @parametrized
    def net(inputs):
        return layer(inputs)

    inputs = jnp.zeros((1, 2))

    params = net.init_parameters(inputs, key=PRNGKey(0))
    assert_parameters_equal(((jnp.zeros((2, 2)), jnp.zeros(2)),), params)

    out = net.apply(params, inputs)
    assert jnp.array_equal(jnp.zeros((1, 2)), out)

    out_ = net.apply(params, inputs, jit=True)
    assert jnp.array_equal(out, out_) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:19,代码来源:test_core.py

示例4: test_external_sequential_submodule

# 需要导入模块: import jax [as 别名]
# 或者: from jax import jit [as 别名]
def test_external_sequential_submodule():
    layer = Sequential(Conv(4, (2, 2)), flatten, relu, Dense(3), relu, Dense(2),
                       Sequential(Dense(2), relu))
    inputs = jnp.zeros((1, 5, 5, 2))

    params = layer.init_parameters(inputs, key=PRNGKey(0))
    assert (4,) == params.conv.bias.shape
    assert (3,) == params.dense0.bias.shape
    assert (3, 2) == params.dense1.kernel.shape
    assert (2,) == params.dense1.bias.shape
    assert (2,) == params.sequential.dense.bias.shape

    out = layer.apply(params, inputs)
    assert (1, 2) == out.shape

    out_ = layer.apply(params, inputs, jit=True)
    assert jnp.allclose(out, out_) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:19,代码来源:test_core.py

示例5: test_param_and_submodule_mixed

# 需要导入模块: import jax [as 别名]
# 或者: from jax import jit [as 别名]
def test_param_and_submodule_mixed():
    @parametrized
    def linear_map(inputs):
        kernel = parameter((inputs.shape[-1], 2), zeros, 'kernel')
        return jnp.dot(inputs, kernel)

    @parametrized
    def dense(inputs):
        return linear_map(inputs) + parameter((2,), zeros, 'bias')

    inputs = jnp.zeros((1, 3))

    params = dense.init_parameters(inputs, key=PRNGKey(0))
    assert (2,) == params.bias.shape
    assert (3, 2) == params.linear_map.kernel.shape

    out = dense.apply(params, inputs)
    assert jnp.array_equal(jnp.zeros((1, 2)), out)

    out_ = dense.apply(params, inputs, jit=True)
    assert jnp.array_equal(out, out_) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:23,代码来源:test_core.py

示例6: test_mixed_up_execution_order

# 需要导入模块: import jax [as 别名]
# 或者: from jax import jit [as 别名]
def test_mixed_up_execution_order():
    @parametrized
    def dense(inputs):
        bias = parameter((2,), zeros, 'bias')
        kernel = parameter((inputs.shape[-1], 2), zeros, 'kernel')
        return jnp.dot(inputs, kernel) + bias

    inputs = jnp.zeros((1, 3))

    params = dense.init_parameters(inputs, key=PRNGKey(0))
    assert (2,) == params.bias.shape
    assert (3, 2) == params.kernel.shape

    out = dense.apply(params, inputs)
    assert jnp.array_equal(jnp.zeros((1, 2)), out)

    out_ = dense.apply(params, inputs, jit=True)
    assert jnp.array_equal(out, out_) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:20,代码来源:test_core.py

示例7: test_parameters_from_subsubmodule

# 需要导入模块: import jax [as 别名]
# 或者: from jax import jit [as 别名]
def test_parameters_from_subsubmodule():
    subsublayer = Dense(2)
    sublayer = Sequential(subsublayer, relu)
    net = Sequential(sublayer, jnp.sum)
    inputs = jnp.zeros((1, 3))
    params = net.init_parameters(inputs, key=PRNGKey(0))
    out = net.apply(params, inputs)

    subsublayer_params = subsublayer.init_parameters(inputs, key=PRNGKey(0))

    params_ = net.parameters_from({subsublayer: subsublayer_params}, inputs)
    assert_dense_parameters_equal(subsublayer_params, params_[0][0])
    out_ = net.apply(params_, inputs)
    assert out.shape == out_.shape

    out_ = net.apply_from({subsublayer: subsublayer_params}, inputs)
    assert out.shape == out_.shape

    out_ = net.apply_from({subsublayer: subsublayer_params}, inputs, jit=True)
    assert out.shape == out_.shape 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:22,代码来源:test_core.py

示例8: test_parameters_from_top_level

# 需要导入模块: import jax [as 别名]
# 或者: from jax import jit [as 别名]
def test_parameters_from_top_level():
    net = Dense(2)
    inputs = jnp.zeros((1, 3))
    params = net.init_parameters(inputs, key=PRNGKey(0))
    out = net.apply(params, inputs)

    params_ = net.parameters_from({net: params}, inputs)
    assert_dense_parameters_equal(params, params_)
    out_ = net.apply(params_, inputs)
    assert jnp.array_equal(out, out_)

    out_ = net.apply_from({net: params}, inputs)
    assert jnp.array_equal(out, out_)

    out_ = net.apply_from({net: params}, inputs, jit=True)
    assert jnp.array_equal(out, out_) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:18,代码来源:test_core.py

示例9: test_jit_args

# 需要导入模块: import jax [as 别名]
# 或者: from jax import jit [as 别名]
def test_jit_args():
  backend = jax_backend.JaxBackend()

  def fun(x, A, y):
    return jax.numpy.dot(x, jax.numpy.dot(A, y))

  fun_jit = backend.jit(fun)
  x = jax.numpy.array(np.random.rand(4))
  y = jax.numpy.array(np.random.rand(4))
  A = jax.numpy.array(np.random.rand(4, 4))

  res1 = fun(x, A, y)
  res2 = fun_jit(x, A, y)
  res3 = fun_jit(x, y=y, A=A)
  np.testing.assert_allclose(res1, res2)
  np.testing.assert_allclose(res1, res3) 
开发者ID:google,项目名称:TensorNetwork,代码行数:18,代码来源:jax_backend_test.py

示例10: test_arnoldi_factorization

# 需要导入模块: import jax [as 别名]
# 或者: from jax import jit [as 别名]
def test_arnoldi_factorization(dtype):
  np.random.seed(10)
  D = 20
  mat = np.random.rand(D, D).astype(dtype)
  x = np.random.rand(D).astype(dtype)

  @jax.tree_util.Partial
  @jax.jit
  def matvec(vector, matrix):
    return matrix @ vector

  arnoldi = _generate_arnoldi_factorization(jax)
  ncv = 40
  kv = jax.numpy.zeros((ncv + 1, D), dtype=dtype)
  H = jax.numpy.zeros((ncv + 1, ncv), dtype=dtype)
  start = 0
  kv, H, it, _ = arnoldi(matvec, [mat], x, kv, H, start, ncv, 0.01)
  Vm = jax.numpy.transpose(kv[:it, :])
  Hm = H[:it, :it]
  fm = kv[it, :] * H[it, it - 1]
  em = np.zeros((1, Vm.shape[1]))
  em[0, -1] = 1
  np.testing.assert_almost_equal(mat @ Vm - Vm @ Hm - fm[:, None] * em,
                                 np.zeros((it, Vm.shape[1])).astype(dtype)) 
开发者ID:google,项目名称:TensorNetwork,代码行数:26,代码来源:jitted_functions_test.py

示例11: step

# 需要导入模块: import jax [as 别名]
# 或者: from jax import jit [as 别名]
def step(self, *args, rng_key=None, **kwargs):
        if self.svi_state is None:
            if rng_key is None:
                rng_key = numpyro.sample('svi.init', dist.PRNGIdentity())
            self.svi_state = self.init(rng_key, *args, **kwargs)
        try:
            self.svi_state, loss = jit(self.update)(self.svi_state, *args, **kwargs)
        except TypeError as e:
            if 'not a valid JAX type' in str(e):
                raise TypeError('NumPyro backend requires args, kwargs to be arrays or tuples, '
                                'dicts of arrays.')
            else:
                raise e
        params = jit(super(SVI, self).get_params)(self.svi_state)
        get_param_store().update(params)
        return loss 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:18,代码来源:infer.py

示例12: test_dual_averaging

# 需要导入模块: import jax [as 别名]
# 或者: from jax import jit [as 别名]
def test_dual_averaging(jitted):
    def optimize(f):
        da_init, da_update = dual_averaging(gamma=0.5)
        da_state = da_init()
        for i in range(10):
            x = da_state[0]
            g = grad(f)(x)
            da_state = da_update(g, da_state)
        x_avg = da_state[1]
        return x_avg

    f = lambda x: (x + 1) ** 2  # noqa: E731
    fn = jit(optimize, static_argnums=(0,)) if jitted else optimize
    x_opt = fn(f)

    assert_allclose(x_opt, -1., atol=1e-3) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:18,代码来源:test_hmc_util.py

示例13: test_jitted_update_fn

# 需要导入模块: import jax [as 别名]
# 或者: from jax import jit [as 别名]
def test_jitted_update_fn():
    data = jnp.array([1.0] * 8 + [0.0] * 2)

    def model(data):
        f = numpyro.sample("beta", dist.Beta(1., 1.))
        numpyro.sample("obs", dist.Bernoulli(f), obs=data)

    def guide(data):
        alpha_q = numpyro.param("alpha_q", 1.0,
                                constraint=constraints.positive)
        beta_q = numpyro.param("beta_q", 1.0,
                               constraint=constraints.positive)
        numpyro.sample("beta", dist.Beta(alpha_q, beta_q))

    adam = optim.Adam(0.05)
    svi = SVI(model, guide, adam, ELBO())
    svi_state = svi.init(random.PRNGKey(1), data)
    expected = svi.get_params(svi.update(svi_state, data)[0])

    actual = svi.get_params(jit(svi.update)(svi_state, data=data)[0])
    check_close(actual, expected, atol=1e-5) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:23,代码来源:test_svi.py

示例14: test_mask

# 需要导入模块: import jax [as 别名]
# 或者: from jax import jit [as 别名]
def test_mask(mask_last, use_jit):
    N = 10
    mask = np.ones(N, dtype=np.bool)
    mask[-mask_last] = 0

    def model(data, mask):
        with numpyro.plate('N', N):
            x = numpyro.sample('x', dist.Normal(0, 1))
            with handlers.mask(mask_array=mask):
                numpyro.sample('y', dist.Delta(x, log_density=1.))
                with handlers.scale(scale=2):
                    numpyro.sample('obs', dist.Normal(x, 1), obs=data)

    data = random.normal(random.PRNGKey(0), (N,))
    x = random.normal(random.PRNGKey(1), (N,))
    if use_jit:
        log_joint = jit(lambda *args: log_density(*args)[0], static_argnums=(0,))(
            model, (data, mask), {}, {'x': x, 'y': x})
    else:
        log_joint = log_density(model, (data, mask), {}, {'x': x, 'y': x})[0]
    log_prob_x = dist.Normal(0, 1).log_prob(x)
    log_prob_y = mask
    log_prob_z = dist.Normal(x, 1).log_prob(data)
    expected = (log_prob_x + jnp.where(mask,  log_prob_y + 2 * log_prob_z, 0.)).sum()
    assert_allclose(log_joint, expected, atol=1e-4) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:27,代码来源:test_handlers.py

示例15: test_numpyrooptim_no_double_jit

# 需要导入模块: import jax [as 别名]
# 或者: from jax import jit [as 别名]
def test_numpyrooptim_no_double_jit(optim_class, args):

    opt = optim_class(*args)
    state = opt.init(jnp.zeros(10))

    my_fn_calls = 0

    @jit
    def my_fn(state, g):
        nonlocal my_fn_calls
        my_fn_calls += 1

        state = opt.update(g, state)
        return state

    state = my_fn(state, jnp.ones(10)*1.)
    state = my_fn(state, jnp.ones(10)*2.)
    state = my_fn(state, jnp.ones(10)*3.)

    assert my_fn_calls == 1 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:22,代码来源:test_optimizers.py


注:本文中的jax.jit方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。