当前位置: 首页>>代码示例>>Python>>正文


Python numpy.transpose方法代码示例

本文整理汇总了Python中jax.numpy.transpose方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.transpose方法的具体用法?Python numpy.transpose怎么用?Python numpy.transpose使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在jax.numpy的用法示例。


在下文中一共展示了numpy.transpose方法的11个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: random_tensors

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import transpose [as 别名]
def random_tensors(request):
  D = request.param
  key = jax.random.PRNGKey(0)

  h = jax.random.normal(key, shape=[D**3] * 2)
  h = 0.5 * (h + np.conj(np.transpose(h)))
  h = np.reshape(h, [D] * 6)

  s = jax.random.normal(key, shape=[D**3] * 2)
  s = s @ np.conj(np.transpose(s))
  s /= np.trace(s)
  s = np.reshape(s, [D] * 6)

  a = jax.random.normal(key, shape=[D**2] * 2)
  u, _, vh = np.linalg.svd(a)
  dis = np.reshape(u, [D] * 4)
  iso = np.reshape(vh, [D] * 4)[:, :, :, 0]

  return tuple(x.astype(np.complex128) for x in (h, s, iso, dis)) 
开发者ID:google,项目名称:TensorNetwork,代码行数:21,代码来源:simple_mera_test.py

示例2: _jax_scan

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import transpose [as 别名]
def _jax_scan(f, xs, init_value, axis=0, remat=False):
  """Scans the f over the given axis of xs.

  In pseudo-python, the scan function would look as follows:

  def scan(f, xs, init_value, axis):
    xs  = [xs[..., i, ...] for i in range(xs.shape[axis])]
    cur_value = init_value
    ys = []
    for x in xs:
      y, cur_value = f(x, cur_value)
      ys.append(y)
    return np.stack(ys, axis), cur_value

  Args:
    f: function (x, carry) -> (y, new_carry)
    xs: tensor, x will be xs slices on axis
    init_value: tensor, initial value of the carry-over
    axis: int, the axis on which to slice xs
    remat: whether to re-materialize f

  Returns:
    A pair (ys, last_value) as described above.
  """
  def swapaxes(x):
    transposed_axes = list(range(len(x.shape)))
    transposed_axes[axis] = 0
    transposed_axes[0] = axis
    return jnp.transpose(x, axes=transposed_axes)
  if axis != 0:
    xs = nested_map(swapaxes, xs)
  def transposed_f(c, x):
    y, d = f(x, c)
    return d, y
  if remat:
    last_value, ys = lax.scan(jax.remat(transposed_f), init_value, xs)
  else:
    last_value, ys = lax.scan(transposed_f, init_value, xs)
  if axis != 0:
    ys = nested_map(swapaxes, ys)
  return ys, last_value 
开发者ID:google,项目名称:trax,代码行数:43,代码来源:jax.py

示例3: test_descend

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import transpose [as 别名]
def test_descend(random_tensors):
  h, s, iso, dis = random_tensors
  s = simple_mera.descend(h, s, iso, dis)
  assert len(s.shape) == 6
  D = s.shape[0]
  smat = np.reshape(s, [D**3] * 2)
  assert np.isclose(np.trace(smat), 1.0)
  assert np.isclose(np.linalg.norm(smat - np.conj(np.transpose(smat))), 0.0)
  spec, _ = np.linalg.eigh(smat)
  assert np.alltrue(spec >= 0.0) 
开发者ID:google,项目名称:TensorNetwork,代码行数:12,代码来源:simple_mera_test.py

示例4: test_ascend

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import transpose [as 别名]
def test_ascend(random_tensors):
  h, s, iso, dis = random_tensors
  h = simple_mera.ascend(h, s, iso, dis)
  assert len(h.shape) == 6
  D = h.shape[0]
  hmat = np.reshape(h, [D**3] * 2)
  norm = np.linalg.norm(hmat - np.conj(np.transpose(hmat)))
  assert np.isclose(norm, 0.0) 
开发者ID:google,项目名称:TensorNetwork,代码行数:10,代码来源:simple_mera_test.py

示例5: wavelet_tensors

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import transpose [as 别名]
def wavelet_tensors(request):
  """Returns the Hamiltonian and MERA tensors for the D=2 wavelet MERA.

  From Evenbly & White, Phys. Rev. Lett. 116, 140403 (2016).
  """
  D = 2
  h = simple_mera.ham_ising()

  E = np.array([[1, 0], [0, 1]])
  X = np.array([[0, 1], [1, 0]])
  Y = np.array([[0, -1j], [1j, 0]])
  Z = np.array([[1, 0], [0, -1]])

  wmat_un = np.real((np.sqrt(3) + np.sqrt(2)) / 4 * np.kron(E, E) +
                    (np.sqrt(3) - np.sqrt(2)) / 4 * np.kron(Z, Z) + 1.j *
                    (1 + np.sqrt(2)) / 4 * np.kron(X, Y) + 1.j *
                    (1 - np.sqrt(2)) / 4 * np.kron(Y, X))

  umat = np.real((np.sqrt(3) + 2) / 4 * np.kron(E, E) +
                 (np.sqrt(3) - 2) / 4 * np.kron(Z, Z) +
                 1.j / 4 * np.kron(X, Y) + 1.j / 4 * np.kron(Y, X))

  w = np.reshape(wmat_un, (D, D, D, D))[:, 0, :, :]
  u = np.reshape(umat, (D, D, D, D))

  w = np.transpose(w, [1, 2, 0])
  u = np.transpose(u, [2, 3, 0, 1])

  return tuple(x.astype(np.complex128) for x in (h, w, u)) 
开发者ID:google,项目名称:TensorNetwork,代码行数:31,代码来源:simple_mera_test.py

示例6: compute_singleton_mean_variance

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import transpose [as 别名]
def compute_singleton_mean_variance(X, Y, dimension, msq, lam, eta1, xisq, c, var_obs):
    P, N = X.shape[1], X.shape[0]

    probe = jnp.zeros((2, P))
    probe = jax.ops.index_update(probe, jax.ops.index[:, dimension], jnp.array([1.0, -1.0]))

    eta2 = jnp.square(eta1) * jnp.sqrt(xisq) / msq
    kappa = jnp.sqrt(msq) * lam / jnp.sqrt(msq + jnp.square(eta1 * lam))

    kX = kappa * X
    kprobe = kappa * probe

    k_xx = kernel(kX, kX, eta1, eta2, c) + var_obs * jnp.eye(N)
    k_xx_inv = jnp.linalg.inv(k_xx)
    k_probeX = kernel(kprobe, kX, eta1, eta2, c)
    k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c)

    vec = jnp.array([0.50, -0.50])
    mu = jnp.matmul(k_probeX, jnp.matmul(k_xx_inv, Y))
    mu = jnp.dot(mu, vec)

    var = k_prbprb - jnp.matmul(k_probeX, jnp.matmul(k_xx_inv, jnp.transpose(k_probeX)))
    var = jnp.matmul(var, vec)
    var = jnp.dot(var, vec)

    return mu, var


# Compute the mean and variance of coefficient theta_ij for a MCMC sample of the
# kernel hyperparameters (eta1, xisq, ...). Compare to theorem 5.1 in reference [1]. 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:32,代码来源:sparse_regression.py

示例7: compute_pairwise_mean_variance

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import transpose [as 别名]
def compute_pairwise_mean_variance(X, Y, dim1, dim2, msq, lam, eta1, xisq, c, var_obs):
    P, N = X.shape[1], X.shape[0]

    probe = jnp.zeros((4, P))
    probe = jax.ops.index_update(probe, jax.ops.index[:, dim1], jnp.array([1.0, 1.0, -1.0, -1.0]))
    probe = jax.ops.index_update(probe, jax.ops.index[:, dim2], jnp.array([1.0, -1.0, 1.0, -1.0]))

    eta2 = jnp.square(eta1) * jnp.sqrt(xisq) / msq
    kappa = jnp.sqrt(msq) * lam / jnp.sqrt(msq + jnp.square(eta1 * lam))

    kX = kappa * X
    kprobe = kappa * probe

    k_xx = kernel(kX, kX, eta1, eta2, c) + var_obs * jnp.eye(N)
    k_xx_inv = jnp.linalg.inv(k_xx)
    k_probeX = kernel(kprobe, kX, eta1, eta2, c)
    k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c)

    vec = jnp.array([0.25, -0.25, -0.25, 0.25])
    mu = jnp.matmul(k_probeX, jnp.matmul(k_xx_inv, Y))
    mu = jnp.dot(mu, vec)

    var = k_prbprb - jnp.matmul(k_probeX, jnp.matmul(k_xx_inv, jnp.transpose(k_probeX)))
    var = jnp.matmul(var, vec)
    var = jnp.dot(var, vec)

    return mu, var


# Sample coefficients theta from the posterior for a given MCMC sample.
# The first P returned values are {theta_1, theta_2, ...., theta_P}, while
# the remaining values are {theta_ij} for i,j in the list `active_dims`,
# sorted so that i < j. 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:35,代码来源:sparse_regression.py

示例8: predict

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import transpose [as 别名]
def predict(rng_key, X, Y, X_test, var, length, noise):
    # compute kernels between train and test data, etc.
    k_pp = kernel(X_test, X_test, var, length, noise, include_noise=True)
    k_pX = kernel(X_test, X, var, length, noise, include_noise=False)
    k_XX = kernel(X, X, var, length, noise, include_noise=True)
    K_xx_inv = jnp.linalg.inv(k_XX)
    K = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX)))
    sigma_noise = jnp.sqrt(jnp.clip(jnp.diag(K), a_min=0.)) * jax.random.normal(rng_key, X_test.shape[:1])
    mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, Y))
    # we return both the mean function and a sample from the posterior predictive for the
    # given set of hyperparameters
    return mean, mean + sigma_noise


# create artificial regression dataset 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:17,代码来源:gp.py

示例9: _triangular_solve

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import transpose [as 别名]
def _triangular_solve(x, y, upper=False, transpose=False):
    assert np.ndim(x) >= 2 and np.ndim(y) >= 2
    n, m = x.shape[-2:]
    assert y.shape[-2:] == (n, n)
    # NB: JAX requires x and y have the same batch_shape
    batch_shape = lax.broadcast_shapes(x.shape[:-2], y.shape[:-2])
    x = np.broadcast_to(x, batch_shape + (n, m))
    if y.shape[:-2] == batch_shape:
        return solve_triangular(y, x, trans=int(transpose), lower=not upper)

    # The following procedure handles the case: y.shape = (i, 1, n, n), x.shape = (..., i, j, n, m)
    # because we don't want to broadcast y to the shape (i, j, n, n).
    # We are going to make x have shape (..., 1, j,  i, 1, n) to apply batched triangular_solve
    dx = x.ndim
    prepend_ndim = dx - y.ndim  # ndim of ... part
    # Reshape x with the shape (..., 1, i, j, 1, n, m)
    x_new_shape = batch_shape[:prepend_ndim]
    for (sy, sx) in zip(y.shape[:-2], batch_shape[prepend_ndim:]):
        x_new_shape += (sx // sy, sy)
    x_new_shape += (n, m,)
    x = np.reshape(x, x_new_shape)
    # Permute y to make it have shape (..., 1, j, m, i, 1, n)
    batch_ndim = x.ndim - 2
    permute_dims = (tuple(range(prepend_ndim))
                    + tuple(range(prepend_ndim, batch_ndim, 2))
                    + (batch_ndim + 1,)
                    + tuple(range(prepend_ndim + 1, batch_ndim, 2))
                    + (batch_ndim,))
    x = np.transpose(x, permute_dims)
    x_permute_shape = x.shape

    # reshape to (-1, i, 1, n)
    x = np.reshape(x, (-1,) + y.shape[:-1])
    # permute to (i, 1, n, -1)
    x = np.moveaxis(x, 0, -1)

    sol = solve_triangular(y, x, trans=int(transpose), lower=not upper)  # shape: (i, 1, n, -1)
    sol = np.moveaxis(sol, -1, 0)  # shape: (-1, i, 1, n)
    sol = np.reshape(sol, x_permute_shape)  # shape: (..., 1, j, m, i, 1, n)

    # now we permute back to x_new_shape = (..., 1, i, j, 1, n, m)
    permute_inv_dims = tuple(range(prepend_ndim))
    for i in range(y.ndim - 2):
        permute_inv_dims += (prepend_ndim + i, dx + i - 1)
    permute_inv_dims += (sol.ndim - 1, prepend_ndim + y.ndim - 2)
    sol = np.transpose(sol, permute_inv_dims)
    return sol.reshape(batch_shape + (n, m)) 
开发者ID:pyro-ppl,项目名称:funsor,代码行数:49,代码来源:ops.py

示例10: sample_theta_space

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import transpose [as 别名]
def sample_theta_space(X, Y, active_dims, msq, lam, eta1, xisq, c, var_obs):
    P, N, M = X.shape[1], X.shape[0], len(active_dims)
    # the total number of coefficients we return
    num_coefficients = P + M * (M - 1) // 2

    probe = jnp.zeros((2 * P + 2 * M * (M - 1), P))
    vec = jnp.zeros((num_coefficients, 2 * P + 2 * M * (M - 1)))
    start1 = 0
    start2 = 0

    for dim in range(P):
        probe = jax.ops.index_update(probe, jax.ops.index[start1:start1 + 2, dim], jnp.array([1.0, -1.0]))
        vec = jax.ops.index_update(vec, jax.ops.index[start2, start1:start1 + 2], jnp.array([0.5, -0.5]))
        start1 += 2
        start2 += 1

    for dim1 in active_dims:
        for dim2 in active_dims:
            if dim1 >= dim2:
                continue
            probe = jax.ops.index_update(probe, jax.ops.index[start1:start1 + 4, dim1],
                                         jnp.array([1.0, 1.0, -1.0, -1.0]))
            probe = jax.ops.index_update(probe, jax.ops.index[start1:start1 + 4, dim2],
                                         jnp.array([1.0, -1.0, 1.0, -1.0]))
            vec = jax.ops.index_update(vec, jax.ops.index[start2, start1:start1 + 4],
                                       jnp.array([0.25, -0.25, -0.25, 0.25]))
            start1 += 4
            start2 += 1

    eta2 = jnp.square(eta1) * jnp.sqrt(xisq) / msq
    kappa = jnp.sqrt(msq) * lam / jnp.sqrt(msq + jnp.square(eta1 * lam))

    kX = kappa * X
    kprobe = kappa * probe

    k_xx = kernel(kX, kX, eta1, eta2, c) + var_obs * jnp.eye(N)
    k_xx_inv = jnp.linalg.inv(k_xx)
    k_probeX = kernel(kprobe, kX, eta1, eta2, c)
    k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c)

    mu = jnp.matmul(k_probeX, jnp.matmul(k_xx_inv, Y))
    mu = jnp.sum(mu * vec, axis=-1)

    covar = k_prbprb - jnp.matmul(k_probeX, jnp.matmul(k_xx_inv, jnp.transpose(k_probeX)))
    covar = jnp.matmul(vec, jnp.matmul(covar, jnp.transpose(vec)))
    L = jnp.linalg.cholesky(covar)

    # sample from N(mu, covar)
    sample = mu + jnp.matmul(L, np.random.randn(num_coefficients))

    return sample


# Helper function for doing HMC inference 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:56,代码来源:sparse_regression.py

示例11: _batch_mahalanobis

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import transpose [as 别名]
def _batch_mahalanobis(bL, bx):
    if bL.shape[:-1] == bx.shape:
        # no need to use the below optimization procedure
        solve_bL_bx = solve_triangular(bL, bx[..., None], lower=True).squeeze(-1)
        return jnp.sum(jnp.square(solve_bL_bx), -1)

    # NB: The following procedure handles the case: bL.shape = (i, 1, n, n), bx.shape = (i, j, n)
    # because we don't want to broadcast bL to the shape (i, j, n, n).

    # Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n),
    # we are going to make bx have shape (..., 1, j,  i, 1, n) to apply batched tril_solve
    sample_ndim = bx.ndim - bL.ndim + 1  # size of sample_shape
    out_shape = jnp.shape(bx)[:-1]  # shape of output
    # Reshape bx with the shape (..., 1, i, j, 1, n)
    bx_new_shape = out_shape[:sample_ndim]
    for (sL, sx) in zip(bL.shape[:-2], out_shape[sample_ndim:]):
        bx_new_shape += (sx // sL, sL)
    bx_new_shape += (-1,)
    bx = jnp.reshape(bx, bx_new_shape)
    # Permute bx to make it have shape (..., 1, j, i, 1, n)
    permute_dims = (tuple(range(sample_ndim))
                    + tuple(range(sample_ndim, bx.ndim - 1, 2))
                    + tuple(range(sample_ndim + 1, bx.ndim - 1, 2))
                    + (bx.ndim - 1,))
    bx = jnp.transpose(bx, permute_dims)

    # reshape to (-1, i, 1, n)
    xt = jnp.reshape(bx, (-1,) + bL.shape[:-1])
    # permute to (i, 1, n, -1)
    xt = jnp.moveaxis(xt, 0, -1)
    solve_bL_bx = solve_triangular(bL, xt, lower=True)  # shape: (i, 1, n, -1)
    M = jnp.sum(solve_bL_bx ** 2, axis=-2)  # shape: (i, 1, -1)
    # permute back to (-1, i, 1)
    M = jnp.moveaxis(M, -1, 0)
    # reshape back to (..., 1, j, i, 1)
    M = jnp.reshape(M, bx.shape[:-1])
    # permute back to (..., 1, i, j, 1)
    permute_inv_dims = tuple(range(sample_ndim))
    for i in range(bL.ndim - 2):
        permute_inv_dims += (sample_ndim + i, len(out_shape) + i)
    M = jnp.transpose(M, permute_inv_dims)
    return jnp.reshape(M, out_shape) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:44,代码来源:continuous.py


注:本文中的jax.numpy.transpose方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。