本文整理汇总了Python中jax.numpy.transpose方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.transpose方法的具体用法?Python numpy.transpose怎么用?Python numpy.transpose使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类jax.numpy
的用法示例。
在下文中一共展示了numpy.transpose方法的11个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: random_tensors
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import transpose [as 别名]
def random_tensors(request):
D = request.param
key = jax.random.PRNGKey(0)
h = jax.random.normal(key, shape=[D**3] * 2)
h = 0.5 * (h + np.conj(np.transpose(h)))
h = np.reshape(h, [D] * 6)
s = jax.random.normal(key, shape=[D**3] * 2)
s = s @ np.conj(np.transpose(s))
s /= np.trace(s)
s = np.reshape(s, [D] * 6)
a = jax.random.normal(key, shape=[D**2] * 2)
u, _, vh = np.linalg.svd(a)
dis = np.reshape(u, [D] * 4)
iso = np.reshape(vh, [D] * 4)[:, :, :, 0]
return tuple(x.astype(np.complex128) for x in (h, s, iso, dis))
示例2: _jax_scan
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import transpose [as 别名]
def _jax_scan(f, xs, init_value, axis=0, remat=False):
"""Scans the f over the given axis of xs.
In pseudo-python, the scan function would look as follows:
def scan(f, xs, init_value, axis):
xs = [xs[..., i, ...] for i in range(xs.shape[axis])]
cur_value = init_value
ys = []
for x in xs:
y, cur_value = f(x, cur_value)
ys.append(y)
return np.stack(ys, axis), cur_value
Args:
f: function (x, carry) -> (y, new_carry)
xs: tensor, x will be xs slices on axis
init_value: tensor, initial value of the carry-over
axis: int, the axis on which to slice xs
remat: whether to re-materialize f
Returns:
A pair (ys, last_value) as described above.
"""
def swapaxes(x):
transposed_axes = list(range(len(x.shape)))
transposed_axes[axis] = 0
transposed_axes[0] = axis
return jnp.transpose(x, axes=transposed_axes)
if axis != 0:
xs = nested_map(swapaxes, xs)
def transposed_f(c, x):
y, d = f(x, c)
return d, y
if remat:
last_value, ys = lax.scan(jax.remat(transposed_f), init_value, xs)
else:
last_value, ys = lax.scan(transposed_f, init_value, xs)
if axis != 0:
ys = nested_map(swapaxes, ys)
return ys, last_value
示例3: test_descend
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import transpose [as 别名]
def test_descend(random_tensors):
h, s, iso, dis = random_tensors
s = simple_mera.descend(h, s, iso, dis)
assert len(s.shape) == 6
D = s.shape[0]
smat = np.reshape(s, [D**3] * 2)
assert np.isclose(np.trace(smat), 1.0)
assert np.isclose(np.linalg.norm(smat - np.conj(np.transpose(smat))), 0.0)
spec, _ = np.linalg.eigh(smat)
assert np.alltrue(spec >= 0.0)
示例4: test_ascend
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import transpose [as 别名]
def test_ascend(random_tensors):
h, s, iso, dis = random_tensors
h = simple_mera.ascend(h, s, iso, dis)
assert len(h.shape) == 6
D = h.shape[0]
hmat = np.reshape(h, [D**3] * 2)
norm = np.linalg.norm(hmat - np.conj(np.transpose(hmat)))
assert np.isclose(norm, 0.0)
示例5: wavelet_tensors
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import transpose [as 别名]
def wavelet_tensors(request):
"""Returns the Hamiltonian and MERA tensors for the D=2 wavelet MERA.
From Evenbly & White, Phys. Rev. Lett. 116, 140403 (2016).
"""
D = 2
h = simple_mera.ham_ising()
E = np.array([[1, 0], [0, 1]])
X = np.array([[0, 1], [1, 0]])
Y = np.array([[0, -1j], [1j, 0]])
Z = np.array([[1, 0], [0, -1]])
wmat_un = np.real((np.sqrt(3) + np.sqrt(2)) / 4 * np.kron(E, E) +
(np.sqrt(3) - np.sqrt(2)) / 4 * np.kron(Z, Z) + 1.j *
(1 + np.sqrt(2)) / 4 * np.kron(X, Y) + 1.j *
(1 - np.sqrt(2)) / 4 * np.kron(Y, X))
umat = np.real((np.sqrt(3) + 2) / 4 * np.kron(E, E) +
(np.sqrt(3) - 2) / 4 * np.kron(Z, Z) +
1.j / 4 * np.kron(X, Y) + 1.j / 4 * np.kron(Y, X))
w = np.reshape(wmat_un, (D, D, D, D))[:, 0, :, :]
u = np.reshape(umat, (D, D, D, D))
w = np.transpose(w, [1, 2, 0])
u = np.transpose(u, [2, 3, 0, 1])
return tuple(x.astype(np.complex128) for x in (h, w, u))
示例6: compute_singleton_mean_variance
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import transpose [as 别名]
def compute_singleton_mean_variance(X, Y, dimension, msq, lam, eta1, xisq, c, var_obs):
P, N = X.shape[1], X.shape[0]
probe = jnp.zeros((2, P))
probe = jax.ops.index_update(probe, jax.ops.index[:, dimension], jnp.array([1.0, -1.0]))
eta2 = jnp.square(eta1) * jnp.sqrt(xisq) / msq
kappa = jnp.sqrt(msq) * lam / jnp.sqrt(msq + jnp.square(eta1 * lam))
kX = kappa * X
kprobe = kappa * probe
k_xx = kernel(kX, kX, eta1, eta2, c) + var_obs * jnp.eye(N)
k_xx_inv = jnp.linalg.inv(k_xx)
k_probeX = kernel(kprobe, kX, eta1, eta2, c)
k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c)
vec = jnp.array([0.50, -0.50])
mu = jnp.matmul(k_probeX, jnp.matmul(k_xx_inv, Y))
mu = jnp.dot(mu, vec)
var = k_prbprb - jnp.matmul(k_probeX, jnp.matmul(k_xx_inv, jnp.transpose(k_probeX)))
var = jnp.matmul(var, vec)
var = jnp.dot(var, vec)
return mu, var
# Compute the mean and variance of coefficient theta_ij for a MCMC sample of the
# kernel hyperparameters (eta1, xisq, ...). Compare to theorem 5.1 in reference [1].
示例7: compute_pairwise_mean_variance
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import transpose [as 别名]
def compute_pairwise_mean_variance(X, Y, dim1, dim2, msq, lam, eta1, xisq, c, var_obs):
P, N = X.shape[1], X.shape[0]
probe = jnp.zeros((4, P))
probe = jax.ops.index_update(probe, jax.ops.index[:, dim1], jnp.array([1.0, 1.0, -1.0, -1.0]))
probe = jax.ops.index_update(probe, jax.ops.index[:, dim2], jnp.array([1.0, -1.0, 1.0, -1.0]))
eta2 = jnp.square(eta1) * jnp.sqrt(xisq) / msq
kappa = jnp.sqrt(msq) * lam / jnp.sqrt(msq + jnp.square(eta1 * lam))
kX = kappa * X
kprobe = kappa * probe
k_xx = kernel(kX, kX, eta1, eta2, c) + var_obs * jnp.eye(N)
k_xx_inv = jnp.linalg.inv(k_xx)
k_probeX = kernel(kprobe, kX, eta1, eta2, c)
k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c)
vec = jnp.array([0.25, -0.25, -0.25, 0.25])
mu = jnp.matmul(k_probeX, jnp.matmul(k_xx_inv, Y))
mu = jnp.dot(mu, vec)
var = k_prbprb - jnp.matmul(k_probeX, jnp.matmul(k_xx_inv, jnp.transpose(k_probeX)))
var = jnp.matmul(var, vec)
var = jnp.dot(var, vec)
return mu, var
# Sample coefficients theta from the posterior for a given MCMC sample.
# The first P returned values are {theta_1, theta_2, ...., theta_P}, while
# the remaining values are {theta_ij} for i,j in the list `active_dims`,
# sorted so that i < j.
示例8: predict
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import transpose [as 别名]
def predict(rng_key, X, Y, X_test, var, length, noise):
# compute kernels between train and test data, etc.
k_pp = kernel(X_test, X_test, var, length, noise, include_noise=True)
k_pX = kernel(X_test, X, var, length, noise, include_noise=False)
k_XX = kernel(X, X, var, length, noise, include_noise=True)
K_xx_inv = jnp.linalg.inv(k_XX)
K = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX)))
sigma_noise = jnp.sqrt(jnp.clip(jnp.diag(K), a_min=0.)) * jax.random.normal(rng_key, X_test.shape[:1])
mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, Y))
# we return both the mean function and a sample from the posterior predictive for the
# given set of hyperparameters
return mean, mean + sigma_noise
# create artificial regression dataset
示例9: _triangular_solve
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import transpose [as 别名]
def _triangular_solve(x, y, upper=False, transpose=False):
assert np.ndim(x) >= 2 and np.ndim(y) >= 2
n, m = x.shape[-2:]
assert y.shape[-2:] == (n, n)
# NB: JAX requires x and y have the same batch_shape
batch_shape = lax.broadcast_shapes(x.shape[:-2], y.shape[:-2])
x = np.broadcast_to(x, batch_shape + (n, m))
if y.shape[:-2] == batch_shape:
return solve_triangular(y, x, trans=int(transpose), lower=not upper)
# The following procedure handles the case: y.shape = (i, 1, n, n), x.shape = (..., i, j, n, m)
# because we don't want to broadcast y to the shape (i, j, n, n).
# We are going to make x have shape (..., 1, j, i, 1, n) to apply batched triangular_solve
dx = x.ndim
prepend_ndim = dx - y.ndim # ndim of ... part
# Reshape x with the shape (..., 1, i, j, 1, n, m)
x_new_shape = batch_shape[:prepend_ndim]
for (sy, sx) in zip(y.shape[:-2], batch_shape[prepend_ndim:]):
x_new_shape += (sx // sy, sy)
x_new_shape += (n, m,)
x = np.reshape(x, x_new_shape)
# Permute y to make it have shape (..., 1, j, m, i, 1, n)
batch_ndim = x.ndim - 2
permute_dims = (tuple(range(prepend_ndim))
+ tuple(range(prepend_ndim, batch_ndim, 2))
+ (batch_ndim + 1,)
+ tuple(range(prepend_ndim + 1, batch_ndim, 2))
+ (batch_ndim,))
x = np.transpose(x, permute_dims)
x_permute_shape = x.shape
# reshape to (-1, i, 1, n)
x = np.reshape(x, (-1,) + y.shape[:-1])
# permute to (i, 1, n, -1)
x = np.moveaxis(x, 0, -1)
sol = solve_triangular(y, x, trans=int(transpose), lower=not upper) # shape: (i, 1, n, -1)
sol = np.moveaxis(sol, -1, 0) # shape: (-1, i, 1, n)
sol = np.reshape(sol, x_permute_shape) # shape: (..., 1, j, m, i, 1, n)
# now we permute back to x_new_shape = (..., 1, i, j, 1, n, m)
permute_inv_dims = tuple(range(prepend_ndim))
for i in range(y.ndim - 2):
permute_inv_dims += (prepend_ndim + i, dx + i - 1)
permute_inv_dims += (sol.ndim - 1, prepend_ndim + y.ndim - 2)
sol = np.transpose(sol, permute_inv_dims)
return sol.reshape(batch_shape + (n, m))
示例10: sample_theta_space
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import transpose [as 别名]
def sample_theta_space(X, Y, active_dims, msq, lam, eta1, xisq, c, var_obs):
P, N, M = X.shape[1], X.shape[0], len(active_dims)
# the total number of coefficients we return
num_coefficients = P + M * (M - 1) // 2
probe = jnp.zeros((2 * P + 2 * M * (M - 1), P))
vec = jnp.zeros((num_coefficients, 2 * P + 2 * M * (M - 1)))
start1 = 0
start2 = 0
for dim in range(P):
probe = jax.ops.index_update(probe, jax.ops.index[start1:start1 + 2, dim], jnp.array([1.0, -1.0]))
vec = jax.ops.index_update(vec, jax.ops.index[start2, start1:start1 + 2], jnp.array([0.5, -0.5]))
start1 += 2
start2 += 1
for dim1 in active_dims:
for dim2 in active_dims:
if dim1 >= dim2:
continue
probe = jax.ops.index_update(probe, jax.ops.index[start1:start1 + 4, dim1],
jnp.array([1.0, 1.0, -1.0, -1.0]))
probe = jax.ops.index_update(probe, jax.ops.index[start1:start1 + 4, dim2],
jnp.array([1.0, -1.0, 1.0, -1.0]))
vec = jax.ops.index_update(vec, jax.ops.index[start2, start1:start1 + 4],
jnp.array([0.25, -0.25, -0.25, 0.25]))
start1 += 4
start2 += 1
eta2 = jnp.square(eta1) * jnp.sqrt(xisq) / msq
kappa = jnp.sqrt(msq) * lam / jnp.sqrt(msq + jnp.square(eta1 * lam))
kX = kappa * X
kprobe = kappa * probe
k_xx = kernel(kX, kX, eta1, eta2, c) + var_obs * jnp.eye(N)
k_xx_inv = jnp.linalg.inv(k_xx)
k_probeX = kernel(kprobe, kX, eta1, eta2, c)
k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c)
mu = jnp.matmul(k_probeX, jnp.matmul(k_xx_inv, Y))
mu = jnp.sum(mu * vec, axis=-1)
covar = k_prbprb - jnp.matmul(k_probeX, jnp.matmul(k_xx_inv, jnp.transpose(k_probeX)))
covar = jnp.matmul(vec, jnp.matmul(covar, jnp.transpose(vec)))
L = jnp.linalg.cholesky(covar)
# sample from N(mu, covar)
sample = mu + jnp.matmul(L, np.random.randn(num_coefficients))
return sample
# Helper function for doing HMC inference
示例11: _batch_mahalanobis
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import transpose [as 别名]
def _batch_mahalanobis(bL, bx):
if bL.shape[:-1] == bx.shape:
# no need to use the below optimization procedure
solve_bL_bx = solve_triangular(bL, bx[..., None], lower=True).squeeze(-1)
return jnp.sum(jnp.square(solve_bL_bx), -1)
# NB: The following procedure handles the case: bL.shape = (i, 1, n, n), bx.shape = (i, j, n)
# because we don't want to broadcast bL to the shape (i, j, n, n).
# Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n),
# we are going to make bx have shape (..., 1, j, i, 1, n) to apply batched tril_solve
sample_ndim = bx.ndim - bL.ndim + 1 # size of sample_shape
out_shape = jnp.shape(bx)[:-1] # shape of output
# Reshape bx with the shape (..., 1, i, j, 1, n)
bx_new_shape = out_shape[:sample_ndim]
for (sL, sx) in zip(bL.shape[:-2], out_shape[sample_ndim:]):
bx_new_shape += (sx // sL, sL)
bx_new_shape += (-1,)
bx = jnp.reshape(bx, bx_new_shape)
# Permute bx to make it have shape (..., 1, j, i, 1, n)
permute_dims = (tuple(range(sample_ndim))
+ tuple(range(sample_ndim, bx.ndim - 1, 2))
+ tuple(range(sample_ndim + 1, bx.ndim - 1, 2))
+ (bx.ndim - 1,))
bx = jnp.transpose(bx, permute_dims)
# reshape to (-1, i, 1, n)
xt = jnp.reshape(bx, (-1,) + bL.shape[:-1])
# permute to (i, 1, n, -1)
xt = jnp.moveaxis(xt, 0, -1)
solve_bL_bx = solve_triangular(bL, xt, lower=True) # shape: (i, 1, n, -1)
M = jnp.sum(solve_bL_bx ** 2, axis=-2) # shape: (i, 1, -1)
# permute back to (-1, i, 1)
M = jnp.moveaxis(M, -1, 0)
# reshape back to (..., 1, j, i, 1)
M = jnp.reshape(M, bx.shape[:-1])
# permute back to (..., 1, i, j, 1)
permute_inv_dims = tuple(range(sample_ndim))
for i in range(bL.ndim - 2):
permute_inv_dims += (sample_ndim + i, len(out_shape) + i)
M = jnp.transpose(M, permute_inv_dims)
return jnp.reshape(M, out_shape)