本文整理汇总了Python中jax.numpy.expand_dims方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.expand_dims方法的具体用法?Python numpy.expand_dims怎么用?Python numpy.expand_dims使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类jax.numpy
的用法示例。
在下文中一共展示了numpy.expand_dims方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _multinomial
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import expand_dims [as 别名]
def _multinomial(key, p, n, n_max, shape=()):
if jnp.shape(n) != jnp.shape(p)[:-1]:
broadcast_shape = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1])
n = jnp.broadcast_to(n, broadcast_shape)
p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:])
shape = shape or p.shape[:-1]
# get indices from categorical distribution then gather the result
indices = categorical(key, p, (n_max,) + shape)
# mask out values when counts is heterogeneous
if jnp.ndim(n) > 0:
mask = promote_shapes(jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,))[0]
mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype)
excess = jnp.concatenate([jnp.expand_dims(n_max - n, -1), jnp.zeros(jnp.shape(n) + (p.shape[-1] - 1,))], -1)
else:
mask = 1
excess = 0
# NB: we transpose to move batch shape to the front
indices_2D = (jnp.reshape(indices * mask, (n_max, -1,))).T
samples_2D = vmap(_scatter_add_one, (0, 0, 0))(jnp.zeros((indices_2D.shape[0], p.shape[-1]),
dtype=indices.dtype),
jnp.expand_dims(indices_2D, axis=-1),
jnp.ones(indices_2D.shape, dtype=indices.dtype))
return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess
示例2: _onion
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import expand_dims [as 别名]
def _onion(self, key, size):
key_beta, key_normal = random.split(key)
# Now we generate w term in Algorithm 3.2 of [1].
beta_sample = self._beta.sample(key_beta, size)
# The following Normal distribution is used to create a uniform distribution on
# a hypershere (ref: http://mathworld.wolfram.com/HyperspherePointPicking.html)
normal_sample = random.normal(
key_normal,
shape=size + self.batch_shape + (self.dimension * (self.dimension - 1) // 2,)
)
normal_sample = vec_to_tril_matrix(normal_sample, diagonal=0)
u_hypershere = normal_sample / jnp.linalg.norm(normal_sample, axis=-1, keepdims=True)
w = jnp.expand_dims(jnp.sqrt(beta_sample), axis=-1) * u_hypershere
# put w into the off-diagonal triangular part
cholesky = ops.index_add(jnp.zeros(size + self.batch_shape + self.event_shape),
ops.index[..., 1:, :-1], w)
# correct the diagonal
# NB: we clip due to numerical precision
diag = jnp.sqrt(jnp.clip(1 - jnp.sum(cholesky ** 2, axis=-1), a_min=0.))
cholesky = cholesky + jnp.expand_dims(diag, axis=-1) * jnp.identity(self.dimension)
return cholesky
示例3: __init__
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import expand_dims [as 别名]
def __init__(self, loc=0., covariance_matrix=None, precision_matrix=None, scale_tril=None,
validate_args=None):
if jnp.isscalar(loc):
loc = jnp.expand_dims(loc, axis=-1)
# temporary append a new axis to loc
loc = loc[..., jnp.newaxis]
if covariance_matrix is not None:
loc, self.covariance_matrix = promote_shapes(loc, covariance_matrix)
self.scale_tril = jnp.linalg.cholesky(self.covariance_matrix)
elif precision_matrix is not None:
loc, self.precision_matrix = promote_shapes(loc, precision_matrix)
self.scale_tril = cholesky_of_inverse(self.precision_matrix)
elif scale_tril is not None:
loc, self.scale_tril = promote_shapes(loc, scale_tril)
else:
raise ValueError('One of `covariance_matrix`, `precision_matrix`, `scale_tril`'
' must be specified.')
batch_shape = lax.broadcast_shapes(jnp.shape(loc)[:-2], jnp.shape(self.scale_tril)[:-2])
event_shape = jnp.shape(self.scale_tril)[-1:]
self.loc = jnp.broadcast_to(jnp.squeeze(loc, axis=-1), batch_shape + event_shape)
super(MultivariateNormal, self).__init__(batch_shape=batch_shape,
event_shape=event_shape,
validate_args=validate_args)
示例4: logprob_from_conditional_params
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import expand_dims [as 别名]
def logprob_from_conditional_params(images, means, inv_scales, logit_probs):
images = jnp.expand_dims(images, 1)
cdf = lambda offset: sigmoid((images - means + offset) * inv_scales)
upper_cdf = jnp.where(images == 1, 1, cdf(1 / 255))
lower_cdf = jnp.where(images == -1, 0, cdf(-1 / 255))
all_logprobs = jnp.sum(jnp.log(jnp.maximum(upper_cdf - lower_cdf, 1e-12)), -1)
log_mix_coeffs = logit_probs - logsumexp(logit_probs, -3, keepdims=True)
return jnp.sum(logsumexp(log_mix_coeffs + all_logprobs, axis=-3), axis=(-2, -1))
示例5: _extract_signal_patches
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import expand_dims [as 别名]
def _extract_signal_patches(signal, window_length, hop=1, data_format="NCW"):
assert not hasattr(window_length, "__len__")
assert signal.ndim == 3
if data_format == "NCW":
N = (signal.shape[2] - window_length) // hop + 1
indices = jnp.arange(window_length) + jnp.expand_dims(jnp.arange(N) * hop, 1)
indices = jnp.reshape(indices, [1, 1, N * window_length])
patches = jnp.take_along_axis(signal, indices, 2)
return jnp.reshape(patches, signal.shape[:2] + (N, window_length))
else:
error
示例6: _extract_image_patches
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import expand_dims [as 别名]
def _extract_image_patches(
image, window_shape, hop=1, data_format="NCHW", mode="valid"
):
if mode == "same":
p1 = window_shape[0] - 1
p2 = window_shape[1] - 1
image = jnp.pad(
image, [(0, 0), (0, 0), (p1 // 2, p1 - p1 // 2), (p2 // 2, p2 - p2 // 2)]
)
if not hasattr(hop, "__len__"):
hop = (hop, hop)
if data_format == "NCHW":
# compute the number of windows in both dimensions
N = (
(image.shape[2] - window_shape[0]) // hop[0] + 1,
(image.shape[3] - window_shape[1]) // hop[1] + 1,
)
# compute the base indices of a 2d patch
patch = jnp.arange(numpy.prod(window_shape)).reshape(window_shape)
offset = jnp.expand_dims(jnp.arange(window_shape[0]), 1)
patch_indices = patch + offset * (image.shape[3] - window_shape[1])
# create all the shifted versions of it
ver_shifts = jnp.reshape(
jnp.arange(N[0]) * hop[0] * image.shape[3], (-1, 1, 1, 1)
)
hor_shifts = jnp.reshape(jnp.arange(N[1]) * hop[1], (-1, 1, 1))
all_cols = patch_indices + jnp.reshape(jnp.arange(N[1]) * hop[1], (-1, 1, 1))
indices = patch_indices + ver_shifts + hor_shifts
# now extract shape (1, 1, H'W'a'b')
flat_indices = jnp.reshape(indices, [1, 1, -1])
# shape is now (N, C, W*H)
flat_image = jnp.reshape(image, (image.shape[0], image.shape[1], -1))
# shape is now (N, C)
patches = jnp.take_along_axis(flat_image, flat_indices, 2)
return jnp.reshape(patches, image.shape[:2] + N + tuple(window_shape))
else:
error
示例7: expand_dims
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import expand_dims [as 别名]
def expand_dims(self, x, dim=-1):
return np.expand_dims(x, dim=dim)
示例8: forward_one_step
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import expand_dims [as 别名]
def forward_one_step(prev_log_prob, curr_word, transition_log_prob, emission_log_prob):
log_prob_tmp = jnp.expand_dims(prev_log_prob, axis=1) + transition_log_prob
log_prob = log_prob_tmp + emission_log_prob[:, curr_word]
return logsumexp(log_prob, axis=0)
示例9: __call__
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import expand_dims [as 别名]
def __call__(self, x):
n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2)
z = vec_to_tril_matrix(x[..., :-n], diagonal=-1)
diag = jnp.exp(x[..., -n:])
return z + jnp.expand_dims(diag, axis=-1) * jnp.identity(n)
示例10: vec_to_tril_matrix
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import expand_dims [as 别名]
def vec_to_tril_matrix(t, diagonal=0):
# NB: the following formula only works for diagonal <= 0
n = round((math.sqrt(1 + 8 * t.shape[-1]) - 1) / 2) - diagonal
n2 = n * n
idx = jnp.reshape(jnp.arange(n2), (n, n))[jnp.tril_indices(n, diagonal)]
x = lax.scatter_add(jnp.zeros(t.shape[:-1] + (n2,)), jnp.expand_dims(idx, axis=-1), t,
lax.ScatterDimensionNumbers(update_window_dims=range(t.ndim - 1),
inserted_window_dims=(t.ndim - 1,),
scatter_dims_to_operand_dims=(t.ndim - 1,)))
return jnp.reshape(x, x.shape[:-1] + (n, n))
示例11: sample
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import expand_dims [as 别名]
def sample(self, key, sample_shape=()):
shape = sample_shape + self.batch_shape + self.event_shape
walks = random.normal(key, shape=shape)
return jnp.cumsum(walks, axis=-1) * jnp.expand_dims(self.scale, axis=-1)
示例12: variance
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import expand_dims [as 别名]
def variance(self):
return jnp.broadcast_to(jnp.expand_dims(self.scale, -1) ** 2 * jnp.arange(1, self.num_steps + 1),
self.batch_shape + self.event_shape)
示例13: _batch_capacitance_tril
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import expand_dims [as 别名]
def _batch_capacitance_tril(W, D):
r"""
Computes Cholesky of :math:`I + W.T @ inv(D) @ W` for a batch of matrices :math:`W`
and a batch of vectors :math:`D`.
"""
Wt_Dinv = jnp.swapaxes(W, -1, -2) / jnp.expand_dims(D, -2)
K = jnp.matmul(Wt_Dinv, W)
# could be inefficient
return jnp.linalg.cholesky(jnp.add(K, jnp.identity(K.shape[-1])))
示例14: _batch_lowrank_mahalanobis
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import expand_dims [as 别名]
def _batch_lowrank_mahalanobis(W, D, x, capacitance_tril):
r"""
Uses "Woodbury matrix identity"::
inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D),
where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute the squared
Mahalanobis distance :math:`x.T @ inv(W @ W.T + D) @ x`.
"""
Wt_Dinv = jnp.swapaxes(W, -1, -2) / jnp.expand_dims(D, -2)
Wt_Dinv_x = _batch_mv(Wt_Dinv, x)
mahalanobis_term1 = jnp.sum(jnp.square(x) / D, axis=-1)
mahalanobis_term2 = _batch_mahalanobis(capacitance_tril, Wt_Dinv_x)
return mahalanobis_term1 - mahalanobis_term2
示例15: scale_tril
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import expand_dims [as 别名]
def scale_tril(self):
# The following identity is used to increase the numerically computation stability
# for Cholesky decomposition (see http://www.gaussianprocess.org/gpml/, Section 3.4.3):
# W @ W.T + D = D1/2 @ (I + D-1/2 @ W @ W.T @ D-1/2) @ D1/2
# The matrix "I + D-1/2 @ W @ W.T @ D-1/2" has eigenvalues bounded from below by 1,
# hence it is well-conditioned and safe to take Cholesky decomposition.
cov_diag_sqrt_unsqueeze = jnp.expand_dims(jnp.sqrt(self.cov_diag), axis=-1)
Dinvsqrt_W = self.cov_factor / cov_diag_sqrt_unsqueeze
K = jnp.matmul(Dinvsqrt_W, jnp.swapaxes(Dinvsqrt_W, -1, -2))
K = jnp.add(K, jnp.identity(K.shape[-1]))
scale_tril = cov_diag_sqrt_unsqueeze * jnp.linalg.cholesky(K)
return scale_tril