当前位置: 首页>>代码示例>>Python>>正文


Python numpy.log方法代码示例

本文整理汇总了Python中jax.numpy.log方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.log方法的具体用法?Python numpy.log怎么用?Python numpy.log使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在jax.numpy的用法示例。


在下文中一共展示了numpy.log方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: astensor

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import log [as 别名]
def astensor(self, tensor_in, dtype='float'):
        """
        Convert to a JAX ndarray.

        Args:
            tensor_in (Number or Tensor): Tensor object

        Returns:
            `jax.interpreters.xla.DeviceArray`: A multi-dimensional, fixed-size homogenous array.
        """
        try:
            dtype = self.dtypemap[dtype]
        except KeyError:
            log.error('Invalid dtype: dtype must be float, int, or bool.')
            raise
        tensor = np.asarray(tensor_in, dtype=dtype)
        # Ensure non-empty tensor shape for consistency
        try:
            tensor.shape[0]
        except IndexError:
            tensor = np.reshape(tensor, [1])
        return np.asarray(tensor, dtype=dtype) 
开发者ID:scikit-hep,项目名称:pyhf,代码行数:24,代码来源:jax_backend.py

示例2: model

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import log [as 别名]
def model(N, y=None):
    """
    :param int N: number of measurement times
    :param numpy.ndarray y: measured populations with shape (N, 2)
    """
    # initial population
    z_init = numpyro.sample("z_init", dist.LogNormal(jnp.log(10), 1), sample_shape=(2,))
    # measurement times
    ts = jnp.arange(float(N))
    # parameters alpha, beta, gamma, delta of dz_dt
    theta = numpyro.sample(
        "theta",
        dist.TruncatedNormal(low=0., loc=jnp.array([0.5, 0.05, 1.5, 0.05]),
                             scale=jnp.array([0.5, 0.05, 0.5, 0.05])))
    # integrate dz/dt, the result will have shape N x 2
    z = odeint(dz_dt, z_init, ts, theta, rtol=1e-5, atol=1e-3, mxstep=500)
    # measurement errors, we expect that measured hare has larger error than measured lynx
    sigma = numpyro.sample("sigma", dist.Exponential(jnp.array([1, 2])))
    # measured populations (in log scale)
    numpyro.sample("y", dist.Normal(jnp.log(z), sigma), obs=y) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:22,代码来源:ode.py

示例3: Tanh

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import log [as 别名]
def Tanh():
    """
    Tanh nonlinearity with its log jacobian.

    :return: an (`init_fn`, `update_fn`) pair.
    """
    def init_fun(rng, input_shape):
        return input_shape, ()

    def apply_fun(params, inputs, **kwargs):
        x, logdet = inputs
        out = jnp.tanh(x)
        tanh_logdet = -2 * (x + softplus(-2 * x) - jnp.log(2.))
        # logdet.shape = batch_shape + (num_blocks, in_factor, out_factor)
        # tanh_logdet.shape = batch_shape + (num_blocks x out_factor,)
        # so we need to reshape tanh_logdet to: batch_shape + (num_blocks, 1, out_factor)
        tanh_logdet = tanh_logdet.reshape(logdet.shape[:-2] + (1, logdet.shape[-1]))
        return out, logdet + tanh_logdet

    return init_fun, apply_fun 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:22,代码来源:block_neural_arn.py

示例4: FanInResidualNormal

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import log [as 别名]
def FanInResidualNormal():
    """
    Similar to stax.FanInSum but also keeps track of log determinant of Jacobian.
    It is required that the second fan-in branch is identity.

    :return: an (`init_fn`, `update_fn`) pair.
    """
    def init_fun(rng, input_shape):
        return input_shape[0], ()

    def apply_fun(params, inputs, **kwargs):
        # f(x) + x
        (fx, logdet), (x, _) = inputs
        return fx + x, softplus(logdet)

    return init_fun, apply_fun 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:18,代码来源:block_neural_arn.py

示例5: _get_tr_params

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import log [as 别名]
def _get_tr_params(n, p):
    # See Table 1. Additionally, we pre-compute log(p), log1(-p) and the
    # constant terms, that depend only on (n, p, m) in log(f(k)) (bottom of page 5).
    mu = n * p
    spq = jnp.sqrt(mu * (1 - p))
    c = mu + 0.5
    b = 1.15 + 2.53 * spq
    a = -0.0873 + 0.0248 * b + 0.01 * p
    alpha = (2.83 + 5.1 / b) * spq
    u_r = 0.43
    v_r = 0.92 - 4.2 / b
    m = jnp.floor((n + 1) * p).astype(n.dtype)
    log_p = jnp.log(p)
    log1_p = jnp.log1p(-p)
    log_h = (m + 0.5) * (jnp.log((m + 1.) / (n - m + 1.)) + log1_p - log_p) + \
            (stirling_approx_tail(m) + stirling_approx_tail(n - m))
    return _tr_params(c, b, a, alpha, u_r, v_r, m, log_p, log1_p, log_h) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:19,代码来源:util.py

示例6: _log

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import log [as 别名]
def _log(x):
    return np.log(x) 
开发者ID:pyro-ppl,项目名称:funsor,代码行数:4,代码来源:ops.py

示例7: gaussian_kl

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import log [as 别名]
def gaussian_kl(mu, sigmasq):
    """KL divergence from a diagonal Gaussian to the standard Gaussian."""
    return -0.5 * np.sum(1. + np.log(sigmasq) - mu ** 2. - sigmasq) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:5,代码来源:mnist_vae.py

示例8: bernoulli_logpdf

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import log [as 别名]
def bernoulli_logpdf(logits, x):
    """Bernoulli log pdf of data x given logits."""
    return -np.sum(np.logaddexp(0., np.where(x, -1., 1.) * logits)) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:5,代码来源:mnist_vae.py

示例9: test_ocr_rnn

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import log [as 别名]
def test_ocr_rnn():
    length = 5
    carry_size = 3
    class_count = 4
    inputs = jnp.zeros((1, length, 4))

    def rnn(): return Rnn(*GRUCell(carry_size, zeros))

    net = Sequential(
        rnn(),
        rnn(),
        rnn(),
        lambda x: jnp.reshape(x, (-1, carry_size)),  # -> same weights for all time steps
        Dense(class_count, zeros, zeros),
        softmax,
        lambda x: jnp.reshape(x, (-1, length, class_count)))

    params = net.init_parameters(inputs, key=PRNGKey(0))

    assert len(params) == 4
    cell = params.rnn0.gru_cell
    assert len(cell) == 3
    assert jnp.array_equal(jnp.zeros((7, 3)), cell.update_kernel)
    assert jnp.array_equal(jnp.zeros((7, 3)), cell.reset_kernel)
    assert jnp.array_equal(jnp.zeros((7, 3)), cell.compute_kernel)

    out = net.apply(params, inputs)

    @parametrized
    def cross_entropy(images, targets):
        prediction = net(images)
        return jnp.mean(-jnp.sum(targets * jnp.log(prediction), (1, 2)))

    opt = optimizers.RmsProp(0.003)
    state = opt.init(cross_entropy.init_parameters(inputs, out, key=PRNGKey(0)))
    state = opt.update(cross_entropy.apply, state, inputs, out)
    opt.update(cross_entropy.apply, state, inputs, out, jit=True) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:39,代码来源:test_examples.py

示例10: log

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import log [as 别名]
def log(self, x):
        return np.log(x) 
开发者ID:sharadmv,项目名称:deepx,代码行数:4,代码来源:jax.py

示例11: categorical_crossentropy

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import log [as 别名]
def categorical_crossentropy(self, output, target, from_logits=False):
        if from_logits:
            raise NotImplementedError
        return -np.mean(np.log(output) * target, axis=-1) 
开发者ID:sharadmv,项目名称:deepx,代码行数:6,代码来源:jax.py

示例12: logdet

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import log [as 别名]
def logdet(self, A, **kwargs):
        A = (A + self.matrix_transpose(A)) / 2.
        term = np.log(np.diag(self.cholesky(A, **kwargs)))
        return 2 * np.sum(term, axis=-1) 
开发者ID:sharadmv,项目名称:deepx,代码行数:6,代码来源:jax.py

示例13: multigammaln

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import log [as 别名]
def multigammaln(self, a, p):
        p = self.to_float(p)
        p_ = self.cast(p, 'int32')
        a = a[..., None]
        i = self.to_float(self.range(1, p_ + 1))
        term1 = p * (p - 1) / 4. * self.log(np.pi)
        term2 = self.gammaln(a - (i - 1) / 2.)
        return term1 + self.sum(term2, axis=-1) 
开发者ID:sharadmv,项目名称:deepx,代码行数:10,代码来源:jax.py

示例14: log

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import log [as 别名]
def log(self, tensor_in):
        return np.log(tensor_in) 
开发者ID:scikit-hep,项目名称:pyhf,代码行数:4,代码来源:jax_backend.py

示例15: poisson_logpdf

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import log [as 别名]
def poisson_logpdf(self, n, lam):
        n = np.asarray(n)
        lam = np.asarray(lam)
        return n * np.log(lam) - lam - gammaln(n + 1.0) 
开发者ID:scikit-hep,项目名称:pyhf,代码行数:6,代码来源:jax_backend.py


注:本文中的jax.numpy.log方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。