本文整理汇总了Python中jax.numpy.dot方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.dot方法的具体用法?Python numpy.dot怎么用?Python numpy.dot使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类jax.numpy
的用法示例。
在下文中一共展示了numpy.dot方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: GRUCell
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import dot [as 别名]
def GRUCell(carry_size, param_init):
@parametrized
def gru_cell(carry, x):
def param(name):
return parameter((x.shape[1] + carry_size, carry_size), param_init, name)
both = np.concatenate((x, carry), axis=1)
update = sigmoid(np.dot(both, param('update_kernel')))
reset = sigmoid(np.dot(both, param('reset_kernel')))
both_reset_carry = np.concatenate((x, reset * carry), axis=1)
compute = np.tanh(np.dot(both_reset_carry, param('compute_kernel')))
out = update * compute + (1 - update) * carry
return out, out
def carry_init(batch_size):
return np.zeros((batch_size, carry_size))
return gru_cell, carry_init
示例2: test_Dense_equivalent
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import dot [as 别名]
def test_Dense_equivalent():
class DenseEquivalent:
def __init__(self, out_dim, kernel_init=glorot_normal(), bias_init=normal()):
self.bias_init = bias_init
self.kernel_init = kernel_init
self.out_dim = out_dim
def apply(self, params, inputs):
kernel, bias = params
return jnp.dot(inputs, kernel) + bias
def init_parameters(self, example_inputs, key):
kernel_key, bias_key = random.split(key, 2)
kernel = self.kernel_init(kernel_key, (example_inputs.shape[-1], self.out_dim))
bias = self.bias_init(bias_key, (self.out_dim,))
return namedtuple('dense', ['kernel', 'bias'])(kernel=kernel, bias=bias)
def shaped(self, example_inputs): return ShapedParametrized(self, example_inputs)
test_Dense_shape(DenseEquivalent)
示例3: test_Parameter_dense
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import dot [as 别名]
def test_Parameter_dense():
def Dense(out_dim, kernel_init=glorot_normal(), bias_init=normal()):
@parametrized
def dense(inputs):
kernel = parameter((inputs.shape[-1], out_dim), kernel_init)
bias = parameter((out_dim,), bias_init)
return jnp.dot(inputs, kernel) + bias
return dense
net = Dense(2)
inputs = jnp.zeros((1, 3))
params = net.init_parameters(inputs, key=PRNGKey(0))
assert (3, 2) == params.parameter0.shape
assert (2,) == params.parameter1.shape
out = net.apply(params, inputs, jit=True)
assert (1, 2) == out.shape
示例4: test_mixed_up_execution_order
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import dot [as 别名]
def test_mixed_up_execution_order():
@parametrized
def dense(inputs):
bias = parameter((2,), zeros, 'bias')
kernel = parameter((inputs.shape[-1], 2), zeros, 'kernel')
return jnp.dot(inputs, kernel) + bias
inputs = jnp.zeros((1, 3))
params = dense.init_parameters(inputs, key=PRNGKey(0))
assert (2,) == params.bias.shape
assert (3, 2) == params.kernel.shape
out = dense.apply(params, inputs)
assert jnp.array_equal(jnp.zeros((1, 2)), out)
out_ = dense.apply(params, inputs, jit=True)
assert jnp.array_equal(out, out_)
示例5: test_Batched
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import dot [as 别名]
def test_Batched():
out_dim = 1
@parametrized
def unbatched_dense(input):
kernel = parameter((out_dim, input.shape[-1]), ones)
bias = parameter((out_dim,), ones)
return jnp.dot(kernel, input) + bias
batch_size = 4
unbatched_params = unbatched_dense.init_parameters(jnp.zeros(2), key=PRNGKey(0))
out = unbatched_dense.apply(unbatched_params, jnp.ones(2))
assert jnp.array([3.]) == out
dense_apply = vmap(unbatched_dense.apply, (None, 0))
out_batched_ = dense_apply(unbatched_params, jnp.ones((batch_size, 2)))
assert jnp.array_equal(jnp.stack([out] * batch_size), out_batched_)
dense = Batched(unbatched_dense)
params = dense.init_parameters(jnp.ones((batch_size, 2)), key=PRNGKey(0))
assert_parameters_equal((unbatched_params,), params)
out_batched = dense.apply(params, jnp.ones((batch_size, 2)))
assert jnp.array_equal(out_batched_, out_batched)
示例6: get_data
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import dot [as 别名]
def get_data(N=50, D_X=3, sigma_obs=0.05, N_test=500):
D_Y = 1 # create 1d outputs
np.random.seed(0)
X = jnp.linspace(-1, 1, N)
X = jnp.power(X[:, np.newaxis], jnp.arange(D_X))
W = 0.5 * np.random.randn(D_X)
Y = jnp.dot(X, W) + 0.5 * jnp.power(0.5 + X[:, 1], 2.0) * jnp.sin(4.0 * X[:, 1])
Y += sigma_obs * np.random.randn(N)
Y = Y[:, np.newaxis]
Y -= jnp.mean(Y)
Y /= jnp.std(Y)
assert X.shape == (N, D_X)
assert Y.shape == (N, D_Y)
X_test = jnp.linspace(-1.3, 1.3, N_test)
X_test = jnp.power(X_test[:, np.newaxis], jnp.arange(D_X))
return X, Y, X_test
示例7: glmm
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import dot [as 别名]
def glmm(dept, male, applications, admit=None):
v_mu = numpyro.sample('v_mu', dist.Normal(0, jnp.array([4., 1.])))
sigma = numpyro.sample('sigma', dist.HalfNormal(jnp.ones(2)))
L_Rho = numpyro.sample('L_Rho', dist.LKJCholesky(2, concentration=2))
scale_tril = sigma[..., jnp.newaxis] * L_Rho
# non-centered parameterization
num_dept = len(jnp.unique(dept))
z = numpyro.sample('z', dist.Normal(jnp.zeros((num_dept, 2)), 1))
v = jnp.dot(scale_tril, z.T).T
logits = v_mu[0] + v[dept, 0] + (v_mu[1] + v[dept, 1]) * male
if admit is None:
# we use a Delta site to record probs for predictive distribution
probs = expit(logits)
numpyro.sample('probs', dist.Delta(probs), obs=probs)
numpyro.sample('admit', dist.Binomial(applications, logits=logits), obs=admit)
示例8: _is_turning
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import dot [as 别名]
def _is_turning(inverse_mass_matrix, r_left, r_right, r_sum):
r_left, _ = ravel_pytree(r_left)
r_right, _ = ravel_pytree(r_right)
r_sum, _ = ravel_pytree(r_sum)
if inverse_mass_matrix.ndim == 2:
v_left = jnp.matmul(inverse_mass_matrix, r_left)
v_right = jnp.matmul(inverse_mass_matrix, r_right)
elif inverse_mass_matrix.ndim == 1:
v_left = jnp.multiply(inverse_mass_matrix, r_left)
v_right = jnp.multiply(inverse_mass_matrix, r_right)
# This implements dynamic termination criterion (ref [2], section A.4.2).
r_sum = r_sum - (r_left + r_right) / 2
turning_at_left = jnp.dot(v_left, r_sum) <= 0
turning_at_right = jnp.dot(v_right, r_sum) <= 0
return turning_at_left | turning_at_right
示例9: test_correlated_mvn
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import dot [as 别名]
def test_correlated_mvn():
# This requires dense mass matrix estimation.
D = 5
warmup_steps, num_samples = 5000, 8000
true_mean = 0.
a = jnp.tril(0.5 * jnp.fliplr(jnp.eye(D)) + 0.1 * jnp.exp(random.normal(random.PRNGKey(0), shape=(D, D))))
true_cov = jnp.dot(a, a.T)
true_prec = jnp.linalg.inv(true_cov)
def potential_fn(z):
return 0.5 * jnp.dot(z.T, jnp.dot(true_prec, z))
init_params = jnp.zeros(D)
kernel = NUTS(potential_fn=potential_fn, dense_mass=True)
mcmc = MCMC(kernel, warmup_steps, num_samples)
mcmc.run(random.PRNGKey(0), init_params=init_params)
samples = mcmc.get_samples()
assert_allclose(jnp.mean(samples), true_mean, atol=0.02)
assert np.sum(np.abs(np.cov(samples.T) - true_cov)) / D**2 < 0.02
示例10: testHessianVectorProduct
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import dot [as 别名]
def testHessianVectorProduct(self):
onp.random.seed(100)
key = random.PRNGKey(0)
input_size = 4
output_size = 2
width = 10
batch_size = 5
# The accuracy of the approximation will be degraded when using lower
# numberical precision (tpu is float16).
if FLAGS.jax_test_dut == 'tpu':
error_tolerance = 1e-4
else:
error_tolerance = 1e-6
predict, params, key = prepare_single_layer_model(input_size,
output_size, width, key)
b, key = get_batch(input_size, output_size, batch_size, key)
def batches():
yield b
def loss_fn(params, batch):
return loss(predict(params, batch[0]), batch[1])
# isolate the function v -> Hv
hvp, _, num_params = hessian_computation.get_hvp_fn(loss_fn, params,
batches)
# compute the full hessian
loss_cl = functools.partial(loss_fn, batch=b)
hessian = hessian_computation.full_hessian(loss_cl, params)
# test hvp
v = np.ones((num_params))
v_hvp = hvp(params, v)
v_full = np.dot(hessian, v)
self.assertArraysAllClose(v_hvp, v_full, True, atol=error_tolerance)
示例11: testTridiagEigenvalues
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import dot [as 别名]
def testTridiagEigenvalues(self, shape):
onp.random.seed(100)
sigma_squared = 1e-2
# if order > matrix shape, lanczos may fail due to linear dependence.
order = min(70, shape[0])
atol = 1e-5
key = random.PRNGKey(0)
matrix = random.normal(key, shape)
matrix = np.dot(matrix, matrix.T) # symmetrize the matrix
mvp = jit(lambda v: np.dot(matrix, v))
eigs_true, _ = onp.linalg.eigh(matrix)
@jit
def get_tridiag(key):
return lanczos.lanczos_alg(mvp, matrix.shape[0], order, rng_key=key)[0]
tridiag_matrix = get_tridiag(key)
eigs_tridiag, _ = onp.linalg.eigh(tridiag_matrix)
density, grids = density_lib.eigv_to_density(
onp.expand_dims(eigs_tridiag, 0), sigma_squared=sigma_squared)
density_true, _ = density_lib.eigv_to_density(
onp.expand_dims(eigs_true, 0), grids=grids, sigma_squared=sigma_squared)
self.assertAlmostEqual(np.max(eigs_tridiag), np.max(eigs_true), delta=atol)
self.assertAlmostEqual(np.min(eigs_tridiag), np.min(eigs_true), delta=atol)
self.assertArraysAllClose(density, density_true, True, atol=atol)
示例12: Dense
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import dot [as 别名]
def Dense(out_dim, kernel_init=glorot_normal(), bias_init=normal()):
"""Layer constructor function for a dense (fully-connected) layer."""
@parametrized
def dense(inputs):
kernel = parameter((inputs.shape[-1], out_dim), kernel_init, name='kernel')
bias = parameter((out_dim,), bias_init, name='bias')
return np.dot(inputs, kernel) + bias
return dense
示例13: test_parameter_Dense_equivalent
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import dot [as 别名]
def test_parameter_Dense_equivalent():
def DenseEquivalent(out_dim, kernel_init=glorot_normal(), bias_init=normal()):
@parametrized
def dense(inputs):
kernel = Parameter(lambda key: kernel_init(key, (inputs.shape[-1], out_dim)))()
bias = Parameter(lambda key: bias_init(key, (out_dim,)))()
return jnp.dot(inputs, kernel) + bias
return dense
test_Dense_shape(DenseEquivalent)
示例14: feed
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import dot [as 别名]
def feed(self, x):
return jnp.dot(self.W, x)
示例15: dot
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import dot [as 别名]
def dot(self, x, y):
return np.dot(x, y)