本文整理汇总了Python中jax.numpy.newaxis方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.newaxis方法的具体用法?Python numpy.newaxis怎么用?Python numpy.newaxis使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类jax.numpy
的用法示例。
在下文中一共展示了numpy.newaxis方法的13个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: approximate_kl
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import newaxis [as 别名]
def approximate_kl(log_prob_new, log_prob_old, mask):
"""Computes the approximate KL divergence between the old and new log-probs.
Args:
log_prob_new: (B, T+1, A) log probs new
log_prob_old: (B, T+1, A) log probs old
mask: (B, T)
Returns:
Approximate KL.
"""
diff = log_prob_old - log_prob_new
# Cut the last time-step out.
diff = diff[:, :-1]
# Mask out the irrelevant part.
diff *= mask[:, :, np.newaxis] # make mask (B, T, 1)
# Average on non-masked part.
return np.sum(diff) / np.sum(mask)
示例2: masked_entropy
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import newaxis [as 别名]
def masked_entropy(log_probs, mask):
"""Computes the entropy for the given log-probs.
Args:
log_probs: (B, T+1, A) log probs
mask: (B, T) mask.
Returns:
Entropy.
"""
# Cut the last time-step out.
lp = log_probs[:, :-1]
# Mask out the irrelevant part.
lp *= mask[:, :, np.newaxis] # make mask (B, T, 1)
p = np.exp(lp) * mask[:, :, np.newaxis] # (B, T, 1)
# Average on non-masked part and take negative.
return -(np.sum(lp * p) / np.sum(mask))
示例3: glmm
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import newaxis [as 别名]
def glmm(dept, male, applications, admit=None):
v_mu = numpyro.sample('v_mu', dist.Normal(0, jnp.array([4., 1.])))
sigma = numpyro.sample('sigma', dist.HalfNormal(jnp.ones(2)))
L_Rho = numpyro.sample('L_Rho', dist.LKJCholesky(2, concentration=2))
scale_tril = sigma[..., jnp.newaxis] * L_Rho
# non-centered parameterization
num_dept = len(jnp.unique(dept))
z = numpyro.sample('z', dist.Normal(jnp.zeros((num_dept, 2)), 1))
v = jnp.dot(scale_tril, z.T).T
logits = v_mu[0] + v[dept, 0] + (v_mu[1] + v[dept, 1]) * male
if admit is None:
# we use a Delta site to record probs for predictive distribution
probs = expit(logits)
numpyro.sample('probs', dist.Delta(probs), obs=probs)
numpyro.sample('admit', dist.Binomial(applications, logits=logits), obs=admit)
示例4: __init__
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import newaxis [as 别名]
def __init__(self, loc=0., covariance_matrix=None, precision_matrix=None, scale_tril=None,
validate_args=None):
if jnp.isscalar(loc):
loc = jnp.expand_dims(loc, axis=-1)
# temporary append a new axis to loc
loc = loc[..., jnp.newaxis]
if covariance_matrix is not None:
loc, self.covariance_matrix = promote_shapes(loc, covariance_matrix)
self.scale_tril = jnp.linalg.cholesky(self.covariance_matrix)
elif precision_matrix is not None:
loc, self.precision_matrix = promote_shapes(loc, precision_matrix)
self.scale_tril = cholesky_of_inverse(self.precision_matrix)
elif scale_tril is not None:
loc, self.scale_tril = promote_shapes(loc, scale_tril)
else:
raise ValueError('One of `covariance_matrix`, `precision_matrix`, `scale_tril`'
' must be specified.')
batch_shape = lax.broadcast_shapes(jnp.shape(loc)[:-2], jnp.shape(self.scale_tril)[:-2])
event_shape = jnp.shape(self.scale_tril)[-1:]
self.loc = jnp.broadcast_to(jnp.squeeze(loc, axis=-1), batch_shape + event_shape)
super(MultivariateNormal, self).__init__(batch_shape=batch_shape,
event_shape=event_shape,
validate_args=validate_args)
示例5: solve_implicit
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import newaxis [as 别名]
def solve_implicit(ks, a, b, c, d, b_edge=None, d_edge=None):
land_mask = (ks >= 0)[:, :, np.newaxis]
edge_mask = land_mask & (np.arange(a.shape[2])[np.newaxis, np.newaxis, :]
== ks[:, :, np.newaxis])
water_mask = land_mask & (np.arange(a.shape[2])[np.newaxis, np.newaxis, :]
>= ks[:, :, np.newaxis])
a_tri = water_mask * a * np.logical_not(edge_mask)
b_tri = where(water_mask, b, 1.)
if b_edge is not None:
b_tri = where(edge_mask, b_edge, b_tri)
c_tri = water_mask * c
d_tri = water_mask * d
if d_edge is not None:
d_tri = where(edge_mask, d_edge, d_tri)
return solve_tridiag(a_tri, b_tri, c_tri, d_tri), water_mask
示例6: _adv_superbee
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import newaxis [as 别名]
def _adv_superbee(vel, var, mask, dx, axis, cost, cosu, dt_tracer):
velfac = 1
if axis == 0:
sm1, s, sp1, sp2 = ((slice(1 + n, -2 + n or None), slice(2, -2), slice(None))
for n in range(-1, 3))
dx = cost[np.newaxis, 2:-2, np.newaxis] * \
dx[1:-2, np.newaxis, np.newaxis]
elif axis == 1:
sm1, s, sp1, sp2 = ((slice(2, -2), slice(1 + n, -2 + n or None), slice(None))
for n in range(-1, 3))
dx = (cost * dx)[np.newaxis, 1:-2, np.newaxis]
velfac = cosu[np.newaxis, 1:-2, np.newaxis]
elif axis == 2:
vel, var, mask = (pad_z_edges(a) for a in (vel, var, mask))
sm1, s, sp1, sp2 = ((slice(2, -2), slice(2, -2), slice(1 + n, -2 + n or None))
for n in range(-1, 3))
dx = dx[np.newaxis, np.newaxis, :-1]
else:
raise ValueError('axis must be 0, 1, or 2')
uCFL = np.abs(velfac * vel[s] * dt_tracer / dx)
rjp = (var[sp2] - var[sp1]) * mask[sp1]
rj = (var[sp1] - var[s]) * mask[s]
rjm = (var[s] - var[sm1]) * mask[sm1]
cr = limiter(_calc_cr(rjp, rj, rjm, vel[s]))
return velfac * vel[s] * (var[sp1] + var[s]) * 0.5 - np.abs(velfac * vel[s]) * ((1. - cr) + uCFL * cr) * rj * 0.5
示例7: _normalize_by_window_size
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import newaxis [as 别名]
def _normalize_by_window_size(dims, spatial_strides, padding): # pylint: disable=invalid-name
def rescale(outputs, inputs):
one = jnp.ones(inputs.shape[1:-1], dtype=inputs.dtype)
window_sizes = lax.reduce_window(
one, 0., lax.add, dims, spatial_strides, padding)
return outputs / window_sizes[..., jnp.newaxis]
return rescale
示例8: _normalize_by_window_size
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import newaxis [as 别名]
def _normalize_by_window_size(dims, strides, padding):
def rescale(outputs, inputs):
one = np.ones(inputs.shape[1:-1], dtype=inputs.dtype)
window_sizes = lax.reduce_window(one, 0., lax.add, dims, strides, padding)
return outputs / window_sizes[..., np.newaxis]
return rescale
示例9: __call__
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import newaxis [as 别名]
def __call__(self, x):
return self.loc + jnp.squeeze(jnp.matmul(self.scale_tril, x[..., jnp.newaxis]), axis=-1)
示例10: sum_rightmost
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import newaxis [as 别名]
def sum_rightmost(x, dim):
"""
Sum out ``dim`` many rightmost dimensions of a given tensor.
"""
out_dim = jnp.ndim(x) - dim
x = jnp.reshape(x[..., jnp.newaxis], jnp.shape(x)[:out_dim] + (-1,))
return jnp.sum(x, axis=-1)
示例11: sample
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import newaxis [as 别名]
def sample(self, key, sample_shape=()):
eps = random.normal(key, shape=sample_shape + self.batch_shape + self.event_shape)
return self.loc + jnp.squeeze(jnp.matmul(self.scale_tril, eps[..., jnp.newaxis]), axis=-1)
示例12: covariance_matrix
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import newaxis [as 别名]
def covariance_matrix(self):
# TODO: find a better solution to create a diagonal matrix
new_diag = self.cov_diag[..., jnp.newaxis] * jnp.identity(self.loc.shape[-1])
covariance_matrix = new_diag + jnp.matmul(
self.cov_factor, jnp.swapaxes(self.cov_factor, -1, -2)
)
return covariance_matrix
示例13: precision_matrix
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import newaxis [as 别名]
def precision_matrix(self):
# We use "Woodbury matrix identity" to take advantage of low rank form::
# inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D)
# where :math:`C` is the capacitance matrix.
Wt_Dinv = (jnp.swapaxes(self.cov_factor, -1, -2)
/ jnp.expand_dims(self.cov_diag, axis=-2))
A = solve_triangular(Wt_Dinv, self._capacitance_tril, lower=True)
# TODO: find a better solution to create a diagonal matrix
inverse_cov_diag = jnp.reciprocal(self.cov_diag)
diag_embed = inverse_cov_diag[..., jnp.newaxis] * jnp.identity(self.loc.shape[-1])
return diag_embed - jnp.matmul(jnp.swapaxes(A, -1, -2), A)