本文整理汇总了Python中jax.numpy.reshape方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.reshape方法的具体用法?Python numpy.reshape怎么用?Python numpy.reshape使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类jax.numpy
的用法示例。
在下文中一共展示了numpy.reshape方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: random_tensors
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import reshape [as 别名]
def random_tensors(request):
D = request.param
key = jax.random.PRNGKey(0)
h = jax.random.normal(key, shape=[D**3] * 2)
h = 0.5 * (h + np.conj(np.transpose(h)))
h = np.reshape(h, [D] * 6)
s = jax.random.normal(key, shape=[D**3] * 2)
s = s @ np.conj(np.transpose(s))
s /= np.trace(s)
s = np.reshape(s, [D] * 6)
a = jax.random.normal(key, shape=[D**2] * 2)
u, _, vh = np.linalg.svd(a)
dis = np.reshape(u, [D] * 4)
iso = np.reshape(vh, [D] * 4)[:, :, :, 0]
return tuple(x.astype(np.complex128) for x in (h, s, iso, dis))
示例2: shift_ham
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import reshape [as 别名]
def shift_ham(hamiltonian, shift=None):
"""Applies a shift to a hamiltonian.
Args:
hamiltonian: The hamiltonian tensor (rank 6).
shift: The amount by which to shift. If `None`, shifts so that the local
term is negative semi-definite.
Returns:
The shifted Hamiltonian.
"""
hmat = np.reshape(hamiltonian, (2**3, -1))
if shift is None:
shift = np.amax(np.linalg.eigh(hmat)[0])
hmat -= shift * np.eye(2**3)
return np.reshape(hmat, [2] * 6)
示例3: astensor
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import reshape [as 别名]
def astensor(self, tensor_in, dtype='float'):
"""
Convert to a JAX ndarray.
Args:
tensor_in (Number or Tensor): Tensor object
Returns:
`jax.interpreters.xla.DeviceArray`: A multi-dimensional, fixed-size homogenous array.
"""
try:
dtype = self.dtypemap[dtype]
except KeyError:
log.error('Invalid dtype: dtype must be float, int, or bool.')
raise
tensor = np.asarray(tensor_in, dtype=dtype)
# Ensure non-empty tensor shape for consistency
try:
tensor.shape[0]
except IndexError:
tensor = np.reshape(tensor, [1])
return np.asarray(tensor, dtype=dtype)
示例4: _multinomial
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import reshape [as 别名]
def _multinomial(key, p, n, n_max, shape=()):
if jnp.shape(n) != jnp.shape(p)[:-1]:
broadcast_shape = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1])
n = jnp.broadcast_to(n, broadcast_shape)
p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:])
shape = shape or p.shape[:-1]
# get indices from categorical distribution then gather the result
indices = categorical(key, p, (n_max,) + shape)
# mask out values when counts is heterogeneous
if jnp.ndim(n) > 0:
mask = promote_shapes(jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,))[0]
mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype)
excess = jnp.concatenate([jnp.expand_dims(n_max - n, -1), jnp.zeros(jnp.shape(n) + (p.shape[-1] - 1,))], -1)
else:
mask = 1
excess = 0
# NB: we transpose to move batch shape to the front
indices_2D = (jnp.reshape(indices * mask, (n_max, -1,))).T
samples_2D = vmap(_scatter_add_one, (0, 0, 0))(jnp.zeros((indices_2D.shape[0], p.shape[-1]),
dtype=indices.dtype),
jnp.expand_dims(indices_2D, axis=-1),
jnp.ones(indices_2D.shape, dtype=indices.dtype))
return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess
示例5: serialize
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import reshape [as 别名]
def serialize(self, data):
array = data
batch_size = array.shape[0]
array = (array - self._space.low) / (self._space.high - self._space.low)
array = np.clip(array, 0, 1)
digits = []
for digit_index in range(-1, -self._precision - 1, -1):
threshold = self._vocab_size ** digit_index
digit = np.array(array / threshold).astype(np.int32)
# For the corner case of x == high.
digit = np.where(digit == self._vocab_size, digit - 1, digit)
digits.append(digit)
array -= digit * threshold
digits = np.stack(digits, axis=-1)
return np.reshape(digits, (batch_size, -1))
示例6: deserialize
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import reshape [as 别名]
def deserialize(self, representation):
digits = representation
batch_size = digits.shape[0]
digits = np.reshape(digits, (batch_size, -1, self._precision))
array = np.zeros(digits.shape[:-1])
for digit_index_in_seq in range(self._precision):
digit_index = -digit_index_in_seq - 1
array += self._vocab_size ** digit_index * digits[..., digit_index_in_seq]
array = np.reshape(array, (batch_size,) + self._space.shape)
return array * (self._space.high - self._space.low) + self._space.low
示例7: significance_map
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import reshape [as 别名]
def significance_map(self):
return np.reshape(np.broadcast_to(
np.arange(self._precision), self._space.shape + (self._precision,)), -1)
示例8: flatten
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import reshape [as 别名]
def flatten(x):
return np.reshape(x, (x.shape[0], -1))
示例9: mnist_images
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import reshape [as 别名]
def mnist_images():
# https://github.com/google/jax/blob/master/docs/gpu_memory_allocation.rst
import tensorflow as tf
tf.config.experimental.set_visible_devices([], "GPU")
import tensorflow_datasets as tfds
prep = lambda d: np.reshape(np.float32(next(tfds.as_numpy(d))['image']) / 256, (-1, 784))
dataset = tfds.load("mnist:1.0.0")
return (prep(dataset['train'].shuffle(50000).batch(50000)),
prep(dataset['test'].batch(10000)))
示例10: image_grid
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import reshape [as 别名]
def image_grid(nrow, ncol, imagevecs, imshape):
"""Reshape a stack of image vectors into an image grid for plotting."""
images = iter(imagevecs.reshape((-1,) + imshape))
return np.vstack([np.hstack([next(images).T for _ in range(ncol)][::-1])
for _ in range(nrow)]).T
示例11: read_dataset
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import reshape [as 别名]
def read_dataset():
import sets
dataset = sets.Ocr()
dataset = sets.OneHot(dataset.target, depth=2)(dataset, columns=['target'])
dataset['data'] = dataset.data.reshape(dataset.data.shape[:-2] + (-1,)).astype(float)
return sets.Split(0.66)(dataset)
示例12: test_ocr_rnn
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import reshape [as 别名]
def test_ocr_rnn():
length = 5
carry_size = 3
class_count = 4
inputs = jnp.zeros((1, length, 4))
def rnn(): return Rnn(*GRUCell(carry_size, zeros))
net = Sequential(
rnn(),
rnn(),
rnn(),
lambda x: jnp.reshape(x, (-1, carry_size)), # -> same weights for all time steps
Dense(class_count, zeros, zeros),
softmax,
lambda x: jnp.reshape(x, (-1, length, class_count)))
params = net.init_parameters(inputs, key=PRNGKey(0))
assert len(params) == 4
cell = params.rnn0.gru_cell
assert len(cell) == 3
assert jnp.array_equal(jnp.zeros((7, 3)), cell.update_kernel)
assert jnp.array_equal(jnp.zeros((7, 3)), cell.reset_kernel)
assert jnp.array_equal(jnp.zeros((7, 3)), cell.compute_kernel)
out = net.apply(params, inputs)
@parametrized
def cross_entropy(images, targets):
prediction = net(images)
return jnp.mean(-jnp.sum(targets * jnp.log(prediction), (1, 2)))
opt = optimizers.RmsProp(0.003)
state = opt.init(cross_entropy.init_parameters(inputs, out, key=PRNGKey(0)))
state = opt.update(cross_entropy.apply, state, inputs, out)
opt.update(cross_entropy.apply, state, inputs, out, jit=True)
示例13: _extract_signal_patches
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import reshape [as 别名]
def _extract_signal_patches(signal, window_length, hop=1, data_format="NCW"):
assert not hasattr(window_length, "__len__")
assert signal.ndim == 3
if data_format == "NCW":
N = (signal.shape[2] - window_length) // hop + 1
indices = jnp.arange(window_length) + jnp.expand_dims(jnp.arange(N) * hop, 1)
indices = jnp.reshape(indices, [1, 1, N * window_length])
patches = jnp.take_along_axis(signal, indices, 2)
return jnp.reshape(patches, signal.shape[:2] + (N, window_length))
else:
error
示例14: _extract_image_patches
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import reshape [as 别名]
def _extract_image_patches(
image, window_shape, hop=1, data_format="NCHW", mode="valid"
):
if mode == "same":
p1 = window_shape[0] - 1
p2 = window_shape[1] - 1
image = jnp.pad(
image, [(0, 0), (0, 0), (p1 // 2, p1 - p1 // 2), (p2 // 2, p2 - p2 // 2)]
)
if not hasattr(hop, "__len__"):
hop = (hop, hop)
if data_format == "NCHW":
# compute the number of windows in both dimensions
N = (
(image.shape[2] - window_shape[0]) // hop[0] + 1,
(image.shape[3] - window_shape[1]) // hop[1] + 1,
)
# compute the base indices of a 2d patch
patch = jnp.arange(numpy.prod(window_shape)).reshape(window_shape)
offset = jnp.expand_dims(jnp.arange(window_shape[0]), 1)
patch_indices = patch + offset * (image.shape[3] - window_shape[1])
# create all the shifted versions of it
ver_shifts = jnp.reshape(
jnp.arange(N[0]) * hop[0] * image.shape[3], (-1, 1, 1, 1)
)
hor_shifts = jnp.reshape(jnp.arange(N[1]) * hop[1], (-1, 1, 1))
all_cols = patch_indices + jnp.reshape(jnp.arange(N[1]) * hop[1], (-1, 1, 1))
indices = patch_indices + ver_shifts + hor_shifts
# now extract shape (1, 1, H'W'a'b')
flat_indices = jnp.reshape(indices, [1, 1, -1])
# shape is now (N, C, W*H)
flat_image = jnp.reshape(image, (image.shape[0], image.shape[1], -1))
# shape is now (N, C)
patches = jnp.take_along_axis(flat_image, flat_indices, 2)
return jnp.reshape(patches, image.shape[:2] + N + tuple(window_shape))
else:
error
示例15: reshape
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import reshape [as 别名]
def reshape(self, x, shape):
if not isinstance(x, int):
shape = tuple(shape)
shape = tuple(-1 if s is None else s for s in shape)
return np.reshape(x, tuple(map(int, shape)))