当前位置: 首页>>代码示例>>Python>>正文


Python jax.vmap方法代码示例

本文整理汇总了Python中jax.vmap方法的典型用法代码示例。如果您正苦于以下问题:Python jax.vmap方法的具体用法?Python jax.vmap怎么用?Python jax.vmap使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在jax的用法示例。


在下文中一共展示了jax.vmap方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_Batched

# 需要导入模块: import jax [as 别名]
# 或者: from jax import vmap [as 别名]
def test_Batched():
    out_dim = 1

    @parametrized
    def unbatched_dense(input):
        kernel = parameter((out_dim, input.shape[-1]), ones)
        bias = parameter((out_dim,), ones)
        return jnp.dot(kernel, input) + bias

    batch_size = 4

    unbatched_params = unbatched_dense.init_parameters(jnp.zeros(2), key=PRNGKey(0))
    out = unbatched_dense.apply(unbatched_params, jnp.ones(2))
    assert jnp.array([3.]) == out

    dense_apply = vmap(unbatched_dense.apply, (None, 0))
    out_batched_ = dense_apply(unbatched_params, jnp.ones((batch_size, 2)))
    assert jnp.array_equal(jnp.stack([out] * batch_size), out_batched_)

    dense = Batched(unbatched_dense)
    params = dense.init_parameters(jnp.ones((batch_size, 2)), key=PRNGKey(0))
    assert_parameters_equal((unbatched_params,), params)
    out_batched = dense.apply(params, jnp.ones((batch_size, 2)))
    assert jnp.array_equal(out_batched_, out_batched) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:26,代码来源:test_modules.py

示例2: _predictive

# 需要导入模块: import jax [as 别名]
# 或者: from jax import vmap [as 别名]
def _predictive(rng_key, model, posterior_samples, num_samples, return_sites=None,
                parallel=True, model_args=(), model_kwargs={}):
    rng_keys = random.split(rng_key, num_samples)

    def single_prediction(val):
        rng_key, samples = val
        model_trace = trace(seed(substitute(model, samples), rng_key)).get_trace(
            *model_args, **model_kwargs)
        if return_sites is not None:
            if return_sites == '':
                sites = {k for k, site in model_trace.items() if site['type'] != 'plate'}
            else:
                sites = return_sites
        else:
            sites = {k for k, site in model_trace.items()
                     if (site['type'] == 'sample' and k not in samples) or (site['type'] == 'deterministic')}
        return {name: site['value'] for name, site in model_trace.items() if name in sites}

    if parallel:
        return vmap(single_prediction)((rng_keys, posterior_samples))
    else:
        return lax.map(single_prediction, (rng_keys, posterior_samples)) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:24,代码来源:util.py

示例3: log_likelihood

# 需要导入模块: import jax [as 别名]
# 或者: from jax import vmap [as 别名]
def log_likelihood(model, posterior_samples, *args, **kwargs):
    """
    (EXPERIMENTAL INTERFACE) Returns log likelihood at observation nodes of model,
    given samples of all latent variables.

    :param model: Python callable containing Pyro primitives.
    :param dict posterior_samples: dictionary of samples from the posterior.
    :param args: model arguments.
    :param kwargs: model kwargs.
    :return: dict of log likelihoods at observation sites.
    """

    def single_loglik(samples):
        model_trace = trace(substitute(model, samples)).get_trace(*args, **kwargs)
        return {name: site['fn'].log_prob(site['value']) for name, site in model_trace.items()
                if site['type'] == 'sample' and site['is_observed']}

    return vmap(single_loglik)(posterior_samples) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:20,代码来源:util.py

示例4: _multinomial

# 需要导入模块: import jax [as 别名]
# 或者: from jax import vmap [as 别名]
def _multinomial(key, p, n, n_max, shape=()):
    if jnp.shape(n) != jnp.shape(p)[:-1]:
        broadcast_shape = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1])
        n = jnp.broadcast_to(n, broadcast_shape)
        p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:])
    shape = shape or p.shape[:-1]
    # get indices from categorical distribution then gather the result
    indices = categorical(key, p, (n_max,) + shape)
    # mask out values when counts is heterogeneous
    if jnp.ndim(n) > 0:
        mask = promote_shapes(jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,))[0]
        mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype)
        excess = jnp.concatenate([jnp.expand_dims(n_max - n, -1), jnp.zeros(jnp.shape(n) + (p.shape[-1] - 1,))], -1)
    else:
        mask = 1
        excess = 0
    # NB: we transpose to move batch shape to the front
    indices_2D = (jnp.reshape(indices * mask, (n_max, -1,))).T
    samples_2D = vmap(_scatter_add_one, (0, 0, 0))(jnp.zeros((indices_2D.shape[0], p.shape[-1]),
                                                             dtype=indices.dtype),
                                                   jnp.expand_dims(indices_2D, axis=-1),
                                                   jnp.ones(indices_2D.shape, dtype=indices.dtype))
    return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:25,代码来源:util.py

示例5: test_block_neural_arn

# 需要导入模块: import jax [as 别名]
# 或者: from jax import vmap [as 别名]
def test_block_neural_arn(input_dim, hidden_factors, residual, batch_shape):
    arn_init, arn = BlockNeuralAutoregressiveNN(input_dim, hidden_factors, residual)

    rng = random.PRNGKey(0)
    input_shape = batch_shape + (input_dim,)
    out_shape, init_params = arn_init(rng, input_shape)
    assert out_shape == input_shape

    x = random.normal(random.PRNGKey(1), input_shape)
    output, logdet = arn(init_params, x)
    assert output.shape == input_shape
    assert logdet.shape == input_shape

    if len(batch_shape) == 1:
        jac = vmap(jacfwd(lambda x: arn(init_params, x)[0]))(x)
    else:
        jac = jacfwd(lambda x: arn(init_params, x)[0])(x)
    assert_allclose(logdet.sum(-1), jnp.linalg.slogdet(jac)[1], rtol=1e-6)

    # make sure jacobians are lower triangular
    assert np.sum(np.abs(np.triu(jac, k=1))) == 0.0
    assert np.all(np.abs(matrix_to_tril_vec(jac)) > 0) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:24,代码来源:test_nn.py

示例6: test_binop_batch_rule

# 需要导入模块: import jax [as 别名]
# 或者: from jax import vmap [as 别名]
def test_binop_batch_rule(prim):
    bx = jnp.array([1., 2., 3.])
    by = jnp.array([2., 3., 4.])
    x = jnp.array(1.)
    y = jnp.array(2.)

    actual_bx_by = vmap(lambda x, y: prim(x, y))(bx, by)
    for i in range(3):
        assert_allclose(actual_bx_by[i], prim(bx[i], by[i]))

    actual_x_by = vmap(lambda y: prim(x, y))(by)
    for i in range(3):
        assert_allclose(actual_x_by[i], prim(x, by[i]))

    actual_bx_y = vmap(lambda x: prim(x, y))(bx)
    for i in range(3):
        assert_allclose(actual_bx_y[i], prim(bx[i], y)) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:19,代码来源:test_distributions_util.py

示例7: grads

# 需要导入模块: import jax [as 别名]
# 或者: from jax import vmap [as 别名]
def grads(self, inputs):
        in_grad_partial = jax.partial(self._net_grads, self._net_params)
        grad_vmap = jax.vmap(in_grad_partial)
        rich_grads = grad_vmap(inputs)
        flat_grads = np.asarray(self._flatten_batch(rich_grads))
        assert flat_grads.ndim == 2 and flat_grads.shape[0] == inputs.shape[0]
        return flat_grads 
开发者ID:HumanCompatibleAI,项目名称:imitation,代码行数:9,代码来源:tabular_irl.py

示例8: Batched

# 需要导入模块: import jax [as 别名]
# 或者: from jax import vmap [as 别名]
def Batched(unbatched_model: parametrized, batch_dim=0):
    @parametrized
    def batched(*batched_args):
        args = tree_map(lambda x: x[0], batched_args)
        params = Parameter(lambda key: unbatched_model.init_parameters(*args, key=key), 'model')()
        batched_apply = vmap(partial(unbatched_model.apply, params), batch_dim)
        return batched_apply(*batched_args)

    return batched 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:11,代码来源:modules.py

示例9: main

# 需要导入模块: import jax [as 别名]
# 或者: from jax import vmap [as 别名]
def main(args):
    N, D_X, D_H = args.num_data, 3, args.num_hidden
    X, Y, X_test = get_data(N=N, D_X=D_X)

    # do inference
    rng_key, rng_key_predict = random.split(random.PRNGKey(0))
    samples = run_inference(model, args, rng_key, X, Y, D_H)

    # predict Y_test at inputs X_test
    vmap_args = (samples, random.split(rng_key_predict, args.num_samples * args.num_chains))
    predictions = vmap(lambda samples, rng_key: predict(model, rng_key, samples, X_test, D_H))(*vmap_args)
    predictions = predictions[..., 0]

    # compute mean prediction and confidence interval around median
    mean_prediction = jnp.mean(predictions, axis=0)
    percentiles = np.percentile(predictions, [5.0, 95.0], axis=0)

    # make plots
    fig, ax = plt.subplots(1, 1)

    # plot training data
    ax.plot(X[:, 1], Y[:, 0], 'kx')
    # plot 90% confidence level of predictions
    ax.fill_between(X_test[:, 1], percentiles[0, :], percentiles[1, :], color='lightblue')
    # plot mean prediction
    ax.plot(X_test[:, 1], mean_prediction, 'blue', ls='solid', lw=2.0)
    ax.set(xlabel="X", ylabel="Y", title="Mean predictions with 90% CI")

    plt.savefig('bnn_plot.pdf')
    plt.tight_layout() 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:32,代码来源:bnn.py

示例10: analyze_dimension

# 需要导入模块: import jax [as 别名]
# 或者: from jax import vmap [as 别名]
def analyze_dimension(samples, X, Y, dimension, hypers):
    vmap_args = (samples['msq'], samples['lambda'], samples['eta1'], samples['xisq'], samples['var_obs'])
    mus, variances = vmap(lambda msq, lam, eta1, xisq, var_obs:
                          compute_singleton_mean_variance(X, Y, dimension, msq, lam,
                                                          eta1, xisq, hypers['c'], var_obs))(*vmap_args)
    mean, variance = gaussian_mixture_stats(mus, variances)
    std = jnp.sqrt(variance)
    return mean, std


# Helper function for analyzing the posterior statistics for coefficient theta_ij 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:13,代码来源:sparse_regression.py

示例11: analyze_pair_of_dimensions

# 需要导入模块: import jax [as 别名]
# 或者: from jax import vmap [as 别名]
def analyze_pair_of_dimensions(samples, X, Y, dim1, dim2, hypers):
    vmap_args = (samples['msq'], samples['lambda'], samples['eta1'], samples['xisq'], samples['var_obs'])
    mus, variances = vmap(lambda msq, lam, eta1, xisq, var_obs:
                          compute_pairwise_mean_variance(X, Y, dim1, dim2, msq, lam,
                                                         eta1, xisq, hypers['c'], var_obs))(*vmap_args)
    mean, variance = gaussian_mixture_stats(mus, variances)
    std = jnp.sqrt(variance)
    return mean, std 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:10,代码来源:sparse_regression.py

示例12: main

# 需要导入模块: import jax [as 别名]
# 或者: from jax import vmap [as 别名]
def main(args):
    X, Y, X_test = get_data(N=args.num_data)

    # do inference
    rng_key, rng_key_predict = random.split(random.PRNGKey(0))
    samples = run_inference(model, args, rng_key, X, Y)

    # do prediction
    vmap_args = (random.split(rng_key_predict, args.num_samples * args.num_chains), samples['kernel_var'],
                 samples['kernel_length'], samples['kernel_noise'])
    means, predictions = vmap(lambda rng_key, var, length, noise:
                              predict(rng_key, X, Y, X_test, var, length, noise))(*vmap_args)

    mean_prediction = np.mean(means, axis=0)
    percentiles = np.percentile(predictions, [5.0, 95.0], axis=0)

    # make plots
    fig, ax = plt.subplots(1, 1)

    # plot training data
    ax.plot(X, Y, 'kx')
    # plot 90% confidence level of predictions
    ax.fill_between(X_test, percentiles[0, :], percentiles[1, :], color='lightblue')
    # plot mean prediction
    ax.plot(X_test, mean_prediction, 'blue', ls='solid', lw=2.0)
    ax.set(xlabel="X", ylabel="Y", title="Mean predictions with 90% CI")

    plt.savefig("gp_plot.pdf")
    plt.tight_layout() 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:31,代码来源:gp.py

示例13: init

# 需要导入模块: import jax [as 别名]
# 或者: from jax import vmap [as 别名]
def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={}):
        # non-vectorized
        if rng_key.ndim == 1:
            rng_key, rng_key_init_model = random.split(rng_key)
        # vectorized
        else:
            rng_key, rng_key_init_model = jnp.swapaxes(vmap(random.split)(rng_key), 0, 1)
        init_params = self._init_state(rng_key_init_model, model_args, model_kwargs, init_params)
        if self._potential_fn and init_params is None:
            raise ValueError('Valid value of `init_params` must be provided with'
                             ' `potential_fn`.')

        hmc_init_fn = lambda init_params, rng_key: self._init_fn(  # noqa: E731
            init_params,
            num_warmup=num_warmup,
            step_size=self._step_size,
            adapt_step_size=self._adapt_step_size,
            adapt_mass_matrix=self._adapt_mass_matrix,
            dense_mass=self._dense_mass,
            target_accept_prob=self._target_accept_prob,
            trajectory_length=self._trajectory_length,
            max_tree_depth=self._max_tree_depth,
            find_heuristic_step_size=self._find_heuristic_step_size,
            model_args=model_args,
            model_kwargs=model_kwargs,
            rng_key=rng_key,
        )
        if rng_key.ndim == 1:
            init_state = hmc_init_fn(init_params, rng_key)
        else:
            # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some
            # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth,
            # wa_steps because those variables do not depend on traced args: init_params, rng_key.
            init_state = vmap(hmc_init_fn)(init_params, rng_key)
            sample_fn = vmap(self._sample_fn, in_axes=(0, None, None))
            self._sample_fn = sample_fn
        return init_state 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:39,代码来源:mcmc.py

示例14: parametric_draws

# 需要导入模块: import jax [as 别名]
# 或者: from jax import vmap [as 别名]
def parametric_draws(subposteriors, num_draws, diagonal=False, rng_key=None):
    """
    Merges subposteriors following (embarrassingly parallel) parametric Monte Carlo algorithm.

    **References:**

    1. *Asymptotically Exact, Embarrassingly Parallel MCMC*,
       Willie Neiswanger, Chong Wang, Eric Xing

    :param list subposteriors: a list in which each element is a collection of samples.
    :param int num_draws: number of draws from the merged posterior.
    :param bool diagonal: whether to compute weights using variance or covariance, defaults to
        `False` (using covariance).
    :param jax.random.PRNGKey rng_key: source of the randomness, defaults to `jax.random.PRNGKey(0)`.
    :return: a collection of `num_draws` samples with the same data structure as each subposterior.
    """
    rng_key = random.PRNGKey(0) if rng_key is None else rng_key
    if diagonal:
        mean, var = parametric(subposteriors, diagonal=True)
        samples_flat = dist.Normal(mean, jnp.sqrt(var)).sample(rng_key, (num_draws,))
    else:
        mean, cov = parametric(subposteriors, diagonal=False)
        samples_flat = dist.MultivariateNormal(mean, cov).sample(rng_key, (num_draws,))

    _, unravel_fn = ravel_pytree(tree_map(lambda x: x[0], subposteriors[0]))
    return vmap(lambda x: unravel_fn(x))(samples_flat) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:28,代码来源:hmc_util.py

示例15: loss

# 需要导入模块: import jax [as 别名]
# 或者: from jax import vmap [as 别名]
def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
        """
        Evaluates the ELBO with an estimator that uses num_particles many samples/particles.

        :param jax.random.PRNGKey rng_key: random number generator seed.
        :param dict param_map: dictionary of current parameter values keyed by site
            name.
        :param model: Python callable with NumPyro primitives for the model.
        :param guide: Python callable with NumPyro primitives for the guide.
        :param args: arguments to the model / guide (these can possibly vary during
            the course of fitting).
        :param kwargs: keyword arguments to the model / guide (these can possibly vary
            during the course of fitting).
        :return: negative of the Evidence Lower Bound (ELBO) to be minimized.
        """
        def single_particle_elbo(rng_key):
            model_seed, guide_seed = random.split(rng_key)
            seeded_model = seed(model, model_seed)
            seeded_guide = seed(guide, guide_seed)
            guide_log_density, guide_trace = log_density(seeded_guide, args, kwargs, param_map)
            seeded_model = replay(seeded_model, guide_trace)
            model_log_density, _ = log_density(seeded_model, args, kwargs, param_map)

            # log p(z) - log q(z)
            elbo = model_log_density - guide_log_density
            return elbo

        # Return (-elbo) since by convention we do gradient descent on a loss and
        # the ELBO is a lower bound that needs to be maximized.
        if self.num_particles == 1:
            return - single_particle_elbo(rng_key)
        else:
            rng_keys = random.split(rng_key, self.num_particles)
            return - jnp.mean(vmap(single_particle_elbo)(rng_keys)) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:36,代码来源:elbo.py


注:本文中的jax.vmap方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。