当前位置: 首页>>代码示例>>Python>>正文


Python numpy.where方法代码示例

本文整理汇总了Python中jax.numpy.where方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.where方法的具体用法?Python numpy.where怎么用?Python numpy.where使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在jax.numpy的用法示例。


在下文中一共展示了numpy.where方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: set_host_device_count

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import where [as 别名]
def set_host_device_count(n):
    """
    By default, XLA considers all CPU cores as one device. This utility tells XLA
    that there are `n` host (CPU) devices available to use. As a consequence, this
    allows parallel mapping in JAX :func:`jax.pmap` to work in CPU platform.

    .. note:: This utility only takes effect at the beginning of your program.
        Under the hood, this sets the environment variable
        `XLA_FLAGS=--xla_force_host_platform_device_count=[num_devices]`, where
        `[num_device]` is the desired number of CPU devices `n`.

    .. warning:: Our understanding of the side effects of using the
        `xla_force_host_platform_device_count` flag in XLA is incomplete. If you
        observe some strange phenomenon when using this utility, please let us
        know through our issue or forum page. More information is available in this
        `JAX issue <https://github.com/google/jax/issues/1408>`_.

    :param int n: number of CPU devices to use.
    """
    xla_flags = os.getenv('XLA_FLAGS', '').lstrip('--')
    xla_flags = re.sub(r'xla_force_host_platform_device_count=.+\s', '', xla_flags).split()
    os.environ['XLA_FLAGS'] = ' '.join(['--xla_force_host_platform_device_count={}'.format(n)]
                                       + xla_flags) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:25,代码来源:util.py

示例2: _setup_prototype

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import where [as 别名]
def _setup_prototype(self, *args, **kwargs):
        rng_key = numpyro.sample("_{}_rng_key_setup".format(self.prefix), dist.PRNGIdentity())
        with handlers.block():
            init_params, _, self._postprocess_fn, self.prototype_trace = initialize_model(
                rng_key, self.model,
                init_strategy=self.init_strategy,
                dynamic_args=False,
                model_args=args,
                model_kwargs=kwargs)

        self._init_latent, unpack_latent = ravel_pytree(init_params[0])
        # this is to match the behavior of Pyro, where we can apply
        # unpack_latent for a batch of samples
        self._unpack_latent = UnpackTransform(unpack_latent)
        self.latent_dim = jnp.size(self._init_latent)
        if self.latent_dim == 0:
            raise RuntimeError('{} found no latent variables; Use an empty guide instead'
                               .format(type(self).__name__)) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:20,代码来源:autoguide.py

示例3: get_transform

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import where [as 别名]
def get_transform(self, params):
        def loss_fn(z):
            params1 = params.copy()
            params1['{}_loc'.format(self.prefix)] = z
            return self._loss_fn(params1)

        loc = params['{}_loc'.format(self.prefix)]
        precision = hessian(loss_fn)(loc)
        scale_tril = cholesky_of_inverse(precision)
        if not_jax_tracer(scale_tril):
            if jnp.any(jnp.isnan(scale_tril)):
                warnings.warn("Hessian of log posterior at the MAP point is singular. Posterior"
                              " samples from AutoLaplaceApproxmiation will be constant (equal to"
                              " the MAP point).")
        scale_tril = jnp.where(jnp.isnan(scale_tril), 0., scale_tril)
        return LowerCholeskyAffine(loc, scale_tril) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:18,代码来源:autoguide.py

示例4: stirling_approx_tail

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import where [as 别名]
def stirling_approx_tail(k):
    precomputed = jnp.array([
        0.08106146679532726,
        0.04134069595540929,
        0.02767792568499834,
        0.02079067210376509,
        0.01664469118982119,
        0.01387612882307075,
        0.01189670994589177,
        0.01041126526197209,
        0.009255462182712733,
        0.008330563433362871,
    ])
    kp1 = k + 1
    kp1sq = (k + 1) ** 2
    return jnp.where(k < 10,
                     precomputed[k],
                     (1. / 12 - (1. / 360 - (1. / 1260) / kp1sq) / kp1sq) / kp1) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:20,代码来源:util.py

示例5: _binomial_dispatch

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import where [as 别名]
def _binomial_dispatch(key, p, n):
    def dispatch(key, p, n):
        is_le_mid = p <= 0.5
        pq = jnp.where(is_le_mid, p, 1 - p)
        mu = n * pq
        k = lax.cond(mu < 10,
                     (key, pq, n),
                     lambda x: _binomial_inversion(*x),
                     (key, pq, n),
                     lambda x: _binomial_btrs(*x))
        return jnp.where(is_le_mid, k, n - k)

    # Return 0 for nan `p` or negative `n`, since nan values are not allowed for integer types
    cond0 = jnp.isfinite(p) & (n > 0) & (p > 0)
    return lax.cond(cond0 & (p < 1),
                    (key, p, n),
                    lambda x: dispatch(*x),
                    (),
                    lambda _: jnp.where(cond0, n, 0)) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:21,代码来源:util.py

示例6: test_mask

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import where [as 别名]
def test_mask(mask_last, use_jit):
    N = 10
    mask = np.ones(N, dtype=np.bool)
    mask[-mask_last] = 0

    def model(data, mask):
        with numpyro.plate('N', N):
            x = numpyro.sample('x', dist.Normal(0, 1))
            with handlers.mask(mask_array=mask):
                numpyro.sample('y', dist.Delta(x, log_density=1.))
                with handlers.scale(scale=2):
                    numpyro.sample('obs', dist.Normal(x, 1), obs=data)

    data = random.normal(random.PRNGKey(0), (N,))
    x = random.normal(random.PRNGKey(1), (N,))
    if use_jit:
        log_joint = jit(lambda *args: log_density(*args)[0], static_argnums=(0,))(
            model, (data, mask), {}, {'x': x, 'y': x})
    else:
        log_joint = log_density(model, (data, mask), {}, {'x': x, 'y': x})[0]
    log_prob_x = dist.Normal(0, 1).log_prob(x)
    log_prob_y = mask
    log_prob_z = dist.Normal(x, 1).log_prob(data)
    expected = (log_prob_x + jnp.where(mask,  log_prob_y + 2 * log_prob_z, 0.)).sum()
    assert_allclose(log_joint, expected, atol=1e-4) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:27,代码来源:test_handlers.py

示例7: solve_implicit

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import where [as 别名]
def solve_implicit(ks, a, b, c, d, b_edge=None, d_edge=None):
    land_mask = (ks >= 0)[:, :, np.newaxis]
    edge_mask = land_mask & (np.arange(a.shape[2])[np.newaxis, np.newaxis, :]
                             == ks[:, :, np.newaxis])
    water_mask = land_mask & (np.arange(a.shape[2])[np.newaxis, np.newaxis, :]
                              >= ks[:, :, np.newaxis])

    a_tri = water_mask * a * np.logical_not(edge_mask)
    b_tri = where(water_mask, b, 1.)
    if b_edge is not None:
        b_tri = where(edge_mask, b_edge, b_tri)
    c_tri = water_mask * c
    d_tri = water_mask * d
    if d_edge is not None:
        d_tri = where(edge_mask, d_edge, d_tri)

    return solve_tridiag(a_tri, b_tri, c_tri, d_tri), water_mask 
开发者ID:dionhaefner,项目名称:pyhpc-benchmarks,代码行数:19,代码来源:tke_jax.py

示例8: serialize

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import where [as 别名]
def serialize(self, data):
    array = data
    batch_size = array.shape[0]
    array = (array - self._space.low) / (self._space.high - self._space.low)
    array = np.clip(array, 0, 1)
    digits = []
    for digit_index in range(-1, -self._precision - 1, -1):
      threshold = self._vocab_size ** digit_index
      digit = np.array(array / threshold).astype(np.int32)
      # For the corner case of x == high.
      digit = np.where(digit == self._vocab_size, digit - 1, digit)
      digits.append(digit)
      array -= digit * threshold
    digits = np.stack(digits, axis=-1)
    return np.reshape(digits, (batch_size, -1)) 
开发者ID:google,项目名称:trax,代码行数:17,代码来源:space_serializer.py

示例9: Dropout

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import where [as 别名]
def Dropout(rate, test_mode=False):
    """Constructor for a dropout function with given rate."""
    rate = np.array(rate)

    @parametrized
    def dropout(inputs):
        if test_mode or rate == 0:
            return inputs

        keep_rate = 1 - rate
        keep = random.bernoulli(random_key(), keep_rate, inputs.shape)
        return np.where(keep, inputs / keep_rate, 0)

    return dropout 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:16,代码来源:modules.py

示例10: bernoulli_logpdf

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import where [as 别名]
def bernoulli_logpdf(logits, x):
    """Bernoulli log pdf of data x given logits."""
    return -np.sum(np.logaddexp(0., np.where(x, -1., 1.) * logits)) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:5,代码来源:mnist_vae.py

示例11: logprob_from_conditional_params

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import where [as 别名]
def logprob_from_conditional_params(images, means, inv_scales, logit_probs):
    images = jnp.expand_dims(images, 1)
    cdf = lambda offset: sigmoid((images - means + offset) * inv_scales)
    upper_cdf = jnp.where(images == 1, 1, cdf(1 / 255))
    lower_cdf = jnp.where(images == -1, 0, cdf(-1 / 255))
    all_logprobs = jnp.sum(jnp.log(jnp.maximum(upper_cdf - lower_cdf, 1e-12)), -1)
    log_mix_coeffs = logit_probs - logsumexp(logit_probs, -3, keepdims=True)
    return jnp.sum(logsumexp(log_mix_coeffs + all_logprobs, axis=-3), axis=(-2, -1)) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:10,代码来源:pixelcnn.py

示例12: dropout

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import where [as 别名]
def dropout(self, x, p, seed=None):
        seed = next(self.rng)
        p = 1 - p
        keep = random.bernoulli(seed, p, x.shape)
        return np.where(keep, x / p, 0) 
开发者ID:sharadmv,项目名称:deepx,代码行数:7,代码来源:jax.py

示例13: where

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import where [as 别名]
def where(self, mask, tensor_in_1, tensor_in_2):
        return np.where(mask, tensor_in_1, tensor_in_2) 
开发者ID:scikit-hep,项目名称:pyhf,代码行数:4,代码来源:jax_backend.py

示例14: _numpy_delete

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import where [as 别名]
def _numpy_delete(x, idx):
    """
    Gets the subarray from `x` where data from index `idx` on the first axis is removed.
    """
    # NB: numpy.delete is not yet available in JAX
    mask = jnp.arange(x.shape[0] - 1) < idx
    return jnp.where(mask.reshape((-1,) + (1,) * (x.ndim - 1)), x[:-1], x[1:])


# TODO: consider to expose this functional style 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:12,代码来源:mcmc.py

示例15: _biased_transition_kernel

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import where [as 别名]
def _biased_transition_kernel(current_tree, new_tree):
    # This function computes transition prob for main trees (ref [2], section A.3.2).
    transition_prob = jnp.exp(new_tree.weight - current_tree.weight)
    # If new tree is turning or diverging, we won't move the proposal
    # to the new tree.
    transition_prob = jnp.where(new_tree.turning | new_tree.diverging,
                                0.0, jnp.clip(transition_prob, a_max=1.0))
    return transition_prob 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:10,代码来源:hmc_util.py


注:本文中的jax.numpy.where方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。