当前位置: 首页>>代码示例>>Python>>正文


Python numpy.stack方法代码示例

本文整理汇总了Python中jax.numpy.stack方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.stack方法的具体用法?Python numpy.stack怎么用?Python numpy.stack使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在jax.numpy的用法示例。


在下文中一共展示了numpy.stack方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_get_proposal_loc_and_scale

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import stack [as 别名]
def test_get_proposal_loc_and_scale(dense_mass):
    N = 10
    dim = 3
    samples = random.normal(random.PRNGKey(0), (N, dim))
    loc = jnp.mean(samples[:-1], 0)
    if dense_mass:
        scale = jnp.linalg.cholesky(jnp.cov(samples[:-1], rowvar=False, bias=True))
    else:
        scale = jnp.std(samples[:-1], 0)
    actual_loc, actual_scale = _get_proposal_loc_and_scale(samples[:-1], loc, scale, samples[-1])
    expected_loc, expected_scale = [], []
    for i in range(N - 1):
        samples_i = np.delete(samples, i, axis=0)
        expected_loc.append(jnp.mean(samples_i, 0))
        if dense_mass:
            expected_scale.append(jnp.linalg.cholesky(jnp.cov(samples_i, rowvar=False, bias=True)))
        else:
            expected_scale.append(jnp.std(samples_i, 0))
    expected_loc = jnp.stack(expected_loc)
    expected_scale = jnp.stack(expected_scale)
    assert_allclose(actual_loc, expected_loc, rtol=1e-4)
    assert_allclose(actual_scale, expected_scale, atol=1e-6, rtol=0.05) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:24,代码来源:test_mcmc.py

示例2: _stack

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import stack [as 别名]
def _stack(dim, *x):
    return np.stack(x, axis=dim) 
开发者ID:pyro-ppl,项目名称:funsor,代码行数:4,代码来源:ops.py

示例3: serialize

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import stack [as 别名]
def serialize(self, data):
    array = data
    batch_size = array.shape[0]
    array = (array - self._space.low) / (self._space.high - self._space.low)
    array = np.clip(array, 0, 1)
    digits = []
    for digit_index in range(-1, -self._precision - 1, -1):
      threshold = self._vocab_size ** digit_index
      digit = np.array(array / threshold).astype(np.int32)
      # For the corner case of x == high.
      digit = np.where(digit == self._vocab_size, digit - 1, digit)
      digits.append(digit)
      array -= digit * threshold
    digits = np.stack(digits, axis=-1)
    return np.reshape(digits, (batch_size, -1)) 
开发者ID:google,项目名称:trax,代码行数:17,代码来源:space_serializer.py

示例4: conditional_params_from_outputs

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import stack [as 别名]
def conditional_params_from_outputs(image, theta):
    """
    Maps image and model output theta to conditional parameters for a mixture
    of nr_mix logistics. If the input shapes are

    image.shape == (h, w, c)
    theta.shape == (h, w, 10 * nr_mix)

    the output shapes will be

    means.shape == inv_scales.shape == (nr_mix, h, w, c)
    logit_probs.shape == (nr_mix, h, w)
    """
    assert theta.shape[2] % 10 == 0
    nr_mix = theta.shape[2] // 10
    logit_probs, theta = jnp.split(theta, [nr_mix], axis=-1)
    logit_probs = jnp.moveaxis(logit_probs, -1, 0)
    theta = jnp.moveaxis(jnp.reshape(theta, image.shape + (3 * nr_mix,)), -1, 0)
    unconditioned_means, log_scales, coeffs = jnp.split(theta, 3)
    coeffs = jnp.tanh(coeffs)

    # now condition the means for the last 2 channels
    mean_red = unconditioned_means[..., 0]
    mean_green = unconditioned_means[..., 1] + coeffs[..., 0] * image[..., 0]
    mean_blue = (unconditioned_means[..., 2] + coeffs[..., 1] * image[..., 0]
                 + coeffs[..., 2] * image[..., 1])
    means = jnp.stack((mean_red, mean_green, mean_blue), axis=-1)
    inv_scales = softplus(log_scales)
    return means, inv_scales, logit_probs 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:31,代码来源:pixelcnn.py

示例5: stack

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import stack [as 别名]
def stack(self, values, axis=0, name='stack'):
        return np.stack(values, dim=axis) 
开发者ID:sharadmv,项目名称:deepx,代码行数:4,代码来源:jax.py

示例6: pack

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import stack [as 别名]
def pack(self, *args, **kwargs):
        return self.stack(*args, **kwargs) 
开发者ID:sharadmv,项目名称:deepx,代码行数:4,代码来源:jax.py

示例7: stack

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import stack [as 别名]
def stack(self, sequence, axis=0):
        if axis == 0:
            return np.stack(sequence)
        raise RuntimeError('stack axis!=0') 
开发者ID:scikit-hep,项目名称:pyhf,代码行数:6,代码来源:jax_backend.py

示例8: simulate_data

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import stack [as 别名]
def simulate_data(rng_key, num_categories, num_words, num_supervised_data, num_unsupervised_data):
    rng_key, rng_key_transition, rng_key_emission = random.split(rng_key, 3)

    transition_prior = jnp.ones(num_categories)
    emission_prior = jnp.repeat(0.1, num_words)

    transition_prob = dist.Dirichlet(transition_prior).sample(key=rng_key_transition,
                                                              sample_shape=(num_categories,))
    emission_prob = dist.Dirichlet(emission_prior).sample(key=rng_key_emission,
                                                          sample_shape=(num_categories,))

    start_prob = jnp.repeat(1. / num_categories, num_categories)
    categories, words = [], []
    for t in range(num_supervised_data + num_unsupervised_data):
        rng_key, rng_key_transition, rng_key_emission = random.split(rng_key, 3)
        if t == 0 or t == num_supervised_data:
            category = dist.Categorical(start_prob).sample(key=rng_key_transition)
        else:
            category = dist.Categorical(transition_prob[category]).sample(key=rng_key_transition)
        word = dist.Categorical(emission_prob[category]).sample(key=rng_key_emission)
        categories.append(category)
        words.append(word)

    # split into supervised data and unsupervised data
    categories, words = jnp.stack(categories), jnp.stack(words)
    supervised_categories = categories[:num_supervised_data]
    supervised_words = words[:num_supervised_data]
    unsupervised_words = words[num_supervised_data:]
    return (transition_prior, emission_prior, transition_prob, emission_prob,
            supervised_categories, supervised_words, unsupervised_words) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:32,代码来源:hmm.py

示例9: dz_dt

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import stack [as 别名]
def dz_dt(z, t, theta):
    """
    Lotka–Volterra equations. Real positive parameters `alpha`, `beta`, `gamma`, `delta`
    describes the interaction of two species.
    """
    u = z[0]
    v = z[1]
    alpha, beta, gamma, delta = theta[..., 0], theta[..., 1], theta[..., 2], theta[..., 3]
    du_dt = (alpha - beta * v) * u
    dv_dt = (-gamma + delta * u) * v
    return jnp.stack([du_dt, dv_dt]) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:13,代码来源:ode.py

示例10: _laxmap

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import stack [as 别名]
def _laxmap(f, xs):
    n = tree_flatten(xs)[0][0].shape[0]

    ys = []
    for i in range(n):
        x = jit(_get_value_from_index)(xs, i)
        ys.append(f(x))

    return tree_multimap(lambda *args: jnp.stack(args), *ys) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:11,代码来源:mcmc.py

示例11: parametric

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import stack [as 别名]
def parametric(subposteriors, diagonal=False):
    """
    Merges subposteriors following (embarrassingly parallel) parametric Monte Carlo algorithm.

    **References:**

    1. *Asymptotically Exact, Embarrassingly Parallel MCMC*,
       Willie Neiswanger, Chong Wang, Eric Xing

    :param list subposteriors: a list in which each element is a collection of samples.
    :param bool diagonal: whether to compute weights using variance or covariance, defaults to
        `False` (using covariance).
    :return: the estimated mean and variance/covariance parameters of the joined posterior
    """
    joined_subposteriors = tree_multimap(lambda *args: jnp.stack(args), *subposteriors)
    joined_subposteriors = vmap(vmap(lambda sample: ravel_pytree(sample)[0]))(joined_subposteriors)

    submeans = jnp.mean(joined_subposteriors, axis=1)
    if diagonal:
        weights = vmap(lambda x: 1 / jnp.var(x, ddof=1, axis=0))(joined_subposteriors)
        var = 1 / jnp.sum(weights, axis=0)
        normalized_weights = var * weights

        # comparing to consensus implementation, we compute weighted mean here
        mean = jnp.einsum('ij,ij->j', normalized_weights, submeans)
        return mean, var
    else:
        weights = vmap(lambda x: jnp.linalg.inv(jnp.cov(x.T)))(joined_subposteriors)
        cov = jnp.linalg.inv(jnp.sum(weights, axis=0))
        normalized_weights = jnp.matmul(cov, weights)

        # comparing to consensus implementation, we compute weighted mean here
        mean = jnp.einsum('ijk,ik->j', normalized_weights, submeans)
        return mean, cov 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:36,代码来源:hmc_util.py

示例12: log_prob

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import stack [as 别名]
def log_prob(self, value):
        return self._dirichlet.log_prob(jnp.stack([value, 1. - value], -1)) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:4,代码来源:continuous.py

示例13: kinetic_fn

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import stack [as 别名]
def kinetic_fn(m_inv, p):
        z = jnp.stack([p['x'], p['y']], axis=-1)
        return 0.5 * jnp.dot(m_inv, z**2) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:5,代码来源:test_hmc_util.py

示例14: test_seed

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import stack [as 别名]
def test_seed():
    def _sample():
        x = numpyro.sample('x', dist.Normal(0., 1.))
        y = numpyro.sample('y', dist.Normal(1., 2.))
        return jnp.stack([x, y])

    xs = []
    for i in range(100):
        with handlers.seed(rng_seed=i):
            xs.append(_sample())
    xs = jnp.stack(xs)

    ys = vmap(lambda rng_key: handlers.seed(lambda: _sample(), rng_key)())(jnp.arange(100))
    assert_allclose(xs, ys, atol=1e-6) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:16,代码来源:test_handlers.py

示例15: test_nested_seeding

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import stack [as 别名]
def test_nested_seeding():
    def fn(rng_key_1, rng_key_2, rng_key_3):
        xs = []
        with handlers.seed(rng_seed=rng_key_1):
            with handlers.seed(rng_seed=rng_key_2):
                xs.append(numpyro.sample('x', dist.Normal(0., 1.)))
                with handlers.seed(rng_seed=rng_key_3):
                    xs.append(numpyro.sample('y', dist.Normal(0., 1.)))
        return jnp.stack(xs)

    s1, s2 = fn(0, 1, 2), fn(3, 1, 2)
    assert_allclose(s1, s2)
    s1, s2 = fn(0, 1, 2), fn(3, 1, 4)
    assert_allclose(s1[0], s2[0])
    assert_raises(AssertionError, assert_allclose, s1[1], s2[1]) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:17,代码来源:test_handlers.py


注:本文中的jax.numpy.stack方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。