本文整理汇总了Python中jax.numpy.where方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.where方法的具体用法?Python numpy.where怎么用?Python numpy.where使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类jax.numpy
的用法示例。
在下文中一共展示了numpy.where方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: set_host_device_count
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import where [as 别名]
def set_host_device_count(n):
"""
By default, XLA considers all CPU cores as one device. This utility tells XLA
that there are `n` host (CPU) devices available to use. As a consequence, this
allows parallel mapping in JAX :func:`jax.pmap` to work in CPU platform.
.. note:: This utility only takes effect at the beginning of your program.
Under the hood, this sets the environment variable
`XLA_FLAGS=--xla_force_host_platform_device_count=[num_devices]`, where
`[num_device]` is the desired number of CPU devices `n`.
.. warning:: Our understanding of the side effects of using the
`xla_force_host_platform_device_count` flag in XLA is incomplete. If you
observe some strange phenomenon when using this utility, please let us
know through our issue or forum page. More information is available in this
`JAX issue <https://github.com/google/jax/issues/1408>`_.
:param int n: number of CPU devices to use.
"""
xla_flags = os.getenv('XLA_FLAGS', '').lstrip('--')
xla_flags = re.sub(r'xla_force_host_platform_device_count=.+\s', '', xla_flags).split()
os.environ['XLA_FLAGS'] = ' '.join(['--xla_force_host_platform_device_count={}'.format(n)]
+ xla_flags)
示例2: _setup_prototype
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import where [as 别名]
def _setup_prototype(self, *args, **kwargs):
rng_key = numpyro.sample("_{}_rng_key_setup".format(self.prefix), dist.PRNGIdentity())
with handlers.block():
init_params, _, self._postprocess_fn, self.prototype_trace = initialize_model(
rng_key, self.model,
init_strategy=self.init_strategy,
dynamic_args=False,
model_args=args,
model_kwargs=kwargs)
self._init_latent, unpack_latent = ravel_pytree(init_params[0])
# this is to match the behavior of Pyro, where we can apply
# unpack_latent for a batch of samples
self._unpack_latent = UnpackTransform(unpack_latent)
self.latent_dim = jnp.size(self._init_latent)
if self.latent_dim == 0:
raise RuntimeError('{} found no latent variables; Use an empty guide instead'
.format(type(self).__name__))
示例3: get_transform
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import where [as 别名]
def get_transform(self, params):
def loss_fn(z):
params1 = params.copy()
params1['{}_loc'.format(self.prefix)] = z
return self._loss_fn(params1)
loc = params['{}_loc'.format(self.prefix)]
precision = hessian(loss_fn)(loc)
scale_tril = cholesky_of_inverse(precision)
if not_jax_tracer(scale_tril):
if jnp.any(jnp.isnan(scale_tril)):
warnings.warn("Hessian of log posterior at the MAP point is singular. Posterior"
" samples from AutoLaplaceApproxmiation will be constant (equal to"
" the MAP point).")
scale_tril = jnp.where(jnp.isnan(scale_tril), 0., scale_tril)
return LowerCholeskyAffine(loc, scale_tril)
示例4: stirling_approx_tail
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import where [as 别名]
def stirling_approx_tail(k):
precomputed = jnp.array([
0.08106146679532726,
0.04134069595540929,
0.02767792568499834,
0.02079067210376509,
0.01664469118982119,
0.01387612882307075,
0.01189670994589177,
0.01041126526197209,
0.009255462182712733,
0.008330563433362871,
])
kp1 = k + 1
kp1sq = (k + 1) ** 2
return jnp.where(k < 10,
precomputed[k],
(1. / 12 - (1. / 360 - (1. / 1260) / kp1sq) / kp1sq) / kp1)
示例5: _binomial_dispatch
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import where [as 别名]
def _binomial_dispatch(key, p, n):
def dispatch(key, p, n):
is_le_mid = p <= 0.5
pq = jnp.where(is_le_mid, p, 1 - p)
mu = n * pq
k = lax.cond(mu < 10,
(key, pq, n),
lambda x: _binomial_inversion(*x),
(key, pq, n),
lambda x: _binomial_btrs(*x))
return jnp.where(is_le_mid, k, n - k)
# Return 0 for nan `p` or negative `n`, since nan values are not allowed for integer types
cond0 = jnp.isfinite(p) & (n > 0) & (p > 0)
return lax.cond(cond0 & (p < 1),
(key, p, n),
lambda x: dispatch(*x),
(),
lambda _: jnp.where(cond0, n, 0))
示例6: test_mask
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import where [as 别名]
def test_mask(mask_last, use_jit):
N = 10
mask = np.ones(N, dtype=np.bool)
mask[-mask_last] = 0
def model(data, mask):
with numpyro.plate('N', N):
x = numpyro.sample('x', dist.Normal(0, 1))
with handlers.mask(mask_array=mask):
numpyro.sample('y', dist.Delta(x, log_density=1.))
with handlers.scale(scale=2):
numpyro.sample('obs', dist.Normal(x, 1), obs=data)
data = random.normal(random.PRNGKey(0), (N,))
x = random.normal(random.PRNGKey(1), (N,))
if use_jit:
log_joint = jit(lambda *args: log_density(*args)[0], static_argnums=(0,))(
model, (data, mask), {}, {'x': x, 'y': x})
else:
log_joint = log_density(model, (data, mask), {}, {'x': x, 'y': x})[0]
log_prob_x = dist.Normal(0, 1).log_prob(x)
log_prob_y = mask
log_prob_z = dist.Normal(x, 1).log_prob(data)
expected = (log_prob_x + jnp.where(mask, log_prob_y + 2 * log_prob_z, 0.)).sum()
assert_allclose(log_joint, expected, atol=1e-4)
示例7: solve_implicit
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import where [as 别名]
def solve_implicit(ks, a, b, c, d, b_edge=None, d_edge=None):
land_mask = (ks >= 0)[:, :, np.newaxis]
edge_mask = land_mask & (np.arange(a.shape[2])[np.newaxis, np.newaxis, :]
== ks[:, :, np.newaxis])
water_mask = land_mask & (np.arange(a.shape[2])[np.newaxis, np.newaxis, :]
>= ks[:, :, np.newaxis])
a_tri = water_mask * a * np.logical_not(edge_mask)
b_tri = where(water_mask, b, 1.)
if b_edge is not None:
b_tri = where(edge_mask, b_edge, b_tri)
c_tri = water_mask * c
d_tri = water_mask * d
if d_edge is not None:
d_tri = where(edge_mask, d_edge, d_tri)
return solve_tridiag(a_tri, b_tri, c_tri, d_tri), water_mask
示例8: serialize
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import where [as 别名]
def serialize(self, data):
array = data
batch_size = array.shape[0]
array = (array - self._space.low) / (self._space.high - self._space.low)
array = np.clip(array, 0, 1)
digits = []
for digit_index in range(-1, -self._precision - 1, -1):
threshold = self._vocab_size ** digit_index
digit = np.array(array / threshold).astype(np.int32)
# For the corner case of x == high.
digit = np.where(digit == self._vocab_size, digit - 1, digit)
digits.append(digit)
array -= digit * threshold
digits = np.stack(digits, axis=-1)
return np.reshape(digits, (batch_size, -1))
示例9: Dropout
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import where [as 别名]
def Dropout(rate, test_mode=False):
"""Constructor for a dropout function with given rate."""
rate = np.array(rate)
@parametrized
def dropout(inputs):
if test_mode or rate == 0:
return inputs
keep_rate = 1 - rate
keep = random.bernoulli(random_key(), keep_rate, inputs.shape)
return np.where(keep, inputs / keep_rate, 0)
return dropout
示例10: bernoulli_logpdf
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import where [as 别名]
def bernoulli_logpdf(logits, x):
"""Bernoulli log pdf of data x given logits."""
return -np.sum(np.logaddexp(0., np.where(x, -1., 1.) * logits))
示例11: logprob_from_conditional_params
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import where [as 别名]
def logprob_from_conditional_params(images, means, inv_scales, logit_probs):
images = jnp.expand_dims(images, 1)
cdf = lambda offset: sigmoid((images - means + offset) * inv_scales)
upper_cdf = jnp.where(images == 1, 1, cdf(1 / 255))
lower_cdf = jnp.where(images == -1, 0, cdf(-1 / 255))
all_logprobs = jnp.sum(jnp.log(jnp.maximum(upper_cdf - lower_cdf, 1e-12)), -1)
log_mix_coeffs = logit_probs - logsumexp(logit_probs, -3, keepdims=True)
return jnp.sum(logsumexp(log_mix_coeffs + all_logprobs, axis=-3), axis=(-2, -1))
示例12: dropout
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import where [as 别名]
def dropout(self, x, p, seed=None):
seed = next(self.rng)
p = 1 - p
keep = random.bernoulli(seed, p, x.shape)
return np.where(keep, x / p, 0)
示例13: where
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import where [as 别名]
def where(self, mask, tensor_in_1, tensor_in_2):
return np.where(mask, tensor_in_1, tensor_in_2)
示例14: _numpy_delete
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import where [as 别名]
def _numpy_delete(x, idx):
"""
Gets the subarray from `x` where data from index `idx` on the first axis is removed.
"""
# NB: numpy.delete is not yet available in JAX
mask = jnp.arange(x.shape[0] - 1) < idx
return jnp.where(mask.reshape((-1,) + (1,) * (x.ndim - 1)), x[:-1], x[1:])
# TODO: consider to expose this functional style
示例15: _biased_transition_kernel
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import where [as 别名]
def _biased_transition_kernel(current_tree, new_tree):
# This function computes transition prob for main trees (ref [2], section A.3.2).
transition_prob = jnp.exp(new_tree.weight - current_tree.weight)
# If new tree is turning or diverging, we won't move the proposal
# to the new tree.
transition_prob = jnp.where(new_tree.turning | new_tree.diverging,
0.0, jnp.clip(transition_prob, a_max=1.0))
return transition_prob