本文整理汇总了Python中jax.numpy.log方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.log方法的具体用法?Python numpy.log怎么用?Python numpy.log使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类jax.numpy
的用法示例。
在下文中一共展示了numpy.log方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: astensor
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import log [as 别名]
def astensor(self, tensor_in, dtype='float'):
"""
Convert to a JAX ndarray.
Args:
tensor_in (Number or Tensor): Tensor object
Returns:
`jax.interpreters.xla.DeviceArray`: A multi-dimensional, fixed-size homogenous array.
"""
try:
dtype = self.dtypemap[dtype]
except KeyError:
log.error('Invalid dtype: dtype must be float, int, or bool.')
raise
tensor = np.asarray(tensor_in, dtype=dtype)
# Ensure non-empty tensor shape for consistency
try:
tensor.shape[0]
except IndexError:
tensor = np.reshape(tensor, [1])
return np.asarray(tensor, dtype=dtype)
示例2: model
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import log [as 别名]
def model(N, y=None):
"""
:param int N: number of measurement times
:param numpy.ndarray y: measured populations with shape (N, 2)
"""
# initial population
z_init = numpyro.sample("z_init", dist.LogNormal(jnp.log(10), 1), sample_shape=(2,))
# measurement times
ts = jnp.arange(float(N))
# parameters alpha, beta, gamma, delta of dz_dt
theta = numpyro.sample(
"theta",
dist.TruncatedNormal(low=0., loc=jnp.array([0.5, 0.05, 1.5, 0.05]),
scale=jnp.array([0.5, 0.05, 0.5, 0.05])))
# integrate dz/dt, the result will have shape N x 2
z = odeint(dz_dt, z_init, ts, theta, rtol=1e-5, atol=1e-3, mxstep=500)
# measurement errors, we expect that measured hare has larger error than measured lynx
sigma = numpyro.sample("sigma", dist.Exponential(jnp.array([1, 2])))
# measured populations (in log scale)
numpyro.sample("y", dist.Normal(jnp.log(z), sigma), obs=y)
示例3: Tanh
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import log [as 别名]
def Tanh():
"""
Tanh nonlinearity with its log jacobian.
:return: an (`init_fn`, `update_fn`) pair.
"""
def init_fun(rng, input_shape):
return input_shape, ()
def apply_fun(params, inputs, **kwargs):
x, logdet = inputs
out = jnp.tanh(x)
tanh_logdet = -2 * (x + softplus(-2 * x) - jnp.log(2.))
# logdet.shape = batch_shape + (num_blocks, in_factor, out_factor)
# tanh_logdet.shape = batch_shape + (num_blocks x out_factor,)
# so we need to reshape tanh_logdet to: batch_shape + (num_blocks, 1, out_factor)
tanh_logdet = tanh_logdet.reshape(logdet.shape[:-2] + (1, logdet.shape[-1]))
return out, logdet + tanh_logdet
return init_fun, apply_fun
示例4: FanInResidualNormal
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import log [as 别名]
def FanInResidualNormal():
"""
Similar to stax.FanInSum but also keeps track of log determinant of Jacobian.
It is required that the second fan-in branch is identity.
:return: an (`init_fn`, `update_fn`) pair.
"""
def init_fun(rng, input_shape):
return input_shape[0], ()
def apply_fun(params, inputs, **kwargs):
# f(x) + x
(fx, logdet), (x, _) = inputs
return fx + x, softplus(logdet)
return init_fun, apply_fun
示例5: _get_tr_params
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import log [as 别名]
def _get_tr_params(n, p):
# See Table 1. Additionally, we pre-compute log(p), log1(-p) and the
# constant terms, that depend only on (n, p, m) in log(f(k)) (bottom of page 5).
mu = n * p
spq = jnp.sqrt(mu * (1 - p))
c = mu + 0.5
b = 1.15 + 2.53 * spq
a = -0.0873 + 0.0248 * b + 0.01 * p
alpha = (2.83 + 5.1 / b) * spq
u_r = 0.43
v_r = 0.92 - 4.2 / b
m = jnp.floor((n + 1) * p).astype(n.dtype)
log_p = jnp.log(p)
log1_p = jnp.log1p(-p)
log_h = (m + 0.5) * (jnp.log((m + 1.) / (n - m + 1.)) + log1_p - log_p) + \
(stirling_approx_tail(m) + stirling_approx_tail(n - m))
return _tr_params(c, b, a, alpha, u_r, v_r, m, log_p, log1_p, log_h)
示例6: _log
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import log [as 别名]
def _log(x):
return np.log(x)
示例7: gaussian_kl
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import log [as 别名]
def gaussian_kl(mu, sigmasq):
"""KL divergence from a diagonal Gaussian to the standard Gaussian."""
return -0.5 * np.sum(1. + np.log(sigmasq) - mu ** 2. - sigmasq)
示例8: bernoulli_logpdf
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import log [as 别名]
def bernoulli_logpdf(logits, x):
"""Bernoulli log pdf of data x given logits."""
return -np.sum(np.logaddexp(0., np.where(x, -1., 1.) * logits))
示例9: test_ocr_rnn
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import log [as 别名]
def test_ocr_rnn():
length = 5
carry_size = 3
class_count = 4
inputs = jnp.zeros((1, length, 4))
def rnn(): return Rnn(*GRUCell(carry_size, zeros))
net = Sequential(
rnn(),
rnn(),
rnn(),
lambda x: jnp.reshape(x, (-1, carry_size)), # -> same weights for all time steps
Dense(class_count, zeros, zeros),
softmax,
lambda x: jnp.reshape(x, (-1, length, class_count)))
params = net.init_parameters(inputs, key=PRNGKey(0))
assert len(params) == 4
cell = params.rnn0.gru_cell
assert len(cell) == 3
assert jnp.array_equal(jnp.zeros((7, 3)), cell.update_kernel)
assert jnp.array_equal(jnp.zeros((7, 3)), cell.reset_kernel)
assert jnp.array_equal(jnp.zeros((7, 3)), cell.compute_kernel)
out = net.apply(params, inputs)
@parametrized
def cross_entropy(images, targets):
prediction = net(images)
return jnp.mean(-jnp.sum(targets * jnp.log(prediction), (1, 2)))
opt = optimizers.RmsProp(0.003)
state = opt.init(cross_entropy.init_parameters(inputs, out, key=PRNGKey(0)))
state = opt.update(cross_entropy.apply, state, inputs, out)
opt.update(cross_entropy.apply, state, inputs, out, jit=True)
示例10: log
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import log [as 别名]
def log(self, x):
return np.log(x)
示例11: categorical_crossentropy
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import log [as 别名]
def categorical_crossentropy(self, output, target, from_logits=False):
if from_logits:
raise NotImplementedError
return -np.mean(np.log(output) * target, axis=-1)
示例12: logdet
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import log [as 别名]
def logdet(self, A, **kwargs):
A = (A + self.matrix_transpose(A)) / 2.
term = np.log(np.diag(self.cholesky(A, **kwargs)))
return 2 * np.sum(term, axis=-1)
示例13: multigammaln
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import log [as 别名]
def multigammaln(self, a, p):
p = self.to_float(p)
p_ = self.cast(p, 'int32')
a = a[..., None]
i = self.to_float(self.range(1, p_ + 1))
term1 = p * (p - 1) / 4. * self.log(np.pi)
term2 = self.gammaln(a - (i - 1) / 2.)
return term1 + self.sum(term2, axis=-1)
示例14: log
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import log [as 别名]
def log(self, tensor_in):
return np.log(tensor_in)
示例15: poisson_logpdf
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import log [as 别名]
def poisson_logpdf(self, n, lam):
n = np.asarray(n)
lam = np.asarray(lam)
return n * np.log(lam) - lam - gammaln(n + 1.0)