本文整理汇总了Python中jax.numpy方法的典型用法代码示例。如果您正苦于以下问题:Python jax.numpy方法的具体用法?Python jax.numpy怎么用?Python jax.numpy使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类jax
的用法示例。
在下文中一共展示了jax.numpy方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: nested_stack
# 需要导入模块: import jax [as 别名]
# 或者: from jax import numpy [as 别名]
def nested_stack(objs, axis=0, np_module=np):
"""Stacks the numpy arrays inside any dicts/lists/tuples in `objs`.
Args:
objs: List of nested structures to stack.
axis: Axis to stack along.
np_module: numpy module to use - typically numpy or jax.numpy.
Returns:
An object with the same nested structure as each element of `objs`, with
leaves stacked together into numpy arrays. Nones are propagated, i.e. if
each element of the stacked sequence is None, the output will be None.
"""
# nested_map the stacking operation, but stopping at level 1 so at tuples of
# numpy arrays.
return nested_map(
lambda x: np_module.stack(x, axis=axis),
nested_zip(objs),
level=1,
)
示例2: fourier_complex_morlet
# 需要导入模块: import jax [as 别名]
# 或者: from jax import numpy [as 别名]
def fourier_complex_morlet(bandwidths, centers, N):
"""Complex Morlet wavelet in Fourier
Parameters
----------
bandwidths: array
the bandwidth of the wavelet
centers: array
the centers of the wavelet
freqs: array (optional)
the frequency sampling in radion going from 0 to pi and back to 0
:param N:
"""
freqs = T.linspace(0, 2 * numpy.pi, N)
envelop = T.exp(-0.25 * (freqs - centers) ** 2 * bandwidths ** 2)
H = (freqs <= numpy.pi).astype("float32")
return envelop * H
示例3: update
# 需要导入模块: import jax [as 别名]
# 或者: from jax import numpy [as 别名]
def update(self, update_value):
"""assign a new value for the variable"""
new_value = symjax.current_graph().get(update_value)
if self.shape != jax.numpy.shape(new_value):
warnings.warn(
"Variable and update {} {}".format(self, new_value)
+ "are not the same shape... attempting to reshape"
)
new_value = jax.numpy.reshape(new_value, self.shape)
if hasattr(new_value, "dtype"):
ntype = new_value.dtype
else:
ntype = type(new_value)
if self.dtype != ntype:
warnings.warn(
"Variable and update {} {}".format(self, new_value)
+ "are not the same dtype... attempting to cast"
)
new_value = jax.numpy.astype(new_value, self.dtype)
self._value = new_value
示例4: __init__
# 需要导入模块: import jax [as 别名]
# 或者: from jax import numpy [as 别名]
def __init__(self, dtype: Optional[np.dtype] = None) -> None:
# pylint: disable=global-variable-undefined
global libjax # Jax module
global jnp # jax.numpy module
global jsp # jax.scipy module
super(JaxBackend, self).__init__()
try:
#pylint: disable=import-outside-toplevel
import jax
except ImportError:
raise ImportError("Jax not installed, please switch to a different "
"backend or install Jax.")
libjax = jax
jnp = libjax.numpy
jsp = libjax.scipy
self.name = "jax"
self._dtype = np.dtype(dtype) if dtype is not None else None
示例5: dataset_as_numpy
# 需要导入模块: import jax [as 别名]
# 或者: from jax import numpy [as 别名]
def dataset_as_numpy(*args, **kwargs):
return backend()["dataset_as_numpy"](*args, **kwargs)
# For numpy and random modules, we need to call "backend()" lazily, only when
# the function is called -- so that it can be set by gin configs.
# (Otherwise, backend() is called on import before gin-config is parsed.)
# To do that, we make objects to encapsulated these modules.
示例6: backend
# 需要导入模块: import jax [as 别名]
# 或者: from jax import numpy [as 别名]
def backend(name="jax"):
name = name if not override_backend_name else override_backend_name
if name == "numpy":
return _NUMPY_BACKEND
return _JAX_BACKEND
示例7: _to_numpy
# 需要导入模块: import jax [as 别名]
# 或者: from jax import numpy [as 别名]
def _to_numpy(x):
"""Converts non-NumPy tensors to NumPy arrays."""
return x if isinstance(x, np.ndarray) else x.numpy()
示例8: _extract_image_patches
# 需要导入模块: import jax [as 别名]
# 或者: from jax import numpy [as 别名]
def _extract_image_patches(
image, window_shape, hop=1, data_format="NCHW", mode="valid"
):
if mode == "same":
p1 = window_shape[0] - 1
p2 = window_shape[1] - 1
image = jnp.pad(
image, [(0, 0), (0, 0), (p1 // 2, p1 - p1 // 2), (p2 // 2, p2 - p2 // 2)]
)
if not hasattr(hop, "__len__"):
hop = (hop, hop)
if data_format == "NCHW":
# compute the number of windows in both dimensions
N = (
(image.shape[2] - window_shape[0]) // hop[0] + 1,
(image.shape[3] - window_shape[1]) // hop[1] + 1,
)
# compute the base indices of a 2d patch
patch = jnp.arange(numpy.prod(window_shape)).reshape(window_shape)
offset = jnp.expand_dims(jnp.arange(window_shape[0]), 1)
patch_indices = patch + offset * (image.shape[3] - window_shape[1])
# create all the shifted versions of it
ver_shifts = jnp.reshape(
jnp.arange(N[0]) * hop[0] * image.shape[3], (-1, 1, 1, 1)
)
hor_shifts = jnp.reshape(jnp.arange(N[1]) * hop[1], (-1, 1, 1))
all_cols = patch_indices + jnp.reshape(jnp.arange(N[1]) * hop[1], (-1, 1, 1))
indices = patch_indices + ver_shifts + hor_shifts
# now extract shape (1, 1, H'W'a'b')
flat_indices = jnp.reshape(indices, [1, 1, -1])
# shape is now (N, C, W*H)
flat_image = jnp.reshape(image, (image.shape[0], image.shape[1], -1))
# shape is now (N, C)
patches = jnp.take_along_axis(flat_image, flat_indices, 2)
return jnp.reshape(patches, image.shape[:2] + N + tuple(window_shape))
else:
error
示例9: littewood_paley_normalization
# 需要导入模块: import jax [as 别名]
# 或者: from jax import numpy [as 别名]
def littewood_paley_normalization(filter_bank, down=None, up=None):
lp = T.abs(filter_bank).sum(0)
freq = T.linspace(0, 2 * numpy.pi, lp.shape[0])
down = 0 if down is None else down
up = numpy.pi or up
lp = T.where(T.logical_and(freq >= down, freq <= up), lp, 1)
return filter_bank / lp
示例10: tukey
# 需要导入模块: import jax [as 别名]
# 或者: from jax import numpy [as 别名]
def tukey(M, alpha=0.5):
r"""Return a Tukey window, also known as a tapered cosine window.
Parameters
----------
M : int
Number of points in the output window. If zero or less, an empty
array is returned.
alpha : float, optional
Shape parameter of the Tukey window, representing the fraction of the
window inside the cosine tapered region.
If zero, the Tukey window is equivalent to a rectangular window.
If one, the Tukey window is equivalent to a Hann window.
Returns
-------
w : ndarray
The window, with the maximum value normalized to 1 (though the value 1
does not appear if `M` is even and `sym` is True).
References
----------
.. [1] Harris, Fredric J. (Jan 1978). "On the use of Windows for Harmonic
Analysis with the Discrete Fourier Transform". Proceedings of the
IEEE 66 (1): 51-83. :doi:`10.1109/PROC.1978.10837`
.. [2] Wikipedia, "Window function",
https://en.wikipedia.org/wiki/Window_function#Tukey_window
"""
n = T.arange(0, M)
width = int(numpy.floor(alpha * (M - 1) / 2.0))
n1 = n[0 : width + 1]
n2 = n[width + 1 : M - width - 1]
n3 = n[M - width - 1 :]
w1 = 0.5 * (1 + T.cos(numpy.pi * (-1 + 2.0 * n1 / alpha / (M - 1))))
w2 = T.ones(n2.shape)
w3 = 0.5 * (1 + T.cos(numpy.pi * (-2.0 / alpha + 1 + 2.0 * n3 / alpha / (M - 1))))
w = T.concatenate((w1, w2, w3))
return w
示例11: freq_to_mel
# 需要导入模块: import jax [as 别名]
# 或者: from jax import numpy [as 别名]
def freq_to_mel(f, option="linear"):
# convert frequency to mel with
if option == "linear":
# linear part slope
f_sp = 200.0 / 3
# Fill in the log-scale part
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = min_log_hz / f_sp # same (Mels)
logstep = numpy.log(6.4) / 27.0 # step size for log region
mel = min_log_mel + T.log(f / min_log_hz) / logstep
return T.where(f >= min_log_hz, mel, f / f_sp)
else:
return 2595 * T.log10(1 + f / 700)
示例12: isvar
# 需要导入模块: import jax [as 别名]
# 或者: from jax import numpy [as 别名]
def isvar(item):
""" check whether an item (possibly a nested list etc) contains a variable
(any subtype of Tensor) """
# in case of nested lists/tuples, recursively call the function on it
if isinstance(item, slice):
return False
elif isinstance(item, list) or isinstance(item, tuple):
return numpy.sum([isvar(value) for value in item])
# otherwise cheack that it is a subtype of Tensor or a Tracer and not
# a callable
else:
cond1 = isinstance(item, Tensor) or type(item) in [Constant, OpTuple]
# cond2 = isinstance(item, jax.interpreters.partial_eval.JaxprTracer)
cond3 = callable(item)
return cond1 and not cond3 # (cond1 or cond2) and cond3
示例13: update_numpydoc
# 需要导入模块: import jax [as 别名]
# 或者: from jax import numpy [as 别名]
def update_numpydoc(docstr, fun, op):
"""Transforms the numpy docstring to remove references of
parameters that are supported by the numpy version but not the JAX version"""
# Some numpy functions have an extra tab at the beginning of each line,
# If this function is one of those we remove this extra tab from all the lines
if not hasattr(op, "__code__"):
return docstr
if docstr[:4] == " ":
lines = docstr.split("\n")
for idx, line in enumerate(lines):
lines[idx] = line.replace(" ", "", 1)
docstr = "\n".join(lines)
begin_idx = docstr.find("Parameters")
begin_idx = docstr.find("--\n", begin_idx) + 2
end_idx = docstr.find("Returns", begin_idx)
parameters = docstr[begin_idx:end_idx]
param_list = parameters.replace("\n ", "@@").split("\n")
for idx, p in enumerate(param_list):
param = p[: p.find(" : ")].split(", ")[0]
if param not in op.__code__.co_varnames:
param_list[idx] = ""
param_list = [param for param in param_list if param != ""]
parameters = "\n".join(param_list).replace("@@", "\n ")
return docstr[: begin_idx + 1] + parameters + docstr[end_idx - 2 :]
示例14: inv
# 需要导入模块: import jax [as 别名]
# 或者: from jax import numpy [as 别名]
def inv(self, matrix: Tensor) -> Tensor:
if len(matrix.shape) > 2:
raise ValueError("input to numpy backend method `inv` has shape {}."
" Only matrices are supported.".format(matrix.shape))
return jnp.linalg.inv(matrix)
示例15: expm
# 需要导入模块: import jax [as 别名]
# 或者: from jax import numpy [as 别名]
def expm(self, matrix: Tensor) -> Tensor:
if len(matrix.shape) != 2:
raise ValueError("input to numpy backend method `expm` has shape {}."
" Only matrices are supported.".format(matrix.shape))
if matrix.shape[0] != matrix.shape[1]:
raise ValueError("input to numpy backend method `expm` only supports"
" N*N matrix, {x}*{y} matrix is given".format(
x=matrix.shape[0], y=matrix.shape[1]))
# pylint: disable=no-member
return jsp.linalg.expm(matrix)