本文整理汇总了Python中jax.numpy.arange方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.arange方法的具体用法?Python numpy.arange怎么用?Python numpy.arange使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类jax.numpy
的用法示例。
在下文中一共展示了numpy.arange方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: chosen_probabs
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import arange [as 别名]
def chosen_probabs(probab_observations, actions):
"""Picks out the probabilities of the actions along batch and time-steps.
Args:
probab_observations: ndarray of shape `[B, T+1, A]`, where
probab_observations[b, t, i] contains the log-probability of action = i at
the t^th time-step in the b^th trajectory.
actions: ndarray of shape `[B, T]`, with each entry in [0, A) denoting which
action was chosen in the b^th trajectory's t^th time-step.
Returns:
`[B, T]` ndarray with the log-probabilities of the chosen actions.
"""
B, T = actions.shape # pylint: disable=invalid-name
assert (B, T + 1) == probab_observations.shape[:2]
return probab_observations[np.arange(B)[:, None], np.arange(T), actions]
示例2: get_data
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import arange [as 别名]
def get_data(N=50, D_X=3, sigma_obs=0.05, N_test=500):
D_Y = 1 # create 1d outputs
np.random.seed(0)
X = jnp.linspace(-1, 1, N)
X = jnp.power(X[:, np.newaxis], jnp.arange(D_X))
W = 0.5 * np.random.randn(D_X)
Y = jnp.dot(X, W) + 0.5 * jnp.power(0.5 + X[:, 1], 2.0) * jnp.sin(4.0 * X[:, 1])
Y += sigma_obs * np.random.randn(N)
Y = Y[:, np.newaxis]
Y -= jnp.mean(Y)
Y /= jnp.std(Y)
assert X.shape == (N, D_X)
assert Y.shape == (N, D_Y)
X_test = jnp.linspace(-1.3, 1.3, N_test)
X_test = jnp.power(X_test[:, np.newaxis], jnp.arange(D_X))
return X, Y, X_test
示例3: model
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import arange [as 别名]
def model(N, y=None):
"""
:param int N: number of measurement times
:param numpy.ndarray y: measured populations with shape (N, 2)
"""
# initial population
z_init = numpyro.sample("z_init", dist.LogNormal(jnp.log(10), 1), sample_shape=(2,))
# measurement times
ts = jnp.arange(float(N))
# parameters alpha, beta, gamma, delta of dz_dt
theta = numpyro.sample(
"theta",
dist.TruncatedNormal(low=0., loc=jnp.array([0.5, 0.05, 1.5, 0.05]),
scale=jnp.array([0.5, 0.05, 0.5, 0.05])))
# integrate dz/dt, the result will have shape N x 2
z = odeint(dz_dt, z_init, ts, theta, rtol=1e-5, atol=1e-3, mxstep=500)
# measurement errors, we expect that measured hare has larger error than measured lynx
sigma = numpyro.sample("sigma", dist.Exponential(jnp.array([1, 2])))
# measured populations (in log scale)
numpyro.sample("y", dist.Normal(jnp.log(z), sigma), obs=y)
示例4: _multinomial
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import arange [as 别名]
def _multinomial(key, p, n, n_max, shape=()):
if jnp.shape(n) != jnp.shape(p)[:-1]:
broadcast_shape = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1])
n = jnp.broadcast_to(n, broadcast_shape)
p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:])
shape = shape or p.shape[:-1]
# get indices from categorical distribution then gather the result
indices = categorical(key, p, (n_max,) + shape)
# mask out values when counts is heterogeneous
if jnp.ndim(n) > 0:
mask = promote_shapes(jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,))[0]
mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype)
excess = jnp.concatenate([jnp.expand_dims(n_max - n, -1), jnp.zeros(jnp.shape(n) + (p.shape[-1] - 1,))], -1)
else:
mask = 1
excess = 0
# NB: we transpose to move batch shape to the front
indices_2D = (jnp.reshape(indices * mask, (n_max, -1,))).T
samples_2D = vmap(_scatter_add_one, (0, 0, 0))(jnp.zeros((indices_2D.shape[0], p.shape[-1]),
dtype=indices.dtype),
jnp.expand_dims(indices_2D, axis=-1),
jnp.ones(indices_2D.shape, dtype=indices.dtype))
return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess
示例5: test_chain
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import arange [as 别名]
def test_chain(use_init_params, chain_method):
N, dim = 3000, 3
num_chains = 2
num_warmup, num_samples = 5000, 5000
data = random.normal(random.PRNGKey(0), (N, dim))
true_coefs = jnp.arange(1., dim + 1.)
logits = jnp.sum(true_coefs * data, axis=-1)
labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))
def model(labels):
coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
logits = jnp.sum(coefs * data, axis=-1)
return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)
kernel = NUTS(model=model)
mcmc = MCMC(kernel, num_warmup, num_samples, num_chains=num_chains)
mcmc.chain_method = chain_method
init_params = None if not use_init_params else \
{'coefs': jnp.tile(jnp.ones(dim), num_chains).reshape(num_chains, dim)}
mcmc.run(random.PRNGKey(2), labels, init_params=init_params)
samples_flat = mcmc.get_samples()
assert samples_flat['coefs'].shape[0] == num_chains * num_samples
samples = mcmc.get_samples(group_by_chain=True)
assert samples['coefs'].shape[:2] == (num_chains, num_samples)
assert_allclose(jnp.mean(samples_flat['coefs'], 0), true_coefs, atol=0.21)
示例6: test_gaussian_subposterior
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import arange [as 别名]
def test_gaussian_subposterior(method, diagonal):
D = 10
n_samples = 10000
n_draws = 9000
n_subs = 8
mean = jnp.arange(D)
cov = jnp.ones((D, D)) * 0.9 + jnp.identity(D) * 0.1
subcov = n_subs * cov # subposterior's covariance
subposteriors = list(dist.MultivariateNormal(mean, subcov).sample(
random.PRNGKey(1), (n_subs, n_samples)))
draws = method(subposteriors, n_draws, diagonal=diagonal)
assert draws.shape == (n_draws, D)
assert_allclose(jnp.mean(draws, axis=0), mean, atol=0.03)
if diagonal:
assert_allclose(jnp.var(draws, axis=0), jnp.diag(cov), atol=0.05)
else:
assert_allclose(jnp.cov(draws.T), cov, atol=0.05)
示例7: test_categorical_log_prob_grad
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import arange [as 别名]
def test_categorical_log_prob_grad():
data = jnp.repeat(jnp.arange(3), 10)
def f(x):
return dist.Categorical(jax.nn.softmax(x * jnp.arange(1, 4))).log_prob(data).sum()
def g(x):
return dist.Categorical(logits=x * jnp.arange(1, 4)).log_prob(data).sum()
x = 0.5
fx, grad_fx = jax.value_and_grad(f)(x)
gx, grad_gx = jax.value_and_grad(g)(x)
assert_allclose(fx, gx)
assert_allclose(grad_fx, grad_gx, atol=1e-4)
########################################
# Tests for constraints and transforms #
########################################
示例8: solve_implicit
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import arange [as 别名]
def solve_implicit(ks, a, b, c, d, b_edge=None, d_edge=None):
land_mask = (ks >= 0)[:, :, np.newaxis]
edge_mask = land_mask & (np.arange(a.shape[2])[np.newaxis, np.newaxis, :]
== ks[:, :, np.newaxis])
water_mask = land_mask & (np.arange(a.shape[2])[np.newaxis, np.newaxis, :]
>= ks[:, :, np.newaxis])
a_tri = water_mask * a * np.logical_not(edge_mask)
b_tri = where(water_mask, b, 1.)
if b_edge is not None:
b_tri = where(edge_mask, b_edge, b_tri)
c_tri = water_mask * c
d_tri = water_mask * d
if d_edge is not None:
d_tri = where(edge_mask, d_edge, d_tri)
return solve_tridiag(a_tri, b_tri, c_tri, d_tri), water_mask
示例9: _new_arange
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import arange [as 别名]
def _new_arange(x, start, stop, step):
return np.arange(start, stop, step)
示例10: significance_map
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import arange [as 别名]
def significance_map(self):
return np.reshape(np.broadcast_to(
np.arange(self._precision), self._space.shape + (self._precision,)), -1)
示例11: main
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import arange [as 别名]
def main():
key = PRNGKey(0)
batch_size = 8
num_classes = 1001
input_shape = (224, 224, 3, batch_size)
step_size = 0.1
num_steps = 10
resnet = ResNet50(num_classes)
@parametrized
def loss(inputs, targets):
logits = resnet(inputs)
return np.sum(logits * targets)
@parametrized
def accuracy(inputs, targets):
target_class = np.argmax(targets, axis=-1)
predicted_class = np.argmax(resnet(inputs), axis=-1)
return np.mean(predicted_class == target_class)
def synth_batches():
rng = npr.RandomState(0)
while True:
images = rng.rand(*input_shape).astype('float32')
labels = rng.randint(num_classes, size=(batch_size, 1))
onehot_labels = labels == np.arange(num_classes)
yield images, onehot_labels
opt = optimizers.Momentum(step_size, mass=0.9)
batches = synth_batches()
print("\nInitializing parameters.")
state = opt.init(loss.init_parameters(*next(batches), key=key))
for i in range(num_steps):
print(f'Training on batch {i}.')
state = opt.update(loss.apply, state, *next(batches))
trained_params = opt.get_parameters(state)
示例12: _one_hot
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import arange [as 别名]
def _one_hot(x, k, dtype=np.float32):
"""Create a one-hot encoding of x of size k."""
return np.array(x[:, None] == np.arange(k), dtype)
示例13: _extract_signal_patches
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import arange [as 别名]
def _extract_signal_patches(signal, window_length, hop=1, data_format="NCW"):
assert not hasattr(window_length, "__len__")
assert signal.ndim == 3
if data_format == "NCW":
N = (signal.shape[2] - window_length) // hop + 1
indices = jnp.arange(window_length) + jnp.expand_dims(jnp.arange(N) * hop, 1)
indices = jnp.reshape(indices, [1, 1, N * window_length])
patches = jnp.take_along_axis(signal, indices, 2)
return jnp.reshape(patches, signal.shape[:2] + (N, window_length))
else:
error
示例14: _extract_image_patches
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import arange [as 别名]
def _extract_image_patches(
image, window_shape, hop=1, data_format="NCHW", mode="valid"
):
if mode == "same":
p1 = window_shape[0] - 1
p2 = window_shape[1] - 1
image = jnp.pad(
image, [(0, 0), (0, 0), (p1 // 2, p1 - p1 // 2), (p2 // 2, p2 - p2 // 2)]
)
if not hasattr(hop, "__len__"):
hop = (hop, hop)
if data_format == "NCHW":
# compute the number of windows in both dimensions
N = (
(image.shape[2] - window_shape[0]) // hop[0] + 1,
(image.shape[3] - window_shape[1]) // hop[1] + 1,
)
# compute the base indices of a 2d patch
patch = jnp.arange(numpy.prod(window_shape)).reshape(window_shape)
offset = jnp.expand_dims(jnp.arange(window_shape[0]), 1)
patch_indices = patch + offset * (image.shape[3] - window_shape[1])
# create all the shifted versions of it
ver_shifts = jnp.reshape(
jnp.arange(N[0]) * hop[0] * image.shape[3], (-1, 1, 1, 1)
)
hor_shifts = jnp.reshape(jnp.arange(N[1]) * hop[1], (-1, 1, 1))
all_cols = patch_indices + jnp.reshape(jnp.arange(N[1]) * hop[1], (-1, 1, 1))
indices = patch_indices + ver_shifts + hor_shifts
# now extract shape (1, 1, H'W'a'b')
flat_indices = jnp.reshape(indices, [1, 1, -1])
# shape is now (N, C, W*H)
flat_image = jnp.reshape(image, (image.shape[0], image.shape[1], -1))
# shape is now (N, C)
patches = jnp.take_along_axis(flat_image, flat_indices, 2)
return jnp.reshape(patches, image.shape[:2] + N + tuple(window_shape))
else:
error
示例15: one_hot
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import arange [as 别名]
def one_hot(i, N, dtype="float32"):
"""Create a one-hot encoding of x of size k."""
if hasattr(i, "shape"):
return (x[:, None] == arange(k)).astype(dtype)
else:
z = T.zeros(N, dtype)
return index_add(z, i, 1)