当前位置: 首页>>代码示例>>Python>>正文


Python jax.numpy方法代码示例

本文整理汇总了Python中jax.numpy方法的典型用法代码示例。如果您正苦于以下问题:Python jax.numpy方法的具体用法?Python jax.numpy怎么用?Python jax.numpy使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在jax的用法示例。


在下文中一共展示了jax.numpy方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: nested_stack

# 需要导入模块: import jax [as 别名]
# 或者: from jax import numpy [as 别名]
def nested_stack(objs, axis=0, np_module=np):
  """Stacks the numpy arrays inside any dicts/lists/tuples in `objs`.

  Args:
    objs: List of nested structures to stack.
    axis: Axis to stack along.
    np_module: numpy module to use - typically numpy or jax.numpy.

  Returns:
    An object with the same nested structure as each element of `objs`, with
    leaves stacked together into numpy arrays. Nones are propagated, i.e. if
    each element of the stacked sequence is None, the output will be None.
  """
  # nested_map the stacking operation, but stopping at level 1 so at tuples of
  # numpy arrays.
  return nested_map(
      lambda x: np_module.stack(x, axis=axis),
      nested_zip(objs),
      level=1,
  ) 
开发者ID:google,项目名称:trax,代码行数:22,代码来源:jax.py

示例2: fourier_complex_morlet

# 需要导入模块: import jax [as 别名]
# 或者: from jax import numpy [as 别名]
def fourier_complex_morlet(bandwidths, centers, N):
    """Complex Morlet wavelet in Fourier

    Parameters
    ----------

    bandwidths: array
        the bandwidth of the wavelet

    centers: array
        the centers of the wavelet

    freqs: array (optional)
        the frequency sampling in radion going from 0 to pi and back to 0
        :param N:

    """

    freqs = T.linspace(0, 2 * numpy.pi, N)
    envelop = T.exp(-0.25 * (freqs - centers) ** 2 * bandwidths ** 2)
    H = (freqs <= numpy.pi).astype("float32")
    return envelop * H 
开发者ID:SymJAX,项目名称:SymJAX,代码行数:24,代码来源:signal.py

示例3: update

# 需要导入模块: import jax [as 别名]
# 或者: from jax import numpy [as 别名]
def update(self, update_value):
        """assign a new value for the variable"""
        new_value = symjax.current_graph().get(update_value)

        if self.shape != jax.numpy.shape(new_value):
            warnings.warn(
                "Variable and update {} {}".format(self, new_value)
                + "are not the same shape... attempting to reshape"
            )
            new_value = jax.numpy.reshape(new_value, self.shape)

        if hasattr(new_value, "dtype"):
            ntype = new_value.dtype
        else:
            ntype = type(new_value)
        if self.dtype != ntype:
            warnings.warn(
                "Variable and update {} {}".format(self, new_value)
                + "are not the same dtype... attempting to cast"
            )

            new_value = jax.numpy.astype(new_value, self.dtype)

        self._value = new_value 
开发者ID:SymJAX,项目名称:SymJAX,代码行数:26,代码来源:base.py

示例4: __init__

# 需要导入模块: import jax [as 别名]
# 或者: from jax import numpy [as 别名]
def __init__(self, dtype: Optional[np.dtype] = None) -> None:
    # pylint: disable=global-variable-undefined
    global libjax  # Jax module
    global jnp  # jax.numpy module
    global jsp  # jax.scipy module
    super(JaxBackend, self).__init__()
    try:
      #pylint: disable=import-outside-toplevel
      import jax
    except ImportError:
      raise ImportError("Jax not installed, please switch to a different "
                        "backend or install Jax.")
    libjax = jax
    jnp = libjax.numpy
    jsp = libjax.scipy
    self.name = "jax"
    self._dtype = np.dtype(dtype) if dtype is not None else None 
开发者ID:google,项目名称:TensorNetwork,代码行数:19,代码来源:jax_backend.py

示例5: dataset_as_numpy

# 需要导入模块: import jax [as 别名]
# 或者: from jax import numpy [as 别名]
def dataset_as_numpy(*args, **kwargs):
  return backend()["dataset_as_numpy"](*args, **kwargs)


# For numpy and random modules, we need to call "backend()" lazily, only when
# the function is called -- so that it can be set by gin configs.
# (Otherwise, backend() is called on import before gin-config is parsed.)
# To do that, we make objects to encapsulated these modules. 
开发者ID:yyht,项目名称:BERT,代码行数:10,代码来源:backend.py

示例6: backend

# 需要导入模块: import jax [as 别名]
# 或者: from jax import numpy [as 别名]
def backend(name="jax"):
  name = name if not override_backend_name else override_backend_name
  if name == "numpy":
    return _NUMPY_BACKEND
  return _JAX_BACKEND 
开发者ID:yyht,项目名称:BERT,代码行数:7,代码来源:backend.py

示例7: _to_numpy

# 需要导入模块: import jax [as 别名]
# 或者: from jax import numpy [as 别名]
def _to_numpy(x):
  """Converts non-NumPy tensors to NumPy arrays."""
  return x if isinstance(x, np.ndarray) else x.numpy() 
开发者ID:google,项目名称:trax,代码行数:5,代码来源:jax.py

示例8: _extract_image_patches

# 需要导入模块: import jax [as 别名]
# 或者: from jax import numpy [as 别名]
def _extract_image_patches(
    image, window_shape, hop=1, data_format="NCHW", mode="valid"
):
    if mode == "same":
        p1 = window_shape[0] - 1
        p2 = window_shape[1] - 1
        image = jnp.pad(
            image, [(0, 0), (0, 0), (p1 // 2, p1 - p1 // 2), (p2 // 2, p2 - p2 // 2)]
        )
    if not hasattr(hop, "__len__"):
        hop = (hop, hop)
    if data_format == "NCHW":

        # compute the number of windows in both dimensions
        N = (
            (image.shape[2] - window_shape[0]) // hop[0] + 1,
            (image.shape[3] - window_shape[1]) // hop[1] + 1,
        )

        # compute the base indices of a 2d patch
        patch = jnp.arange(numpy.prod(window_shape)).reshape(window_shape)
        offset = jnp.expand_dims(jnp.arange(window_shape[0]), 1)
        patch_indices = patch + offset * (image.shape[3] - window_shape[1])

        # create all the shifted versions of it
        ver_shifts = jnp.reshape(
            jnp.arange(N[0]) * hop[0] * image.shape[3], (-1, 1, 1, 1)
        )
        hor_shifts = jnp.reshape(jnp.arange(N[1]) * hop[1], (-1, 1, 1))
        all_cols = patch_indices + jnp.reshape(jnp.arange(N[1]) * hop[1], (-1, 1, 1))
        indices = patch_indices + ver_shifts + hor_shifts

        # now extract shape (1, 1, H'W'a'b')
        flat_indices = jnp.reshape(indices, [1, 1, -1])
        # shape is now (N, C, W*H)
        flat_image = jnp.reshape(image, (image.shape[0], image.shape[1], -1))
        # shape is now (N, C)
        patches = jnp.take_along_axis(flat_image, flat_indices, 2)
        return jnp.reshape(patches, image.shape[:2] + N + tuple(window_shape))
    else:
        error 
开发者ID:SymJAX,项目名称:SymJAX,代码行数:43,代码来源:ops_special.py

示例9: littewood_paley_normalization

# 需要导入模块: import jax [as 别名]
# 或者: from jax import numpy [as 别名]
def littewood_paley_normalization(filter_bank, down=None, up=None):
    lp = T.abs(filter_bank).sum(0)
    freq = T.linspace(0, 2 * numpy.pi, lp.shape[0])
    down = 0 if down is None else down
    up = numpy.pi or up
    lp = T.where(T.logical_and(freq >= down, freq <= up), lp, 1)
    return filter_bank / lp 
开发者ID:SymJAX,项目名称:SymJAX,代码行数:9,代码来源:signal.py

示例10: tukey

# 需要导入模块: import jax [as 别名]
# 或者: from jax import numpy [as 别名]
def tukey(M, alpha=0.5):
    r"""Return a Tukey window, also known as a tapered cosine window.
    Parameters
    ----------
    M : int
        Number of points in the output window. If zero or less, an empty
        array is returned.
    alpha : float, optional
        Shape parameter of the Tukey window, representing the fraction of the
        window inside the cosine tapered region.
        If zero, the Tukey window is equivalent to a rectangular window.
        If one, the Tukey window is equivalent to a Hann window.
    Returns
    -------
    w : ndarray
        The window, with the maximum value normalized to 1 (though the value 1
        does not appear if `M` is even and `sym` is True).
    References
    ----------
    .. [1] Harris, Fredric J. (Jan 1978). "On the use of Windows for Harmonic
           Analysis with the Discrete Fourier Transform". Proceedings of the
           IEEE 66 (1): 51-83. :doi:`10.1109/PROC.1978.10837`
    .. [2] Wikipedia, "Window function",
           https://en.wikipedia.org/wiki/Window_function#Tukey_window
    """
    n = T.arange(0, M)
    width = int(numpy.floor(alpha * (M - 1) / 2.0))
    n1 = n[0 : width + 1]
    n2 = n[width + 1 : M - width - 1]
    n3 = n[M - width - 1 :]

    w1 = 0.5 * (1 + T.cos(numpy.pi * (-1 + 2.0 * n1 / alpha / (M - 1))))
    w2 = T.ones(n2.shape)
    w3 = 0.5 * (1 + T.cos(numpy.pi * (-2.0 / alpha + 1 + 2.0 * n3 / alpha / (M - 1))))

    w = T.concatenate((w1, w2, w3))

    return w 
开发者ID:SymJAX,项目名称:SymJAX,代码行数:40,代码来源:signal.py

示例11: freq_to_mel

# 需要导入模块: import jax [as 别名]
# 或者: from jax import numpy [as 别名]
def freq_to_mel(f, option="linear"):
    # convert frequency to mel with
    if option == "linear":

        # linear part slope
        f_sp = 200.0 / 3

        # Fill in the log-scale part
        min_log_hz = 1000.0  # beginning of log region (Hz)
        min_log_mel = min_log_hz / f_sp  # same (Mels)
        logstep = numpy.log(6.4) / 27.0  # step size for log region
        mel = min_log_mel + T.log(f / min_log_hz) / logstep
        return T.where(f >= min_log_hz, mel, f / f_sp)
    else:
        return 2595 * T.log10(1 + f / 700) 
开发者ID:SymJAX,项目名称:SymJAX,代码行数:17,代码来源:signal.py

示例12: isvar

# 需要导入模块: import jax [as 别名]
# 或者: from jax import numpy [as 别名]
def isvar(item):
    """ check whether an item (possibly a nested list etc) contains a variable
    (any subtype of Tensor) """
    # in case of nested lists/tuples, recursively call the function on it
    if isinstance(item, slice):
        return False
    elif isinstance(item, list) or isinstance(item, tuple):
        return numpy.sum([isvar(value) for value in item])
    # otherwise cheack that it is a subtype of Tensor or a Tracer and not
    # a callable
    else:
        cond1 = isinstance(item, Tensor) or type(item) in [Constant, OpTuple]
        #        cond2 = isinstance(item, jax.interpreters.partial_eval.JaxprTracer)
        cond3 = callable(item)
        return cond1 and not cond3  # (cond1 or cond2) and cond3 
开发者ID:SymJAX,项目名称:SymJAX,代码行数:17,代码来源:base.py

示例13: update_numpydoc

# 需要导入模块: import jax [as 别名]
# 或者: from jax import numpy [as 别名]
def update_numpydoc(docstr, fun, op):
    """Transforms the numpy docstring to remove references of
       parameters that are supported by the numpy version but not the JAX version"""

    # Some numpy functions have an extra tab at the beginning of each line,
    # If this function is one of those we remove this extra tab from all the lines
    if not hasattr(op, "__code__"):
        return docstr
    if docstr[:4] == "    ":
        lines = docstr.split("\n")
        for idx, line in enumerate(lines):
            lines[idx] = line.replace("    ", "", 1)
        docstr = "\n".join(lines)

    begin_idx = docstr.find("Parameters")
    begin_idx = docstr.find("--\n", begin_idx) + 2
    end_idx = docstr.find("Returns", begin_idx)

    parameters = docstr[begin_idx:end_idx]
    param_list = parameters.replace("\n    ", "@@").split("\n")
    for idx, p in enumerate(param_list):
        param = p[: p.find(" : ")].split(", ")[0]
        if param not in op.__code__.co_varnames:
            param_list[idx] = ""
    param_list = [param for param in param_list if param != ""]
    parameters = "\n".join(param_list).replace("@@", "\n    ")
    return docstr[: begin_idx + 1] + parameters + docstr[end_idx - 2 :] 
开发者ID:SymJAX,项目名称:SymJAX,代码行数:29,代码来源:base.py

示例14: inv

# 需要导入模块: import jax [as 别名]
# 或者: from jax import numpy [as 别名]
def inv(self, matrix: Tensor) -> Tensor:
    if len(matrix.shape) > 2:
      raise ValueError("input to numpy backend method `inv` has shape {}."
                       " Only matrices are supported.".format(matrix.shape))
    return jnp.linalg.inv(matrix) 
开发者ID:google,项目名称:TensorNetwork,代码行数:7,代码来源:jax_backend.py

示例15: expm

# 需要导入模块: import jax [as 别名]
# 或者: from jax import numpy [as 别名]
def expm(self, matrix: Tensor) -> Tensor:
    if len(matrix.shape) != 2:
      raise ValueError("input to numpy backend method `expm` has shape {}."
                       " Only matrices are supported.".format(matrix.shape))
    if matrix.shape[0] != matrix.shape[1]:
      raise ValueError("input to numpy backend method `expm` only supports"
                       " N*N matrix, {x}*{y} matrix is given".format(
                           x=matrix.shape[0], y=matrix.shape[1]))
    # pylint: disable=no-member
    return jsp.linalg.expm(matrix) 
开发者ID:google,项目名称:TensorNetwork,代码行数:12,代码来源:jax_backend.py


注:本文中的jax.numpy方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。