本文整理汇总了Python中jax.numpy.stack方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.stack方法的具体用法?Python numpy.stack怎么用?Python numpy.stack使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类jax.numpy
的用法示例。
在下文中一共展示了numpy.stack方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test_get_proposal_loc_and_scale
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import stack [as 别名]
def test_get_proposal_loc_and_scale(dense_mass):
N = 10
dim = 3
samples = random.normal(random.PRNGKey(0), (N, dim))
loc = jnp.mean(samples[:-1], 0)
if dense_mass:
scale = jnp.linalg.cholesky(jnp.cov(samples[:-1], rowvar=False, bias=True))
else:
scale = jnp.std(samples[:-1], 0)
actual_loc, actual_scale = _get_proposal_loc_and_scale(samples[:-1], loc, scale, samples[-1])
expected_loc, expected_scale = [], []
for i in range(N - 1):
samples_i = np.delete(samples, i, axis=0)
expected_loc.append(jnp.mean(samples_i, 0))
if dense_mass:
expected_scale.append(jnp.linalg.cholesky(jnp.cov(samples_i, rowvar=False, bias=True)))
else:
expected_scale.append(jnp.std(samples_i, 0))
expected_loc = jnp.stack(expected_loc)
expected_scale = jnp.stack(expected_scale)
assert_allclose(actual_loc, expected_loc, rtol=1e-4)
assert_allclose(actual_scale, expected_scale, atol=1e-6, rtol=0.05)
示例2: _stack
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import stack [as 别名]
def _stack(dim, *x):
return np.stack(x, axis=dim)
示例3: serialize
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import stack [as 别名]
def serialize(self, data):
array = data
batch_size = array.shape[0]
array = (array - self._space.low) / (self._space.high - self._space.low)
array = np.clip(array, 0, 1)
digits = []
for digit_index in range(-1, -self._precision - 1, -1):
threshold = self._vocab_size ** digit_index
digit = np.array(array / threshold).astype(np.int32)
# For the corner case of x == high.
digit = np.where(digit == self._vocab_size, digit - 1, digit)
digits.append(digit)
array -= digit * threshold
digits = np.stack(digits, axis=-1)
return np.reshape(digits, (batch_size, -1))
示例4: conditional_params_from_outputs
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import stack [as 别名]
def conditional_params_from_outputs(image, theta):
"""
Maps image and model output theta to conditional parameters for a mixture
of nr_mix logistics. If the input shapes are
image.shape == (h, w, c)
theta.shape == (h, w, 10 * nr_mix)
the output shapes will be
means.shape == inv_scales.shape == (nr_mix, h, w, c)
logit_probs.shape == (nr_mix, h, w)
"""
assert theta.shape[2] % 10 == 0
nr_mix = theta.shape[2] // 10
logit_probs, theta = jnp.split(theta, [nr_mix], axis=-1)
logit_probs = jnp.moveaxis(logit_probs, -1, 0)
theta = jnp.moveaxis(jnp.reshape(theta, image.shape + (3 * nr_mix,)), -1, 0)
unconditioned_means, log_scales, coeffs = jnp.split(theta, 3)
coeffs = jnp.tanh(coeffs)
# now condition the means for the last 2 channels
mean_red = unconditioned_means[..., 0]
mean_green = unconditioned_means[..., 1] + coeffs[..., 0] * image[..., 0]
mean_blue = (unconditioned_means[..., 2] + coeffs[..., 1] * image[..., 0]
+ coeffs[..., 2] * image[..., 1])
means = jnp.stack((mean_red, mean_green, mean_blue), axis=-1)
inv_scales = softplus(log_scales)
return means, inv_scales, logit_probs
示例5: stack
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import stack [as 别名]
def stack(self, values, axis=0, name='stack'):
return np.stack(values, dim=axis)
示例6: pack
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import stack [as 别名]
def pack(self, *args, **kwargs):
return self.stack(*args, **kwargs)
示例7: stack
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import stack [as 别名]
def stack(self, sequence, axis=0):
if axis == 0:
return np.stack(sequence)
raise RuntimeError('stack axis!=0')
示例8: simulate_data
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import stack [as 别名]
def simulate_data(rng_key, num_categories, num_words, num_supervised_data, num_unsupervised_data):
rng_key, rng_key_transition, rng_key_emission = random.split(rng_key, 3)
transition_prior = jnp.ones(num_categories)
emission_prior = jnp.repeat(0.1, num_words)
transition_prob = dist.Dirichlet(transition_prior).sample(key=rng_key_transition,
sample_shape=(num_categories,))
emission_prob = dist.Dirichlet(emission_prior).sample(key=rng_key_emission,
sample_shape=(num_categories,))
start_prob = jnp.repeat(1. / num_categories, num_categories)
categories, words = [], []
for t in range(num_supervised_data + num_unsupervised_data):
rng_key, rng_key_transition, rng_key_emission = random.split(rng_key, 3)
if t == 0 or t == num_supervised_data:
category = dist.Categorical(start_prob).sample(key=rng_key_transition)
else:
category = dist.Categorical(transition_prob[category]).sample(key=rng_key_transition)
word = dist.Categorical(emission_prob[category]).sample(key=rng_key_emission)
categories.append(category)
words.append(word)
# split into supervised data and unsupervised data
categories, words = jnp.stack(categories), jnp.stack(words)
supervised_categories = categories[:num_supervised_data]
supervised_words = words[:num_supervised_data]
unsupervised_words = words[num_supervised_data:]
return (transition_prior, emission_prior, transition_prob, emission_prob,
supervised_categories, supervised_words, unsupervised_words)
示例9: dz_dt
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import stack [as 别名]
def dz_dt(z, t, theta):
"""
Lotka–Volterra equations. Real positive parameters `alpha`, `beta`, `gamma`, `delta`
describes the interaction of two species.
"""
u = z[0]
v = z[1]
alpha, beta, gamma, delta = theta[..., 0], theta[..., 1], theta[..., 2], theta[..., 3]
du_dt = (alpha - beta * v) * u
dv_dt = (-gamma + delta * u) * v
return jnp.stack([du_dt, dv_dt])
示例10: _laxmap
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import stack [as 别名]
def _laxmap(f, xs):
n = tree_flatten(xs)[0][0].shape[0]
ys = []
for i in range(n):
x = jit(_get_value_from_index)(xs, i)
ys.append(f(x))
return tree_multimap(lambda *args: jnp.stack(args), *ys)
示例11: parametric
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import stack [as 别名]
def parametric(subposteriors, diagonal=False):
"""
Merges subposteriors following (embarrassingly parallel) parametric Monte Carlo algorithm.
**References:**
1. *Asymptotically Exact, Embarrassingly Parallel MCMC*,
Willie Neiswanger, Chong Wang, Eric Xing
:param list subposteriors: a list in which each element is a collection of samples.
:param bool diagonal: whether to compute weights using variance or covariance, defaults to
`False` (using covariance).
:return: the estimated mean and variance/covariance parameters of the joined posterior
"""
joined_subposteriors = tree_multimap(lambda *args: jnp.stack(args), *subposteriors)
joined_subposteriors = vmap(vmap(lambda sample: ravel_pytree(sample)[0]))(joined_subposteriors)
submeans = jnp.mean(joined_subposteriors, axis=1)
if diagonal:
weights = vmap(lambda x: 1 / jnp.var(x, ddof=1, axis=0))(joined_subposteriors)
var = 1 / jnp.sum(weights, axis=0)
normalized_weights = var * weights
# comparing to consensus implementation, we compute weighted mean here
mean = jnp.einsum('ij,ij->j', normalized_weights, submeans)
return mean, var
else:
weights = vmap(lambda x: jnp.linalg.inv(jnp.cov(x.T)))(joined_subposteriors)
cov = jnp.linalg.inv(jnp.sum(weights, axis=0))
normalized_weights = jnp.matmul(cov, weights)
# comparing to consensus implementation, we compute weighted mean here
mean = jnp.einsum('ijk,ik->j', normalized_weights, submeans)
return mean, cov
示例12: log_prob
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import stack [as 别名]
def log_prob(self, value):
return self._dirichlet.log_prob(jnp.stack([value, 1. - value], -1))
示例13: kinetic_fn
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import stack [as 别名]
def kinetic_fn(m_inv, p):
z = jnp.stack([p['x'], p['y']], axis=-1)
return 0.5 * jnp.dot(m_inv, z**2)
示例14: test_seed
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import stack [as 别名]
def test_seed():
def _sample():
x = numpyro.sample('x', dist.Normal(0., 1.))
y = numpyro.sample('y', dist.Normal(1., 2.))
return jnp.stack([x, y])
xs = []
for i in range(100):
with handlers.seed(rng_seed=i):
xs.append(_sample())
xs = jnp.stack(xs)
ys = vmap(lambda rng_key: handlers.seed(lambda: _sample(), rng_key)())(jnp.arange(100))
assert_allclose(xs, ys, atol=1e-6)
示例15: test_nested_seeding
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import stack [as 别名]
def test_nested_seeding():
def fn(rng_key_1, rng_key_2, rng_key_3):
xs = []
with handlers.seed(rng_seed=rng_key_1):
with handlers.seed(rng_seed=rng_key_2):
xs.append(numpyro.sample('x', dist.Normal(0., 1.)))
with handlers.seed(rng_seed=rng_key_3):
xs.append(numpyro.sample('y', dist.Normal(0., 1.)))
return jnp.stack(xs)
s1, s2 = fn(0, 1, 2), fn(3, 1, 2)
assert_allclose(s1, s2)
s1, s2 = fn(0, 1, 2), fn(3, 1, 4)
assert_allclose(s1[0], s2[0])
assert_raises(AssertionError, assert_allclose, s1[1], s2[1])