当前位置: 首页>>代码示例>>Python>>正文


Python numpy.reshape方法代码示例

本文整理汇总了Python中jax.numpy.reshape方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.reshape方法的具体用法?Python numpy.reshape怎么用?Python numpy.reshape使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在jax.numpy的用法示例。


在下文中一共展示了numpy.reshape方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: random_tensors

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import reshape [as 别名]
def random_tensors(request):
  D = request.param
  key = jax.random.PRNGKey(0)

  h = jax.random.normal(key, shape=[D**3] * 2)
  h = 0.5 * (h + np.conj(np.transpose(h)))
  h = np.reshape(h, [D] * 6)

  s = jax.random.normal(key, shape=[D**3] * 2)
  s = s @ np.conj(np.transpose(s))
  s /= np.trace(s)
  s = np.reshape(s, [D] * 6)

  a = jax.random.normal(key, shape=[D**2] * 2)
  u, _, vh = np.linalg.svd(a)
  dis = np.reshape(u, [D] * 4)
  iso = np.reshape(vh, [D] * 4)[:, :, :, 0]

  return tuple(x.astype(np.complex128) for x in (h, s, iso, dis)) 
开发者ID:google,项目名称:TensorNetwork,代码行数:21,代码来源:simple_mera_test.py

示例2: shift_ham

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import reshape [as 别名]
def shift_ham(hamiltonian, shift=None):
  """Applies a shift to a hamiltonian.

  Args:
    hamiltonian: The hamiltonian tensor (rank 6).
    shift: The amount by which to shift. If `None`, shifts so that the local
      term is negative semi-definite.

  Returns:
    The shifted Hamiltonian.
  """
  hmat = np.reshape(hamiltonian, (2**3, -1))
  if shift is None:
    shift = np.amax(np.linalg.eigh(hmat)[0])
  hmat -= shift * np.eye(2**3)
  return np.reshape(hmat, [2] * 6) 
开发者ID:google,项目名称:TensorNetwork,代码行数:18,代码来源:simple_mera.py

示例3: astensor

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import reshape [as 别名]
def astensor(self, tensor_in, dtype='float'):
        """
        Convert to a JAX ndarray.

        Args:
            tensor_in (Number or Tensor): Tensor object

        Returns:
            `jax.interpreters.xla.DeviceArray`: A multi-dimensional, fixed-size homogenous array.
        """
        try:
            dtype = self.dtypemap[dtype]
        except KeyError:
            log.error('Invalid dtype: dtype must be float, int, or bool.')
            raise
        tensor = np.asarray(tensor_in, dtype=dtype)
        # Ensure non-empty tensor shape for consistency
        try:
            tensor.shape[0]
        except IndexError:
            tensor = np.reshape(tensor, [1])
        return np.asarray(tensor, dtype=dtype) 
开发者ID:scikit-hep,项目名称:pyhf,代码行数:24,代码来源:jax_backend.py

示例4: _multinomial

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import reshape [as 别名]
def _multinomial(key, p, n, n_max, shape=()):
    if jnp.shape(n) != jnp.shape(p)[:-1]:
        broadcast_shape = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1])
        n = jnp.broadcast_to(n, broadcast_shape)
        p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:])
    shape = shape or p.shape[:-1]
    # get indices from categorical distribution then gather the result
    indices = categorical(key, p, (n_max,) + shape)
    # mask out values when counts is heterogeneous
    if jnp.ndim(n) > 0:
        mask = promote_shapes(jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,))[0]
        mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype)
        excess = jnp.concatenate([jnp.expand_dims(n_max - n, -1), jnp.zeros(jnp.shape(n) + (p.shape[-1] - 1,))], -1)
    else:
        mask = 1
        excess = 0
    # NB: we transpose to move batch shape to the front
    indices_2D = (jnp.reshape(indices * mask, (n_max, -1,))).T
    samples_2D = vmap(_scatter_add_one, (0, 0, 0))(jnp.zeros((indices_2D.shape[0], p.shape[-1]),
                                                             dtype=indices.dtype),
                                                   jnp.expand_dims(indices_2D, axis=-1),
                                                   jnp.ones(indices_2D.shape, dtype=indices.dtype))
    return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:25,代码来源:util.py

示例5: serialize

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import reshape [as 别名]
def serialize(self, data):
    array = data
    batch_size = array.shape[0]
    array = (array - self._space.low) / (self._space.high - self._space.low)
    array = np.clip(array, 0, 1)
    digits = []
    for digit_index in range(-1, -self._precision - 1, -1):
      threshold = self._vocab_size ** digit_index
      digit = np.array(array / threshold).astype(np.int32)
      # For the corner case of x == high.
      digit = np.where(digit == self._vocab_size, digit - 1, digit)
      digits.append(digit)
      array -= digit * threshold
    digits = np.stack(digits, axis=-1)
    return np.reshape(digits, (batch_size, -1)) 
开发者ID:google,项目名称:trax,代码行数:17,代码来源:space_serializer.py

示例6: deserialize

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import reshape [as 别名]
def deserialize(self, representation):
    digits = representation
    batch_size = digits.shape[0]
    digits = np.reshape(digits, (batch_size, -1, self._precision))
    array = np.zeros(digits.shape[:-1])
    for digit_index_in_seq in range(self._precision):
      digit_index = -digit_index_in_seq - 1
      array += self._vocab_size ** digit_index * digits[..., digit_index_in_seq]
    array = np.reshape(array, (batch_size,) + self._space.shape)
    return array * (self._space.high - self._space.low) + self._space.low 
开发者ID:google,项目名称:trax,代码行数:12,代码来源:space_serializer.py

示例7: significance_map

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import reshape [as 别名]
def significance_map(self):
    return np.reshape(np.broadcast_to(
        np.arange(self._precision), self._space.shape + (self._precision,)), -1) 
开发者ID:google,项目名称:trax,代码行数:5,代码来源:space_serializer.py

示例8: flatten

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import reshape [as 别名]
def flatten(x):
    return np.reshape(x, (x.shape[0], -1)) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:4,代码来源:modules.py

示例9: mnist_images

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import reshape [as 别名]
def mnist_images():
    # https://github.com/google/jax/blob/master/docs/gpu_memory_allocation.rst
    import tensorflow as tf
    tf.config.experimental.set_visible_devices([], "GPU")

    import tensorflow_datasets as tfds
    prep = lambda d: np.reshape(np.float32(next(tfds.as_numpy(d))['image']) / 256, (-1, 784))
    dataset = tfds.load("mnist:1.0.0")
    return (prep(dataset['train'].shuffle(50000).batch(50000)),
            prep(dataset['test'].batch(10000))) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:12,代码来源:mnist_vae.py

示例10: image_grid

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import reshape [as 别名]
def image_grid(nrow, ncol, imagevecs, imshape):
    """Reshape a stack of image vectors into an image grid for plotting."""
    images = iter(imagevecs.reshape((-1,) + imshape))
    return np.vstack([np.hstack([next(images).T for _ in range(ncol)][::-1])
                      for _ in range(nrow)]).T 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:7,代码来源:mnist_vae.py

示例11: read_dataset

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import reshape [as 别名]
def read_dataset():
    import sets
    dataset = sets.Ocr()
    dataset = sets.OneHot(dataset.target, depth=2)(dataset, columns=['target'])
    dataset['data'] = dataset.data.reshape(dataset.data.shape[:-2] + (-1,)).astype(float)
    return sets.Split(0.66)(dataset) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:8,代码来源:ocr_rnn.py

示例12: test_ocr_rnn

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import reshape [as 别名]
def test_ocr_rnn():
    length = 5
    carry_size = 3
    class_count = 4
    inputs = jnp.zeros((1, length, 4))

    def rnn(): return Rnn(*GRUCell(carry_size, zeros))

    net = Sequential(
        rnn(),
        rnn(),
        rnn(),
        lambda x: jnp.reshape(x, (-1, carry_size)),  # -> same weights for all time steps
        Dense(class_count, zeros, zeros),
        softmax,
        lambda x: jnp.reshape(x, (-1, length, class_count)))

    params = net.init_parameters(inputs, key=PRNGKey(0))

    assert len(params) == 4
    cell = params.rnn0.gru_cell
    assert len(cell) == 3
    assert jnp.array_equal(jnp.zeros((7, 3)), cell.update_kernel)
    assert jnp.array_equal(jnp.zeros((7, 3)), cell.reset_kernel)
    assert jnp.array_equal(jnp.zeros((7, 3)), cell.compute_kernel)

    out = net.apply(params, inputs)

    @parametrized
    def cross_entropy(images, targets):
        prediction = net(images)
        return jnp.mean(-jnp.sum(targets * jnp.log(prediction), (1, 2)))

    opt = optimizers.RmsProp(0.003)
    state = opt.init(cross_entropy.init_parameters(inputs, out, key=PRNGKey(0)))
    state = opt.update(cross_entropy.apply, state, inputs, out)
    opt.update(cross_entropy.apply, state, inputs, out, jit=True) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:39,代码来源:test_examples.py

示例13: _extract_signal_patches

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import reshape [as 别名]
def _extract_signal_patches(signal, window_length, hop=1, data_format="NCW"):
    assert not hasattr(window_length, "__len__")
    assert signal.ndim == 3
    if data_format == "NCW":
        N = (signal.shape[2] - window_length) // hop + 1
        indices = jnp.arange(window_length) + jnp.expand_dims(jnp.arange(N) * hop, 1)
        indices = jnp.reshape(indices, [1, 1, N * window_length])
        patches = jnp.take_along_axis(signal, indices, 2)
        return jnp.reshape(patches, signal.shape[:2] + (N, window_length))
    else:
        error 
开发者ID:SymJAX,项目名称:SymJAX,代码行数:13,代码来源:ops_special.py

示例14: _extract_image_patches

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import reshape [as 别名]
def _extract_image_patches(
    image, window_shape, hop=1, data_format="NCHW", mode="valid"
):
    if mode == "same":
        p1 = window_shape[0] - 1
        p2 = window_shape[1] - 1
        image = jnp.pad(
            image, [(0, 0), (0, 0), (p1 // 2, p1 - p1 // 2), (p2 // 2, p2 - p2 // 2)]
        )
    if not hasattr(hop, "__len__"):
        hop = (hop, hop)
    if data_format == "NCHW":

        # compute the number of windows in both dimensions
        N = (
            (image.shape[2] - window_shape[0]) // hop[0] + 1,
            (image.shape[3] - window_shape[1]) // hop[1] + 1,
        )

        # compute the base indices of a 2d patch
        patch = jnp.arange(numpy.prod(window_shape)).reshape(window_shape)
        offset = jnp.expand_dims(jnp.arange(window_shape[0]), 1)
        patch_indices = patch + offset * (image.shape[3] - window_shape[1])

        # create all the shifted versions of it
        ver_shifts = jnp.reshape(
            jnp.arange(N[0]) * hop[0] * image.shape[3], (-1, 1, 1, 1)
        )
        hor_shifts = jnp.reshape(jnp.arange(N[1]) * hop[1], (-1, 1, 1))
        all_cols = patch_indices + jnp.reshape(jnp.arange(N[1]) * hop[1], (-1, 1, 1))
        indices = patch_indices + ver_shifts + hor_shifts

        # now extract shape (1, 1, H'W'a'b')
        flat_indices = jnp.reshape(indices, [1, 1, -1])
        # shape is now (N, C, W*H)
        flat_image = jnp.reshape(image, (image.shape[0], image.shape[1], -1))
        # shape is now (N, C)
        patches = jnp.take_along_axis(flat_image, flat_indices, 2)
        return jnp.reshape(patches, image.shape[:2] + N + tuple(window_shape))
    else:
        error 
开发者ID:SymJAX,项目名称:SymJAX,代码行数:43,代码来源:ops_special.py

示例15: reshape

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import reshape [as 别名]
def reshape(self, x, shape):
        if not isinstance(x, int):
            shape = tuple(shape)
        shape = tuple(-1 if s is None else s for s in shape)
        return np.reshape(x, tuple(map(int, shape))) 
开发者ID:sharadmv,项目名称:deepx,代码行数:7,代码来源:jax.py


注:本文中的jax.numpy.reshape方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。