本文整理汇总了Python中jax.numpy.clip方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.clip方法的具体用法?Python numpy.clip怎么用?Python numpy.clip使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类jax.numpy
的用法示例。
在下文中一共展示了numpy.clip方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: __init__
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import clip [as 别名]
def __init__(self, space, vocab_size, precision=2, max_range=(-100.0, 100.0)):
self._precision = precision
# Some gym envs (e.g. CartPole) have unreasonably high bounds for
# observations. We clip so we can represent them.
bounded_space = copy.copy(space)
(min_low, max_high) = max_range
bounded_space.low = np.maximum(space.low, min_low)
bounded_space.high = np.minimum(space.high, max_high)
if (not np.allclose(bounded_space.low, space.low) or
not np.allclose(bounded_space.high, space.high)):
logging.warning(
'Space limits %s, %s out of bounds %s. Clipping to %s, %s.',
str(space.low), str(space.high), str(max_range),
str(bounded_space.low), str(bounded_space.high)
)
super(BoxSpaceSerializer, self).__init__(bounded_space, vocab_size)
示例2: clip
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import clip [as 别名]
def clip(self, tensor_in, min_value, max_value):
"""
Clips (limits) the tensor values to be within a specified min and max.
Example:
>>> import pyhf
>>> pyhf.set_backend("jax")
>>> a = pyhf.tensorlib.astensor([-2, -1, 0, 1, 2])
>>> pyhf.tensorlib.clip(a, -1, 1)
DeviceArray([-1., -1., 0., 1., 1.], dtype=float64)
Args:
tensor_in (`tensor`): The input tensor object
min_value (`scalar` or `tensor` or `None`): The minimum value to be cliped to
max_value (`scalar` or `tensor` or `None`): The maximum value to be cliped to
Returns:
JAX ndarray: A clipped `tensor`
"""
return np.clip(tensor_in, min_value, max_value)
示例3: _build_basetree
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import clip [as 别名]
def _build_basetree(vv_update, kinetic_fn, z, r, z_grad, inverse_mass_matrix, step_size, going_right,
energy_current, max_delta_energy):
step_size = jnp.where(going_right, step_size, -step_size)
z_new, r_new, potential_energy_new, z_new_grad = vv_update(
step_size,
inverse_mass_matrix,
(z, r, energy_current, z_grad),
)
energy_new = potential_energy_new + kinetic_fn(inverse_mass_matrix, r_new)
delta_energy = energy_new - energy_current
# Handles the NaN case.
delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf, delta_energy)
tree_weight = -delta_energy
diverging = delta_energy > max_delta_energy
accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0)
return TreeInfo(z_new, r_new, z_new_grad, z_new, r_new, z_new_grad,
z_new, potential_energy_new, z_new_grad, energy_new,
depth=0, weight=tree_weight, r_sum=r_new, turning=False,
diverging=diverging, sum_accept_probs=accept_prob, num_proposals=1)
示例4: signed_stick_breaking_tril
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import clip [as 别名]
def signed_stick_breaking_tril(t):
# make sure that t in (-1, 1)
eps = jnp.finfo(t.dtype).eps
t = jnp.clip(t, a_min=(-1 + eps), a_max=(1 - eps))
# transform t to tril matrix with identity diagonal
r = vec_to_tril_matrix(t, diagonal=-1)
# apply stick-breaking on the squared values;
# we omit the step of computing s = z * z_cumprod by using the fact:
# y = sign(r) * s = sign(r) * sqrt(z * z_cumprod) = r * sqrt(z_cumprod)
z = r ** 2
z1m_cumprod = jnp.cumprod(1 - z, axis=-1)
z1m_cumprod_sqrt = jnp.sqrt(z1m_cumprod)
pad_width = [(0, 0)] * z.ndim
pad_width[-1] = (1, 0)
z1m_cumprod_sqrt_shifted = jnp.pad(z1m_cumprod_sqrt[..., :-1], pad_width,
mode="constant", constant_values=1.)
y = (r + jnp.identity(r.shape[-1])) * z1m_cumprod_sqrt_shifted
return y
示例5: clip_eta
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import clip [as 别名]
def clip_eta(eta, norm, eps):
"""
Helper function to clip the perturbation to epsilon norm ball.
:param eta: A tensor with the current perturbation.
:param norm: Order of the norm (mimics Numpy).
Possible values: np.inf or 2.
:param eps: Epsilon, bound of the perturbation.
"""
# Clipping perturbation eta to self.norm norm ball
if norm not in [np.inf, 2]:
raise ValueError('norm must be np.inf or 2.')
axis = list(range(1, len(eta.shape)))
avoid_zero_div = 1e-12
if norm == np.inf:
eta = np.clip(eta, a_min=-eps, a_max=eps)
elif norm == 2:
# avoid_zero_div must go inside sqrt to avoid a divide by zero in the gradient through this operation
norm = np.sqrt(np.maximum(avoid_zero_div, np.sum(np.square(eta), axis=axis, keepdims=True)))
# We must *clip* to within the norm ball, not *normalize* onto the surface of the ball
factor = np.minimum(1., np.divide(eps, norm))
eta = eta * factor
return eta
示例6: clipped_probab_ratios
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import clip [as 别名]
def clipped_probab_ratios(probab_ratios, epsilon=0.2):
return np.clip(probab_ratios, 1 - epsilon, 1 + epsilon)
示例7: _max
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import clip [as 别名]
def _max(x, y):
return np.clip(y, a_min=x, a_max=None)
示例8: _min
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import clip [as 别名]
def _min(x, y):
return np.clip(y, a_min=None, a_max=x)
示例9: _reciprocal
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import clip [as 别名]
def _reciprocal(x):
result = np.clip(np.reciprocal(x), a_max=np.finfo(x.dtype).max)
return result
示例10: _safesub
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import clip [as 别名]
def _safesub(x, y):
try:
finfo = np.finfo(y.dtype)
except ValueError:
finfo = np.iinfo(y.dtype)
return x + np.clip(-y, a_min=None, a_max=finfo.max)
示例11: serialize
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import clip [as 别名]
def serialize(self, data):
array = data
batch_size = array.shape[0]
array = (array - self._space.low) / (self._space.high - self._space.low)
array = np.clip(array, 0, 1)
digits = []
for digit_index in range(-1, -self._precision - 1, -1):
threshold = self._vocab_size ** digit_index
digit = np.array(array / threshold).astype(np.int32)
# For the corner case of x == high.
digit = np.where(digit == self._vocab_size, digit - 1, digit)
digits.append(digit)
array -= digit * threshold
digits = np.stack(digits, axis=-1)
return np.reshape(digits, (batch_size, -1))
示例12: _get_num_steps
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import clip [as 别名]
def _get_num_steps(step_size, trajectory_length):
num_steps = jnp.clip(trajectory_length / step_size, a_min=1)
# NB: casting to jnp.int64 does not take effect (returns jnp.int32 instead)
# if jax_enable_x64 is False
return num_steps.astype(canonicalize_dtype(jnp.int64))
示例13: _biased_transition_kernel
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import clip [as 别名]
def _biased_transition_kernel(current_tree, new_tree):
# This function computes transition prob for main trees (ref [2], section A.3.2).
transition_prob = jnp.exp(new_tree.weight - current_tree.weight)
# If new tree is turning or diverging, we won't move the proposal
# to the new tree.
transition_prob = jnp.where(new_tree.turning | new_tree.diverging,
0.0, jnp.clip(transition_prob, a_max=1.0))
return transition_prob
示例14: update
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import clip [as 别名]
def update(self, g, state):
i, opt_state = state
# clip norm
g = tree_map(lambda g_: jnp.clip(g_, a_min=-self.clip_norm, a_max=self.clip_norm), g)
opt_state = self.update_fn(i, g, opt_state)
return i + 1, opt_state
示例15: _clipped_expit
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import clip [as 别名]
def _clipped_expit(x):
finfo = jnp.finfo(get_dtype(x))
return jnp.clip(expit(x), a_min=finfo.tiny, a_max=1. - finfo.eps)