本文整理汇总了Python中jax.numpy.abs方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.abs方法的具体用法?Python numpy.abs怎么用?Python numpy.abs使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类jax.numpy
的用法示例。
在下文中一共展示了numpy.abs方法的11个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _adv_superbee
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import abs [as 别名]
def _adv_superbee(vel, var, mask, dx, axis, cost, cosu, dt_tracer):
velfac = 1
if axis == 0:
sm1, s, sp1, sp2 = ((slice(1 + n, -2 + n or None), slice(2, -2), slice(None))
for n in range(-1, 3))
dx = cost[np.newaxis, 2:-2, np.newaxis] * \
dx[1:-2, np.newaxis, np.newaxis]
elif axis == 1:
sm1, s, sp1, sp2 = ((slice(2, -2), slice(1 + n, -2 + n or None), slice(None))
for n in range(-1, 3))
dx = (cost * dx)[np.newaxis, 1:-2, np.newaxis]
velfac = cosu[np.newaxis, 1:-2, np.newaxis]
elif axis == 2:
vel, var, mask = (pad_z_edges(a) for a in (vel, var, mask))
sm1, s, sp1, sp2 = ((slice(2, -2), slice(2, -2), slice(1 + n, -2 + n or None))
for n in range(-1, 3))
dx = dx[np.newaxis, np.newaxis, :-1]
else:
raise ValueError('axis must be 0, 1, or 2')
uCFL = np.abs(velfac * vel[s] * dt_tracer / dx)
rjp = (var[sp2] - var[sp1]) * mask[sp1]
rj = (var[sp1] - var[s]) * mask[s]
rjm = (var[s] - var[sm1]) * mask[sm1]
cr = limiter(_calc_cr(rjp, rj, rjm, vel[s]))
return velfac * vel[s] * (var[sp1] + var[s]) * 0.5 - np.abs(velfac * vel[s]) * ((1. - cr) + uCFL * cr) * rj * 0.5
示例2: abs
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import abs [as 别名]
def abs(self, x):
return np.abs(x)
示例3: abs
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import abs [as 别名]
def abs(self, tensor):
return np.abs(tensor)
示例4: main
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import abs [as 别名]
def main(args):
# Generate some data.
data = random.normal(PRNGKey(0), shape=(100,)) + 3.0
# Construct an SVI object so we can do variational inference on our
# model/guide pair.
adam = optim.Adam(args.learning_rate)
svi = SVI(model, guide, adam, ELBO(num_particles=100))
svi_state = svi.init(PRNGKey(0), data)
# Training loop
def body_fn(i, val):
svi_state, loss = svi.update(val, data)
return svi_state
svi_state = fori_loop(0, args.num_steps, body_fn, svi_state)
# Report the final values of the variational parameters
# in the guide after training.
params = svi.get_params(svi_state)
for name, value in params.items():
print("{} = {}".format(name, value))
# For this simple (conjugate) model we know the exact posterior. In
# particular we know that the variational distribution should be
# centered near 3.0. So let's check this explicitly.
assert jnp.abs(params["guide_loc"] - 3.0) < 0.1
示例5: __call__
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import abs [as 别名]
def __call__(self, x):
return jnp.abs(x)
示例6: log_abs_det_jacobian
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import abs [as 别名]
def log_abs_det_jacobian(self, x, y, intermediates=None):
return sum_rightmost(jnp.broadcast_to(jnp.log(jnp.abs(self.scale)), jnp.shape(x)), self.event_dim)
示例7: binary_cross_entropy_with_logits
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import abs [as 别名]
def binary_cross_entropy_with_logits(x, y):
# compute -y * log(sigmoid(x)) - (1 - y) * log(1 - sigmoid(x))
# Ref: https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits
return jnp.clip(x, 0) + jnp.log1p(jnp.exp(-jnp.abs(x))) - x * y
示例8: sample
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import abs [as 别名]
def sample(self, key, sample_shape=()):
return jnp.abs(self._cauchy.sample(key, sample_shape))
示例9: log_prob
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import abs [as 别名]
def log_prob(self, value):
normalize_term = jnp.log(2 * self.scale)
value_scaled = jnp.abs(value - self.loc) / self.scale
return -value_scaled - normalize_term
示例10: log_prob
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import abs [as 别名]
def log_prob(self, value):
log_factorial_n = gammaln(self.total_count + 1)
log_factorial_k = gammaln(value + 1)
log_factorial_nmk = gammaln(self.total_count - value + 1)
normalize_term = (self.total_count * jnp.clip(self.logits, 0) +
xlog1py(self.total_count, jnp.exp(-jnp.abs(self.logits))) -
log_factorial_n)
return value * self.logits - log_factorial_k - log_factorial_nmk - normalize_term
示例11: _calc_cr
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import abs [as 别名]
def _calc_cr(rjp, rj, rjm, vel):
"""
Calculates cr value used in superbee advection scheme
"""
eps = 1e-20 # prevent division by 0
return where(vel > 0., rjm, rjp) / where(np.abs(rj) < eps, eps, rj)