当前位置: 首页>>代码示例>>Python>>正文


Python numpy.newaxis方法代码示例

本文整理汇总了Python中jax.numpy.newaxis方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.newaxis方法的具体用法?Python numpy.newaxis怎么用?Python numpy.newaxis使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在jax.numpy的用法示例。


在下文中一共展示了numpy.newaxis方法的13个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: approximate_kl

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import newaxis [as 别名]
def approximate_kl(log_prob_new, log_prob_old, mask):
  """Computes the approximate KL divergence between the old and new log-probs.

  Args:
    log_prob_new: (B, T+1, A) log probs new
    log_prob_old: (B, T+1, A) log probs old
    mask: (B, T)

  Returns:
    Approximate KL.
  """
  diff = log_prob_old - log_prob_new
  # Cut the last time-step out.
  diff = diff[:, :-1]
  # Mask out the irrelevant part.
  diff *= mask[:, :, np.newaxis]  # make mask (B, T, 1)
  # Average on non-masked part.
  return np.sum(diff) / np.sum(mask) 
开发者ID:yyht,项目名称:BERT,代码行数:20,代码来源:ppo.py

示例2: masked_entropy

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import newaxis [as 别名]
def masked_entropy(log_probs, mask):
  """Computes the entropy for the given log-probs.

  Args:
    log_probs: (B, T+1, A) log probs
    mask: (B, T) mask.

  Returns:
    Entropy.
  """
  # Cut the last time-step out.
  lp = log_probs[:, :-1]
  # Mask out the irrelevant part.
  lp *= mask[:, :, np.newaxis]  # make mask (B, T, 1)
  p = np.exp(lp) * mask[:, :, np.newaxis]  # (B, T, 1)
  # Average on non-masked part and take negative.
  return -(np.sum(lp * p) / np.sum(mask)) 
开发者ID:yyht,项目名称:BERT,代码行数:19,代码来源:ppo.py

示例3: glmm

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import newaxis [as 别名]
def glmm(dept, male, applications, admit=None):
    v_mu = numpyro.sample('v_mu', dist.Normal(0, jnp.array([4., 1.])))

    sigma = numpyro.sample('sigma', dist.HalfNormal(jnp.ones(2)))
    L_Rho = numpyro.sample('L_Rho', dist.LKJCholesky(2, concentration=2))
    scale_tril = sigma[..., jnp.newaxis] * L_Rho
    # non-centered parameterization
    num_dept = len(jnp.unique(dept))
    z = numpyro.sample('z', dist.Normal(jnp.zeros((num_dept, 2)), 1))
    v = jnp.dot(scale_tril, z.T).T

    logits = v_mu[0] + v[dept, 0] + (v_mu[1] + v[dept, 1]) * male
    if admit is None:
        # we use a Delta site to record probs for predictive distribution
        probs = expit(logits)
        numpyro.sample('probs', dist.Delta(probs), obs=probs)
    numpyro.sample('admit', dist.Binomial(applications, logits=logits), obs=admit) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:19,代码来源:ucbadmit.py

示例4: __init__

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import newaxis [as 别名]
def __init__(self, loc=0., covariance_matrix=None, precision_matrix=None, scale_tril=None,
                 validate_args=None):
        if jnp.isscalar(loc):
            loc = jnp.expand_dims(loc, axis=-1)
        # temporary append a new axis to loc
        loc = loc[..., jnp.newaxis]
        if covariance_matrix is not None:
            loc, self.covariance_matrix = promote_shapes(loc, covariance_matrix)
            self.scale_tril = jnp.linalg.cholesky(self.covariance_matrix)
        elif precision_matrix is not None:
            loc, self.precision_matrix = promote_shapes(loc, precision_matrix)
            self.scale_tril = cholesky_of_inverse(self.precision_matrix)
        elif scale_tril is not None:
            loc, self.scale_tril = promote_shapes(loc, scale_tril)
        else:
            raise ValueError('One of `covariance_matrix`, `precision_matrix`, `scale_tril`'
                             ' must be specified.')
        batch_shape = lax.broadcast_shapes(jnp.shape(loc)[:-2], jnp.shape(self.scale_tril)[:-2])
        event_shape = jnp.shape(self.scale_tril)[-1:]
        self.loc = jnp.broadcast_to(jnp.squeeze(loc, axis=-1), batch_shape + event_shape)
        super(MultivariateNormal, self).__init__(batch_shape=batch_shape,
                                                 event_shape=event_shape,
                                                 validate_args=validate_args) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:25,代码来源:continuous.py

示例5: solve_implicit

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import newaxis [as 别名]
def solve_implicit(ks, a, b, c, d, b_edge=None, d_edge=None):
    land_mask = (ks >= 0)[:, :, np.newaxis]
    edge_mask = land_mask & (np.arange(a.shape[2])[np.newaxis, np.newaxis, :]
                             == ks[:, :, np.newaxis])
    water_mask = land_mask & (np.arange(a.shape[2])[np.newaxis, np.newaxis, :]
                              >= ks[:, :, np.newaxis])

    a_tri = water_mask * a * np.logical_not(edge_mask)
    b_tri = where(water_mask, b, 1.)
    if b_edge is not None:
        b_tri = where(edge_mask, b_edge, b_tri)
    c_tri = water_mask * c
    d_tri = water_mask * d
    if d_edge is not None:
        d_tri = where(edge_mask, d_edge, d_tri)

    return solve_tridiag(a_tri, b_tri, c_tri, d_tri), water_mask 
开发者ID:dionhaefner,项目名称:pyhpc-benchmarks,代码行数:19,代码来源:tke_jax.py

示例6: _adv_superbee

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import newaxis [as 别名]
def _adv_superbee(vel, var, mask, dx, axis, cost, cosu, dt_tracer):
    velfac = 1
    if axis == 0:
        sm1, s, sp1, sp2 = ((slice(1 + n, -2 + n or None), slice(2, -2), slice(None))
                            for n in range(-1, 3))
        dx = cost[np.newaxis, 2:-2, np.newaxis] * \
            dx[1:-2, np.newaxis, np.newaxis]
    elif axis == 1:
        sm1, s, sp1, sp2 = ((slice(2, -2), slice(1 + n, -2 + n or None), slice(None))
                            for n in range(-1, 3))
        dx = (cost * dx)[np.newaxis, 1:-2, np.newaxis]
        velfac = cosu[np.newaxis, 1:-2, np.newaxis]
    elif axis == 2:
        vel, var, mask = (pad_z_edges(a) for a in (vel, var, mask))
        sm1, s, sp1, sp2 = ((slice(2, -2), slice(2, -2), slice(1 + n, -2 + n or None))
                            for n in range(-1, 3))
        dx = dx[np.newaxis, np.newaxis, :-1]
    else:
        raise ValueError('axis must be 0, 1, or 2')
    uCFL = np.abs(velfac * vel[s] * dt_tracer / dx)
    rjp = (var[sp2] - var[sp1]) * mask[sp1]
    rj = (var[sp1] - var[s]) * mask[s]
    rjm = (var[s] - var[sm1]) * mask[sm1]
    cr = limiter(_calc_cr(rjp, rj, rjm, vel[s]))
    return velfac * vel[s] * (var[sp1] + var[s]) * 0.5 - np.abs(velfac * vel[s]) * ((1. - cr) + uCFL * cr) * rj * 0.5 
开发者ID:dionhaefner,项目名称:pyhpc-benchmarks,代码行数:27,代码来源:tke_jax.py

示例7: _normalize_by_window_size

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import newaxis [as 别名]
def _normalize_by_window_size(dims, spatial_strides, padding):  # pylint: disable=invalid-name
  def rescale(outputs, inputs):
    one = jnp.ones(inputs.shape[1:-1], dtype=inputs.dtype)
    window_sizes = lax.reduce_window(
        one, 0., lax.add, dims, spatial_strides, padding)
    return outputs / window_sizes[..., jnp.newaxis]
  return rescale 
开发者ID:yyht,项目名称:BERT,代码行数:9,代码来源:backend.py

示例8: _normalize_by_window_size

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import newaxis [as 别名]
def _normalize_by_window_size(dims, strides, padding):
    def rescale(outputs, inputs):
        one = np.ones(inputs.shape[1:-1], dtype=inputs.dtype)
        window_sizes = lax.reduce_window(one, 0., lax.add, dims, strides, padding)
        return outputs / window_sizes[..., np.newaxis]

    return rescale 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:9,代码来源:modules.py

示例9: __call__

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import newaxis [as 别名]
def __call__(self, x):
        return self.loc + jnp.squeeze(jnp.matmul(self.scale_tril, x[..., jnp.newaxis]), axis=-1) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:4,代码来源:transforms.py

示例10: sum_rightmost

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import newaxis [as 别名]
def sum_rightmost(x, dim):
    """
    Sum out ``dim`` many rightmost dimensions of a given tensor.
    """
    out_dim = jnp.ndim(x) - dim
    x = jnp.reshape(x[..., jnp.newaxis], jnp.shape(x)[:out_dim] + (-1,))
    return jnp.sum(x, axis=-1) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:9,代码来源:util.py

示例11: sample

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import newaxis [as 别名]
def sample(self, key, sample_shape=()):
        eps = random.normal(key, shape=sample_shape + self.batch_shape + self.event_shape)
        return self.loc + jnp.squeeze(jnp.matmul(self.scale_tril, eps[..., jnp.newaxis]), axis=-1) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:5,代码来源:continuous.py

示例12: covariance_matrix

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import newaxis [as 别名]
def covariance_matrix(self):
        # TODO: find a better solution to create a diagonal matrix
        new_diag = self.cov_diag[..., jnp.newaxis] * jnp.identity(self.loc.shape[-1])
        covariance_matrix = new_diag + jnp.matmul(
            self.cov_factor, jnp.swapaxes(self.cov_factor, -1, -2)
            )
        return covariance_matrix 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:9,代码来源:continuous.py

示例13: precision_matrix

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import newaxis [as 别名]
def precision_matrix(self):
        # We use "Woodbury matrix identity" to take advantage of low rank form::
        #     inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D)
        # where :math:`C` is the capacitance matrix.
        Wt_Dinv = (jnp.swapaxes(self.cov_factor, -1, -2)
                   / jnp.expand_dims(self.cov_diag, axis=-2))
        A = solve_triangular(Wt_Dinv, self._capacitance_tril, lower=True)
        # TODO: find a better solution to create a diagonal matrix
        inverse_cov_diag = jnp.reciprocal(self.cov_diag)
        diag_embed = inverse_cov_diag[..., jnp.newaxis] * jnp.identity(self.loc.shape[-1])
        return diag_embed - jnp.matmul(jnp.swapaxes(A, -1, -2), A) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:13,代码来源:continuous.py


注:本文中的jax.numpy.newaxis方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。