本文整理汇总了Python中jax.numpy.inf方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.inf方法的具体用法?Python numpy.inf怎么用?Python numpy.inf使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类jax.numpy
的用法示例。
在下文中一共展示了numpy.inf方法的11个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test_numpy_backend_delegation
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import inf [as 别名]
def test_numpy_backend_delegation(self):
# Assert that we are getting JAX's numpy backend.
backend = backend_lib.backend()
numpy = backend_lib.numpy
self.assertEqual(jnp, backend["np"])
# Assert that `numpy` calls the appropriate gin configured functions and
# properties.
self.assertTrue(numpy.isinf(numpy.inf))
self.assertEqual(jnp.isinf, numpy.isinf)
self.assertEqual(jnp.inf, numpy.inf)
# Assert that we will now get the pure numpy backend.
self.override_gin("backend.name = 'numpy'")
backend = backend_lib.backend()
numpy = backend_lib.numpy
self.assertEqual(onp, backend["np"])
# Assert that `numpy` calls the appropriate gin configured functions and
# properties.
self.assertTrue(numpy.isinf(numpy.inf))
self.assertEqual(onp.isinf, numpy.isinf)
self.assertEqual(onp.inf, numpy.inf)
示例2: test_numpy_backend_delegation
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import inf [as 别名]
def test_numpy_backend_delegation(self):
# Assert that we are getting JAX's numpy backend.
backend = fastmath.backend()
numpy = fastmath.numpy
self.assertEqual(jnp, backend['np'])
# Assert that `numpy` calls the appropriate gin configured functions and
# properties.
self.assertTrue(numpy.isinf(numpy.inf))
self.assertEqual(jnp.isinf, numpy.isinf)
self.assertEqual(jnp.inf, numpy.inf)
# Assert that we will now get the pure numpy backend.
self.override_gin("backend.name = 'numpy'")
backend = fastmath.backend()
numpy = fastmath.numpy
self.assertEqual(onp, backend['np'])
# Assert that `numpy` calls the appropriate gin configured functions and
# properties.
self.assertTrue(numpy.isinf(numpy.inf))
self.assertEqual(onp.isinf, numpy.isinf)
self.assertEqual(onp.inf, numpy.inf)
示例3: _build_basetree
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import inf [as 别名]
def _build_basetree(vv_update, kinetic_fn, z, r, z_grad, inverse_mass_matrix, step_size, going_right,
energy_current, max_delta_energy):
step_size = jnp.where(going_right, step_size, -step_size)
z_new, r_new, potential_energy_new, z_new_grad = vv_update(
step_size,
inverse_mass_matrix,
(z, r, energy_current, z_grad),
)
energy_new = potential_energy_new + kinetic_fn(inverse_mass_matrix, r_new)
delta_energy = energy_new - energy_current
# Handles the NaN case.
delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf, delta_energy)
tree_weight = -delta_energy
diverging = delta_energy > max_delta_energy
accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0)
return TreeInfo(z_new, r_new, z_new_grad, z_new, r_new, z_new_grad,
z_new, potential_energy_new, z_new_grad, energy_new,
depth=0, weight=tree_weight, r_sum=r_new, turning=False,
diverging=diverging, sum_accept_probs=accept_prob, num_proposals=1)
示例4: clip_eta
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import inf [as 别名]
def clip_eta(eta, norm, eps):
"""
Helper function to clip the perturbation to epsilon norm ball.
:param eta: A tensor with the current perturbation.
:param norm: Order of the norm (mimics Numpy).
Possible values: np.inf or 2.
:param eps: Epsilon, bound of the perturbation.
"""
# Clipping perturbation eta to self.norm norm ball
if norm not in [np.inf, 2]:
raise ValueError('norm must be np.inf or 2.')
axis = list(range(1, len(eta.shape)))
avoid_zero_div = 1e-12
if norm == np.inf:
eta = np.clip(eta, a_min=-eps, a_max=eps)
elif norm == 2:
# avoid_zero_div must go inside sqrt to avoid a divide by zero in the gradient through this operation
norm = np.sqrt(np.maximum(avoid_zero_div, np.sum(np.square(eta), axis=axis, keepdims=True)))
# We must *clip* to within the norm ball, not *normalize* onto the surface of the ball
factor = np.minimum(1., np.divide(eps, norm))
eta = eta * factor
return eta
示例5: jax_max_pool
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import inf [as 别名]
def jax_max_pool(x, pool_size, strides, padding):
return _pooling_general(x, lax.max, -jnp.inf, pool_size=pool_size,
strides=strides, padding=padding)
示例6: pool2d
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import inf [as 别名]
def pool2d(self, x, pool_size, strides=(1, 1),
border_mode='valid', pool_mode='max'):
dims = (1,) + pool_size + (1,)
strides = (1,) + strides + (1,)
return lax.reduce_window(x, -np.inf, lax.max, dims, strides, border_mode)
示例7: validate_sample
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import inf [as 别名]
def validate_sample(log_prob_fn):
def wrapper(self, *args, **kwargs):
log_prob = log_prob_fn(self, *args, *kwargs)
if self._validate_args:
value = kwargs['value'] if 'value' in kwargs else args[0]
mask = self._validate_sample(value)
log_prob = jnp.where(mask, log_prob, -jnp.inf)
return log_prob
return wrapper
示例8: mean
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import inf [as 别名]
def mean(self):
return jnp.full(self.batch_shape, jnp.inf)
示例9: variance
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import inf [as 别名]
def variance(self):
return jnp.full(self.batch_shape, jnp.inf)
示例10: test_log_prob
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import inf [as 别名]
def test_log_prob(jax_dist, sp_dist, params, prepend_shape, jit):
jit_fn = _identity if not jit else jax.jit
jax_dist = jax_dist(*params)
rng_key = random.PRNGKey(0)
samples = jax_dist.sample(key=rng_key, sample_shape=prepend_shape)
assert jax_dist.log_prob(samples).shape == prepend_shape + jax_dist.batch_shape
if not sp_dist:
if isinstance(jax_dist, dist.TruncatedCauchy) or isinstance(jax_dist, dist.TruncatedNormal):
low, loc, scale = params
high = jnp.inf
sp_dist = osp.cauchy if isinstance(jax_dist, dist.TruncatedCauchy) else osp.norm
sp_dist = sp_dist(loc, scale)
expected = sp_dist.logpdf(samples) - jnp.log(sp_dist.cdf(high) - sp_dist.cdf(low))
assert_allclose(jit_fn(jax_dist.log_prob)(samples), expected, atol=1e-5)
return
pytest.skip('no corresponding scipy distn.')
if _is_batched_multivariate(jax_dist):
pytest.skip('batching not allowed in multivariate distns.')
if jax_dist.event_shape and prepend_shape:
# >>> d = sp.dirichlet([1.1, 1.1])
# >>> samples = d.rvs(size=(2,))
# >>> d.logpdf(samples)
# ValueError: The input vector 'x' must lie within the normal simplex ...
pytest.skip('batched samples cannot be scored by multivariate distributions.')
sp_dist = sp_dist(*params)
try:
expected = sp_dist.logpdf(samples)
except AttributeError:
expected = sp_dist.logpmf(samples)
except ValueError as e:
# precision issue: jnp.sum(x / jnp.sum(x)) = 0.99999994 != 1
if "The input vector 'x' must lie within the normal simplex." in str(e):
samples = samples.copy().astype('float64')
samples = samples / samples.sum(axis=-1, keepdims=True)
expected = sp_dist.logpdf(samples)
else:
raise e
assert_allclose(jit_fn(jax_dist.log_prob)(samples), expected, atol=1e-5)
示例11: fast_gradient_method
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import inf [as 别名]
def fast_gradient_method(model_fn, x, eps, norm, clip_min=None, clip_max=None, y=None,
targeted=False):
"""
JAX implementation of the Fast Gradient Method.
:param model_fn: a callable that takes an input tensor and returns the model logits.
:param x: input tensor.
:param eps: epsilon (input variation parameter); see https://arxiv.org/abs/1412.6572.
:param norm: Order of the norm (mimics NumPy). Possible values: np.inf or 2.
:param clip_min: (optional) float. Minimum float value for adversarial example components.
:param clip_max: (optional) float. Maximum float value for adversarial example components.
:param y: (optional) Tensor with one-hot true labels. If targeted is true, then provide the
target one-hot label. Otherwise, only provide this parameter if you'd like to use true
labels when crafting adversarial samples. Otherwise, model predictions are used
as labels to avoid the "label leaking" effect (explained in this paper:
https://arxiv.org/abs/1611.01236). Default is None. This argument does not have
to be a binary one-hot label (e.g., [0, 1, 0, 0]), it can be floating points values
that sum up to 1 (e.g., [0.05, 0.85, 0.05, 0.05]).
:param targeted: (optional) bool. Is the attack targeted or untargeted?
Untargeted, the default, will try to make the label incorrect.
Targeted will instead try to move in the direction of being more like y.
:return: a tensor for the adversarial example
"""
if norm not in [np.inf, 2]:
raise ValueError("Norm order must be either np.inf or 2.")
if y is None:
# Using model predictions as ground truth to avoid label leaking
x_labels = np.argmax(model_fn(x), 1)
y = one_hot(x_labels, 10)
def loss_adv(image, label):
pred = model_fn(image[None])
loss = - np.sum(logsoftmax(pred) * label)
if targeted:
loss = -loss
return loss
grads_fn = vmap(grad(loss_adv), in_axes=(0, 0), out_axes=0)
grads = grads_fn(x, y)
axis = list(range(1, len(grads.shape)))
avoid_zero_div = 1e-12
if norm == np.inf:
perturbation = eps * np.sign(grads)
elif norm == 1:
raise NotImplementedError("L_1 norm has not been implemented yet.")
elif norm == 2:
square = np.maximum(avoid_zero_div, np.sum(np.square(grads), axis=axis, keepdims=True))
perturbation = grads / np.sqrt(square)
adv_x = x + perturbation
# If clipping is needed, reset all values outside of [clip_min, clip_max]
if (clip_min is not None) or (clip_max is not None):
# We don't currently support one-sided clipping
assert clip_min is not None and clip_max is not None
adv_x = np.clip(adv_x, a_min=clip_min, a_max=clip_max)
return adv_x