当前位置: 首页>>代码示例>>Python>>正文


Python numpy.clip方法代码示例

本文整理汇总了Python中jax.numpy.clip方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.clip方法的具体用法?Python numpy.clip怎么用?Python numpy.clip使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在jax.numpy的用法示例。


在下文中一共展示了numpy.clip方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __init__

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import clip [as 别名]
def __init__(self, space, vocab_size, precision=2, max_range=(-100.0, 100.0)):
    self._precision = precision

    # Some gym envs (e.g. CartPole) have unreasonably high bounds for
    # observations. We clip so we can represent them.
    bounded_space = copy.copy(space)
    (min_low, max_high) = max_range
    bounded_space.low = np.maximum(space.low, min_low)
    bounded_space.high = np.minimum(space.high, max_high)
    if (not np.allclose(bounded_space.low, space.low) or
        not np.allclose(bounded_space.high, space.high)):
      logging.warning(
          'Space limits %s, %s out of bounds %s. Clipping to %s, %s.',
          str(space.low), str(space.high), str(max_range),
          str(bounded_space.low), str(bounded_space.high)
      )

    super(BoxSpaceSerializer, self).__init__(bounded_space, vocab_size) 
开发者ID:google,项目名称:trax,代码行数:20,代码来源:space_serializer.py

示例2: clip

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import clip [as 别名]
def clip(self, tensor_in, min_value, max_value):
        """
        Clips (limits) the tensor values to be within a specified min and max.

        Example:

            >>> import pyhf
            >>> pyhf.set_backend("jax")
            >>> a = pyhf.tensorlib.astensor([-2, -1, 0, 1, 2])
            >>> pyhf.tensorlib.clip(a, -1, 1)
            DeviceArray([-1., -1.,  0.,  1.,  1.], dtype=float64)

        Args:
            tensor_in (`tensor`): The input tensor object
            min_value (`scalar` or `tensor` or `None`): The minimum value to be cliped to
            max_value (`scalar` or `tensor` or `None`): The maximum value to be cliped to

        Returns:
            JAX ndarray: A clipped `tensor`
        """
        return np.clip(tensor_in, min_value, max_value) 
开发者ID:scikit-hep,项目名称:pyhf,代码行数:23,代码来源:jax_backend.py

示例3: _build_basetree

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import clip [as 别名]
def _build_basetree(vv_update, kinetic_fn, z, r, z_grad, inverse_mass_matrix, step_size, going_right,
                    energy_current, max_delta_energy):
    step_size = jnp.where(going_right, step_size, -step_size)
    z_new, r_new, potential_energy_new, z_new_grad = vv_update(
        step_size,
        inverse_mass_matrix,
        (z, r, energy_current, z_grad),
    )

    energy_new = potential_energy_new + kinetic_fn(inverse_mass_matrix, r_new)
    delta_energy = energy_new - energy_current
    # Handles the NaN case.
    delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf, delta_energy)
    tree_weight = -delta_energy

    diverging = delta_energy > max_delta_energy
    accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0)
    return TreeInfo(z_new, r_new, z_new_grad, z_new, r_new, z_new_grad,
                    z_new, potential_energy_new, z_new_grad, energy_new,
                    depth=0, weight=tree_weight, r_sum=r_new, turning=False,
                    diverging=diverging, sum_accept_probs=accept_prob, num_proposals=1) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:23,代码来源:hmc_util.py

示例4: signed_stick_breaking_tril

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import clip [as 别名]
def signed_stick_breaking_tril(t):
    # make sure that t in (-1, 1)
    eps = jnp.finfo(t.dtype).eps
    t = jnp.clip(t, a_min=(-1 + eps), a_max=(1 - eps))
    # transform t to tril matrix with identity diagonal
    r = vec_to_tril_matrix(t, diagonal=-1)

    # apply stick-breaking on the squared values;
    # we omit the step of computing s = z * z_cumprod by using the fact:
    #     y = sign(r) * s = sign(r) * sqrt(z * z_cumprod) = r * sqrt(z_cumprod)
    z = r ** 2
    z1m_cumprod = jnp.cumprod(1 - z, axis=-1)
    z1m_cumprod_sqrt = jnp.sqrt(z1m_cumprod)

    pad_width = [(0, 0)] * z.ndim
    pad_width[-1] = (1, 0)
    z1m_cumprod_sqrt_shifted = jnp.pad(z1m_cumprod_sqrt[..., :-1], pad_width,
                                       mode="constant", constant_values=1.)
    y = (r + jnp.identity(r.shape[-1])) * z1m_cumprod_sqrt_shifted
    return y 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:22,代码来源:util.py

示例5: clip_eta

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import clip [as 别名]
def clip_eta(eta, norm, eps):
  """
  Helper function to clip the perturbation to epsilon norm ball.
  :param eta: A tensor with the current perturbation.
  :param norm: Order of the norm (mimics Numpy).
              Possible values: np.inf or 2.
  :param eps: Epsilon, bound of the perturbation.
  """

  # Clipping perturbation eta to self.norm norm ball
  if norm not in [np.inf, 2]:
    raise ValueError('norm must be np.inf or 2.')

  axis = list(range(1, len(eta.shape)))
  avoid_zero_div = 1e-12
  if norm == np.inf:
    eta = np.clip(eta, a_min=-eps, a_max=eps)
  elif norm == 2:
    # avoid_zero_div must go inside sqrt to avoid a divide by zero in the gradient through this operation
    norm = np.sqrt(np.maximum(avoid_zero_div, np.sum(np.square(eta), axis=axis, keepdims=True)))
    # We must *clip* to within the norm ball, not *normalize* onto the surface of the ball
    factor = np.minimum(1., np.divide(eps, norm))
    eta = eta * factor
  return eta 
开发者ID:tensorflow,项目名称:cleverhans,代码行数:26,代码来源:utils.py

示例6: clipped_probab_ratios

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import clip [as 别名]
def clipped_probab_ratios(probab_ratios, epsilon=0.2):
  return np.clip(probab_ratios, 1 - epsilon, 1 + epsilon) 
开发者ID:yyht,项目名称:BERT,代码行数:4,代码来源:ppo.py

示例7: _max

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import clip [as 别名]
def _max(x, y):
    return np.clip(y, a_min=x, a_max=None) 
开发者ID:pyro-ppl,项目名称:funsor,代码行数:4,代码来源:ops.py

示例8: _min

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import clip [as 别名]
def _min(x, y):
    return np.clip(y, a_min=None, a_max=x) 
开发者ID:pyro-ppl,项目名称:funsor,代码行数:4,代码来源:ops.py

示例9: _reciprocal

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import clip [as 别名]
def _reciprocal(x):
    result = np.clip(np.reciprocal(x), a_max=np.finfo(x.dtype).max)
    return result 
开发者ID:pyro-ppl,项目名称:funsor,代码行数:5,代码来源:ops.py

示例10: _safesub

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import clip [as 别名]
def _safesub(x, y):
    try:
        finfo = np.finfo(y.dtype)
    except ValueError:
        finfo = np.iinfo(y.dtype)
    return x + np.clip(-y, a_min=None, a_max=finfo.max) 
开发者ID:pyro-ppl,项目名称:funsor,代码行数:8,代码来源:ops.py

示例11: serialize

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import clip [as 别名]
def serialize(self, data):
    array = data
    batch_size = array.shape[0]
    array = (array - self._space.low) / (self._space.high - self._space.low)
    array = np.clip(array, 0, 1)
    digits = []
    for digit_index in range(-1, -self._precision - 1, -1):
      threshold = self._vocab_size ** digit_index
      digit = np.array(array / threshold).astype(np.int32)
      # For the corner case of x == high.
      digit = np.where(digit == self._vocab_size, digit - 1, digit)
      digits.append(digit)
      array -= digit * threshold
    digits = np.stack(digits, axis=-1)
    return np.reshape(digits, (batch_size, -1)) 
开发者ID:google,项目名称:trax,代码行数:17,代码来源:space_serializer.py

示例12: _get_num_steps

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import clip [as 别名]
def _get_num_steps(step_size, trajectory_length):
    num_steps = jnp.clip(trajectory_length / step_size, a_min=1)
    # NB: casting to jnp.int64 does not take effect (returns jnp.int32 instead)
    # if jax_enable_x64 is False
    return num_steps.astype(canonicalize_dtype(jnp.int64)) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:7,代码来源:mcmc.py

示例13: _biased_transition_kernel

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import clip [as 别名]
def _biased_transition_kernel(current_tree, new_tree):
    # This function computes transition prob for main trees (ref [2], section A.3.2).
    transition_prob = jnp.exp(new_tree.weight - current_tree.weight)
    # If new tree is turning or diverging, we won't move the proposal
    # to the new tree.
    transition_prob = jnp.where(new_tree.turning | new_tree.diverging,
                                0.0, jnp.clip(transition_prob, a_max=1.0))
    return transition_prob 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:10,代码来源:hmc_util.py

示例14: update

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import clip [as 别名]
def update(self, g, state):
        i, opt_state = state
        # clip norm
        g = tree_map(lambda g_: jnp.clip(g_, a_min=-self.clip_norm, a_max=self.clip_norm), g)
        opt_state = self.update_fn(i, g, opt_state)
        return i + 1, opt_state 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:8,代码来源:optim.py

示例15: _clipped_expit

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import clip [as 别名]
def _clipped_expit(x):
    finfo = jnp.finfo(get_dtype(x))
    return jnp.clip(expit(x), a_min=finfo.tiny, a_max=1. - finfo.eps) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:5,代码来源:transforms.py


注:本文中的jax.numpy.clip方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。