当前位置: 首页>>代码示例>>Python>>正文


Python numpy.expand_dims方法代码示例

本文整理汇总了Python中jax.numpy.expand_dims方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.expand_dims方法的具体用法?Python numpy.expand_dims怎么用?Python numpy.expand_dims使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在jax.numpy的用法示例。


在下文中一共展示了numpy.expand_dims方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _multinomial

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import expand_dims [as 别名]
def _multinomial(key, p, n, n_max, shape=()):
    if jnp.shape(n) != jnp.shape(p)[:-1]:
        broadcast_shape = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1])
        n = jnp.broadcast_to(n, broadcast_shape)
        p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:])
    shape = shape or p.shape[:-1]
    # get indices from categorical distribution then gather the result
    indices = categorical(key, p, (n_max,) + shape)
    # mask out values when counts is heterogeneous
    if jnp.ndim(n) > 0:
        mask = promote_shapes(jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,))[0]
        mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype)
        excess = jnp.concatenate([jnp.expand_dims(n_max - n, -1), jnp.zeros(jnp.shape(n) + (p.shape[-1] - 1,))], -1)
    else:
        mask = 1
        excess = 0
    # NB: we transpose to move batch shape to the front
    indices_2D = (jnp.reshape(indices * mask, (n_max, -1,))).T
    samples_2D = vmap(_scatter_add_one, (0, 0, 0))(jnp.zeros((indices_2D.shape[0], p.shape[-1]),
                                                             dtype=indices.dtype),
                                                   jnp.expand_dims(indices_2D, axis=-1),
                                                   jnp.ones(indices_2D.shape, dtype=indices.dtype))
    return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:25,代码来源:util.py

示例2: _onion

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import expand_dims [as 别名]
def _onion(self, key, size):
        key_beta, key_normal = random.split(key)
        # Now we generate w term in Algorithm 3.2 of [1].
        beta_sample = self._beta.sample(key_beta, size)
        # The following Normal distribution is used to create a uniform distribution on
        # a hypershere (ref: http://mathworld.wolfram.com/HyperspherePointPicking.html)
        normal_sample = random.normal(
            key_normal,
            shape=size + self.batch_shape + (self.dimension * (self.dimension - 1) // 2,)
        )
        normal_sample = vec_to_tril_matrix(normal_sample, diagonal=0)
        u_hypershere = normal_sample / jnp.linalg.norm(normal_sample, axis=-1, keepdims=True)
        w = jnp.expand_dims(jnp.sqrt(beta_sample), axis=-1) * u_hypershere

        # put w into the off-diagonal triangular part
        cholesky = ops.index_add(jnp.zeros(size + self.batch_shape + self.event_shape),
                                 ops.index[..., 1:, :-1], w)
        # correct the diagonal
        # NB: we clip due to numerical precision
        diag = jnp.sqrt(jnp.clip(1 - jnp.sum(cholesky ** 2, axis=-1), a_min=0.))
        cholesky = cholesky + jnp.expand_dims(diag, axis=-1) * jnp.identity(self.dimension)
        return cholesky 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:24,代码来源:continuous.py

示例3: __init__

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import expand_dims [as 别名]
def __init__(self, loc=0., covariance_matrix=None, precision_matrix=None, scale_tril=None,
                 validate_args=None):
        if jnp.isscalar(loc):
            loc = jnp.expand_dims(loc, axis=-1)
        # temporary append a new axis to loc
        loc = loc[..., jnp.newaxis]
        if covariance_matrix is not None:
            loc, self.covariance_matrix = promote_shapes(loc, covariance_matrix)
            self.scale_tril = jnp.linalg.cholesky(self.covariance_matrix)
        elif precision_matrix is not None:
            loc, self.precision_matrix = promote_shapes(loc, precision_matrix)
            self.scale_tril = cholesky_of_inverse(self.precision_matrix)
        elif scale_tril is not None:
            loc, self.scale_tril = promote_shapes(loc, scale_tril)
        else:
            raise ValueError('One of `covariance_matrix`, `precision_matrix`, `scale_tril`'
                             ' must be specified.')
        batch_shape = lax.broadcast_shapes(jnp.shape(loc)[:-2], jnp.shape(self.scale_tril)[:-2])
        event_shape = jnp.shape(self.scale_tril)[-1:]
        self.loc = jnp.broadcast_to(jnp.squeeze(loc, axis=-1), batch_shape + event_shape)
        super(MultivariateNormal, self).__init__(batch_shape=batch_shape,
                                                 event_shape=event_shape,
                                                 validate_args=validate_args) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:25,代码来源:continuous.py

示例4: logprob_from_conditional_params

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import expand_dims [as 别名]
def logprob_from_conditional_params(images, means, inv_scales, logit_probs):
    images = jnp.expand_dims(images, 1)
    cdf = lambda offset: sigmoid((images - means + offset) * inv_scales)
    upper_cdf = jnp.where(images == 1, 1, cdf(1 / 255))
    lower_cdf = jnp.where(images == -1, 0, cdf(-1 / 255))
    all_logprobs = jnp.sum(jnp.log(jnp.maximum(upper_cdf - lower_cdf, 1e-12)), -1)
    log_mix_coeffs = logit_probs - logsumexp(logit_probs, -3, keepdims=True)
    return jnp.sum(logsumexp(log_mix_coeffs + all_logprobs, axis=-3), axis=(-2, -1)) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:10,代码来源:pixelcnn.py

示例5: _extract_signal_patches

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import expand_dims [as 别名]
def _extract_signal_patches(signal, window_length, hop=1, data_format="NCW"):
    assert not hasattr(window_length, "__len__")
    assert signal.ndim == 3
    if data_format == "NCW":
        N = (signal.shape[2] - window_length) // hop + 1
        indices = jnp.arange(window_length) + jnp.expand_dims(jnp.arange(N) * hop, 1)
        indices = jnp.reshape(indices, [1, 1, N * window_length])
        patches = jnp.take_along_axis(signal, indices, 2)
        return jnp.reshape(patches, signal.shape[:2] + (N, window_length))
    else:
        error 
开发者ID:SymJAX,项目名称:SymJAX,代码行数:13,代码来源:ops_special.py

示例6: _extract_image_patches

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import expand_dims [as 别名]
def _extract_image_patches(
    image, window_shape, hop=1, data_format="NCHW", mode="valid"
):
    if mode == "same":
        p1 = window_shape[0] - 1
        p2 = window_shape[1] - 1
        image = jnp.pad(
            image, [(0, 0), (0, 0), (p1 // 2, p1 - p1 // 2), (p2 // 2, p2 - p2 // 2)]
        )
    if not hasattr(hop, "__len__"):
        hop = (hop, hop)
    if data_format == "NCHW":

        # compute the number of windows in both dimensions
        N = (
            (image.shape[2] - window_shape[0]) // hop[0] + 1,
            (image.shape[3] - window_shape[1]) // hop[1] + 1,
        )

        # compute the base indices of a 2d patch
        patch = jnp.arange(numpy.prod(window_shape)).reshape(window_shape)
        offset = jnp.expand_dims(jnp.arange(window_shape[0]), 1)
        patch_indices = patch + offset * (image.shape[3] - window_shape[1])

        # create all the shifted versions of it
        ver_shifts = jnp.reshape(
            jnp.arange(N[0]) * hop[0] * image.shape[3], (-1, 1, 1, 1)
        )
        hor_shifts = jnp.reshape(jnp.arange(N[1]) * hop[1], (-1, 1, 1))
        all_cols = patch_indices + jnp.reshape(jnp.arange(N[1]) * hop[1], (-1, 1, 1))
        indices = patch_indices + ver_shifts + hor_shifts

        # now extract shape (1, 1, H'W'a'b')
        flat_indices = jnp.reshape(indices, [1, 1, -1])
        # shape is now (N, C, W*H)
        flat_image = jnp.reshape(image, (image.shape[0], image.shape[1], -1))
        # shape is now (N, C)
        patches = jnp.take_along_axis(flat_image, flat_indices, 2)
        return jnp.reshape(patches, image.shape[:2] + N + tuple(window_shape))
    else:
        error 
开发者ID:SymJAX,项目名称:SymJAX,代码行数:43,代码来源:ops_special.py

示例7: expand_dims

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import expand_dims [as 别名]
def expand_dims(self, x, dim=-1):
        return np.expand_dims(x, dim=dim) 
开发者ID:sharadmv,项目名称:deepx,代码行数:4,代码来源:jax.py

示例8: forward_one_step

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import expand_dims [as 别名]
def forward_one_step(prev_log_prob, curr_word, transition_log_prob, emission_log_prob):
    log_prob_tmp = jnp.expand_dims(prev_log_prob, axis=1) + transition_log_prob
    log_prob = log_prob_tmp + emission_log_prob[:, curr_word]
    return logsumexp(log_prob, axis=0) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:6,代码来源:hmm.py

示例9: __call__

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import expand_dims [as 别名]
def __call__(self, x):
        n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2)
        z = vec_to_tril_matrix(x[..., :-n], diagonal=-1)
        diag = jnp.exp(x[..., -n:])
        return z + jnp.expand_dims(diag, axis=-1) * jnp.identity(n) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:7,代码来源:transforms.py

示例10: vec_to_tril_matrix

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import expand_dims [as 别名]
def vec_to_tril_matrix(t, diagonal=0):
    # NB: the following formula only works for diagonal <= 0
    n = round((math.sqrt(1 + 8 * t.shape[-1]) - 1) / 2) - diagonal
    n2 = n * n
    idx = jnp.reshape(jnp.arange(n2), (n, n))[jnp.tril_indices(n, diagonal)]
    x = lax.scatter_add(jnp.zeros(t.shape[:-1] + (n2,)), jnp.expand_dims(idx, axis=-1), t,
                        lax.ScatterDimensionNumbers(update_window_dims=range(t.ndim - 1),
                                                    inserted_window_dims=(t.ndim - 1,),
                                                    scatter_dims_to_operand_dims=(t.ndim - 1,)))
    return jnp.reshape(x, x.shape[:-1] + (n, n)) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:12,代码来源:util.py

示例11: sample

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import expand_dims [as 别名]
def sample(self, key, sample_shape=()):
        shape = sample_shape + self.batch_shape + self.event_shape
        walks = random.normal(key, shape=shape)
        return jnp.cumsum(walks, axis=-1) * jnp.expand_dims(self.scale, axis=-1) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:6,代码来源:continuous.py

示例12: variance

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import expand_dims [as 别名]
def variance(self):
        return jnp.broadcast_to(jnp.expand_dims(self.scale, -1) ** 2 * jnp.arange(1, self.num_steps + 1),
                                self.batch_shape + self.event_shape) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:5,代码来源:continuous.py

示例13: _batch_capacitance_tril

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import expand_dims [as 别名]
def _batch_capacitance_tril(W, D):
    r"""
    Computes Cholesky of :math:`I + W.T @ inv(D) @ W` for a batch of matrices :math:`W`
    and a batch of vectors :math:`D`.
    """
    Wt_Dinv = jnp.swapaxes(W, -1, -2) / jnp.expand_dims(D, -2)
    K = jnp.matmul(Wt_Dinv, W)
    # could be inefficient
    return jnp.linalg.cholesky(jnp.add(K, jnp.identity(K.shape[-1]))) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:11,代码来源:continuous.py

示例14: _batch_lowrank_mahalanobis

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import expand_dims [as 别名]
def _batch_lowrank_mahalanobis(W, D, x, capacitance_tril):
    r"""
    Uses "Woodbury matrix identity"::
        inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D),
    where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute the squared
    Mahalanobis distance :math:`x.T @ inv(W @ W.T + D) @ x`.
    """
    Wt_Dinv = jnp.swapaxes(W, -1, -2) / jnp.expand_dims(D, -2)
    Wt_Dinv_x = _batch_mv(Wt_Dinv, x)
    mahalanobis_term1 = jnp.sum(jnp.square(x) / D, axis=-1)
    mahalanobis_term2 = _batch_mahalanobis(capacitance_tril, Wt_Dinv_x)
    return mahalanobis_term1 - mahalanobis_term2 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:14,代码来源:continuous.py

示例15: scale_tril

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import expand_dims [as 别名]
def scale_tril(self):
        # The following identity is used to increase the numerically computation stability
        # for Cholesky decomposition (see http://www.gaussianprocess.org/gpml/, Section 3.4.3):
        #     W @ W.T + D = D1/2 @ (I + D-1/2 @ W @ W.T @ D-1/2) @ D1/2
        # The matrix "I + D-1/2 @ W @ W.T @ D-1/2" has eigenvalues bounded from below by 1,
        # hence it is well-conditioned and safe to take Cholesky decomposition.
        cov_diag_sqrt_unsqueeze = jnp.expand_dims(jnp.sqrt(self.cov_diag), axis=-1)
        Dinvsqrt_W = self.cov_factor / cov_diag_sqrt_unsqueeze
        K = jnp.matmul(Dinvsqrt_W, jnp.swapaxes(Dinvsqrt_W, -1, -2))
        K = jnp.add(K, jnp.identity(K.shape[-1]))
        scale_tril = cov_diag_sqrt_unsqueeze * jnp.linalg.cholesky(K)
        return scale_tril 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:14,代码来源:continuous.py


注:本文中的jax.numpy.expand_dims方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。