当前位置: 首页>>代码示例>>Python>>正文


Python numpy.dot方法代码示例

本文整理汇总了Python中jax.numpy.dot方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.dot方法的具体用法?Python numpy.dot怎么用?Python numpy.dot使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在jax.numpy的用法示例。


在下文中一共展示了numpy.dot方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: GRUCell

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import dot [as 别名]
def GRUCell(carry_size, param_init):
    @parametrized
    def gru_cell(carry, x):
        def param(name):
            return parameter((x.shape[1] + carry_size, carry_size), param_init, name)

        both = np.concatenate((x, carry), axis=1)
        update = sigmoid(np.dot(both, param('update_kernel')))
        reset = sigmoid(np.dot(both, param('reset_kernel')))
        both_reset_carry = np.concatenate((x, reset * carry), axis=1)
        compute = np.tanh(np.dot(both_reset_carry, param('compute_kernel')))
        out = update * compute + (1 - update) * carry
        return out, out

    def carry_init(batch_size):
        return np.zeros((batch_size, carry_size))

    return gru_cell, carry_init 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:20,代码来源:modules.py

示例2: test_Dense_equivalent

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import dot [as 别名]
def test_Dense_equivalent():
    class DenseEquivalent:
        def __init__(self, out_dim, kernel_init=glorot_normal(), bias_init=normal()):
            self.bias_init = bias_init
            self.kernel_init = kernel_init
            self.out_dim = out_dim

        def apply(self, params, inputs):
            kernel, bias = params
            return jnp.dot(inputs, kernel) + bias

        def init_parameters(self, example_inputs, key):
            kernel_key, bias_key = random.split(key, 2)
            kernel = self.kernel_init(kernel_key, (example_inputs.shape[-1], self.out_dim))
            bias = self.bias_init(bias_key, (self.out_dim,))
            return namedtuple('dense', ['kernel', 'bias'])(kernel=kernel, bias=bias)

        def shaped(self, example_inputs): return ShapedParametrized(self, example_inputs)

    test_Dense_shape(DenseEquivalent) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:22,代码来源:test_examples.py

示例3: test_Parameter_dense

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import dot [as 别名]
def test_Parameter_dense():
    def Dense(out_dim, kernel_init=glorot_normal(), bias_init=normal()):
        @parametrized
        def dense(inputs):
            kernel = parameter((inputs.shape[-1], out_dim), kernel_init)
            bias = parameter((out_dim,), bias_init)
            return jnp.dot(inputs, kernel) + bias

        return dense

    net = Dense(2)
    inputs = jnp.zeros((1, 3))
    params = net.init_parameters(inputs, key=PRNGKey(0))
    assert (3, 2) == params.parameter0.shape
    assert (2,) == params.parameter1.shape

    out = net.apply(params, inputs, jit=True)
    assert (1, 2) == out.shape 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:20,代码来源:test_examples.py

示例4: test_mixed_up_execution_order

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import dot [as 别名]
def test_mixed_up_execution_order():
    @parametrized
    def dense(inputs):
        bias = parameter((2,), zeros, 'bias')
        kernel = parameter((inputs.shape[-1], 2), zeros, 'kernel')
        return jnp.dot(inputs, kernel) + bias

    inputs = jnp.zeros((1, 3))

    params = dense.init_parameters(inputs, key=PRNGKey(0))
    assert (2,) == params.bias.shape
    assert (3, 2) == params.kernel.shape

    out = dense.apply(params, inputs)
    assert jnp.array_equal(jnp.zeros((1, 2)), out)

    out_ = dense.apply(params, inputs, jit=True)
    assert jnp.array_equal(out, out_) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:20,代码来源:test_core.py

示例5: test_Batched

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import dot [as 别名]
def test_Batched():
    out_dim = 1

    @parametrized
    def unbatched_dense(input):
        kernel = parameter((out_dim, input.shape[-1]), ones)
        bias = parameter((out_dim,), ones)
        return jnp.dot(kernel, input) + bias

    batch_size = 4

    unbatched_params = unbatched_dense.init_parameters(jnp.zeros(2), key=PRNGKey(0))
    out = unbatched_dense.apply(unbatched_params, jnp.ones(2))
    assert jnp.array([3.]) == out

    dense_apply = vmap(unbatched_dense.apply, (None, 0))
    out_batched_ = dense_apply(unbatched_params, jnp.ones((batch_size, 2)))
    assert jnp.array_equal(jnp.stack([out] * batch_size), out_batched_)

    dense = Batched(unbatched_dense)
    params = dense.init_parameters(jnp.ones((batch_size, 2)), key=PRNGKey(0))
    assert_parameters_equal((unbatched_params,), params)
    out_batched = dense.apply(params, jnp.ones((batch_size, 2)))
    assert jnp.array_equal(out_batched_, out_batched) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:26,代码来源:test_modules.py

示例6: get_data

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import dot [as 别名]
def get_data(N=50, D_X=3, sigma_obs=0.05, N_test=500):
    D_Y = 1  # create 1d outputs
    np.random.seed(0)
    X = jnp.linspace(-1, 1, N)
    X = jnp.power(X[:, np.newaxis], jnp.arange(D_X))
    W = 0.5 * np.random.randn(D_X)
    Y = jnp.dot(X, W) + 0.5 * jnp.power(0.5 + X[:, 1], 2.0) * jnp.sin(4.0 * X[:, 1])
    Y += sigma_obs * np.random.randn(N)
    Y = Y[:, np.newaxis]
    Y -= jnp.mean(Y)
    Y /= jnp.std(Y)

    assert X.shape == (N, D_X)
    assert Y.shape == (N, D_Y)

    X_test = jnp.linspace(-1.3, 1.3, N_test)
    X_test = jnp.power(X_test[:, np.newaxis], jnp.arange(D_X))

    return X, Y, X_test 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:21,代码来源:bnn.py

示例7: glmm

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import dot [as 别名]
def glmm(dept, male, applications, admit=None):
    v_mu = numpyro.sample('v_mu', dist.Normal(0, jnp.array([4., 1.])))

    sigma = numpyro.sample('sigma', dist.HalfNormal(jnp.ones(2)))
    L_Rho = numpyro.sample('L_Rho', dist.LKJCholesky(2, concentration=2))
    scale_tril = sigma[..., jnp.newaxis] * L_Rho
    # non-centered parameterization
    num_dept = len(jnp.unique(dept))
    z = numpyro.sample('z', dist.Normal(jnp.zeros((num_dept, 2)), 1))
    v = jnp.dot(scale_tril, z.T).T

    logits = v_mu[0] + v[dept, 0] + (v_mu[1] + v[dept, 1]) * male
    if admit is None:
        # we use a Delta site to record probs for predictive distribution
        probs = expit(logits)
        numpyro.sample('probs', dist.Delta(probs), obs=probs)
    numpyro.sample('admit', dist.Binomial(applications, logits=logits), obs=admit) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:19,代码来源:ucbadmit.py

示例8: _is_turning

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import dot [as 别名]
def _is_turning(inverse_mass_matrix, r_left, r_right, r_sum):
    r_left, _ = ravel_pytree(r_left)
    r_right, _ = ravel_pytree(r_right)
    r_sum, _ = ravel_pytree(r_sum)

    if inverse_mass_matrix.ndim == 2:
        v_left = jnp.matmul(inverse_mass_matrix, r_left)
        v_right = jnp.matmul(inverse_mass_matrix, r_right)
    elif inverse_mass_matrix.ndim == 1:
        v_left = jnp.multiply(inverse_mass_matrix, r_left)
        v_right = jnp.multiply(inverse_mass_matrix, r_right)

    # This implements dynamic termination criterion (ref [2], section A.4.2).
    r_sum = r_sum - (r_left + r_right) / 2
    turning_at_left = jnp.dot(v_left, r_sum) <= 0
    turning_at_right = jnp.dot(v_right, r_sum) <= 0
    return turning_at_left | turning_at_right 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:19,代码来源:hmc_util.py

示例9: test_correlated_mvn

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import dot [as 别名]
def test_correlated_mvn():
    # This requires dense mass matrix estimation.
    D = 5

    warmup_steps, num_samples = 5000, 8000

    true_mean = 0.
    a = jnp.tril(0.5 * jnp.fliplr(jnp.eye(D)) + 0.1 * jnp.exp(random.normal(random.PRNGKey(0), shape=(D, D))))
    true_cov = jnp.dot(a, a.T)
    true_prec = jnp.linalg.inv(true_cov)

    def potential_fn(z):
        return 0.5 * jnp.dot(z.T, jnp.dot(true_prec, z))

    init_params = jnp.zeros(D)
    kernel = NUTS(potential_fn=potential_fn, dense_mass=True)
    mcmc = MCMC(kernel, warmup_steps, num_samples)
    mcmc.run(random.PRNGKey(0), init_params=init_params)
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples), true_mean, atol=0.02)
    assert np.sum(np.abs(np.cov(samples.T) - true_cov)) / D**2 < 0.02 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:23,代码来源:test_mcmc.py

示例10: testHessianVectorProduct

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import dot [as 别名]
def testHessianVectorProduct(self):
    onp.random.seed(100)
    key = random.PRNGKey(0)
    input_size = 4
    output_size = 2
    width = 10
    batch_size = 5

    # The accuracy of the approximation will be degraded when using lower
    # numberical precision (tpu is float16).
    if FLAGS.jax_test_dut == 'tpu':
      error_tolerance = 1e-4
    else:
      error_tolerance = 1e-6

    predict, params, key = prepare_single_layer_model(input_size,
                                                      output_size, width, key)

    b, key = get_batch(input_size, output_size, batch_size, key)

    def batches():
      yield b
    def loss_fn(params, batch):
      return loss(predict(params, batch[0]), batch[1])

    # isolate the function v -> Hv
    hvp, _, num_params = hessian_computation.get_hvp_fn(loss_fn, params,
                                                        batches)

    # compute the full hessian
    loss_cl = functools.partial(loss_fn, batch=b)
    hessian = hessian_computation.full_hessian(loss_cl, params)

    # test hvp
    v = np.ones((num_params))
    v_hvp = hvp(params, v)

    v_full = np.dot(hessian, v)

    self.assertArraysAllClose(v_hvp, v_full, True, atol=error_tolerance) 
开发者ID:google,项目名称:spectral-density,代码行数:42,代码来源:spectral_density_test.py

示例11: testTridiagEigenvalues

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import dot [as 别名]
def testTridiagEigenvalues(self, shape):
    onp.random.seed(100)
    sigma_squared = 1e-2

    # if order > matrix shape, lanczos may fail due to linear dependence.
    order = min(70, shape[0])

    atol = 1e-5

    key = random.PRNGKey(0)
    matrix = random.normal(key, shape)
    matrix = np.dot(matrix, matrix.T)  # symmetrize the matrix
    mvp = jit(lambda v: np.dot(matrix, v))

    eigs_true, _ = onp.linalg.eigh(matrix)

    @jit
    def get_tridiag(key):
      return lanczos.lanczos_alg(mvp, matrix.shape[0], order, rng_key=key)[0]

    tridiag_matrix = get_tridiag(key)
    eigs_tridiag, _ = onp.linalg.eigh(tridiag_matrix)
    density, grids = density_lib.eigv_to_density(
        onp.expand_dims(eigs_tridiag, 0), sigma_squared=sigma_squared)
    density_true, _ = density_lib.eigv_to_density(
        onp.expand_dims(eigs_true, 0), grids=grids, sigma_squared=sigma_squared)

    self.assertAlmostEqual(np.max(eigs_tridiag), np.max(eigs_true), delta=atol)
    self.assertAlmostEqual(np.min(eigs_tridiag), np.min(eigs_true), delta=atol)
    self.assertArraysAllClose(density, density_true, True, atol=atol) 
开发者ID:google,项目名称:spectral-density,代码行数:32,代码来源:lanczos_test.py

示例12: Dense

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import dot [as 别名]
def Dense(out_dim, kernel_init=glorot_normal(), bias_init=normal()):
    """Layer constructor function for a dense (fully-connected) layer."""

    @parametrized
    def dense(inputs):
        kernel = parameter((inputs.shape[-1], out_dim), kernel_init, name='kernel')
        bias = parameter((out_dim,), bias_init, name='bias')
        return np.dot(inputs, kernel) + bias

    return dense 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:12,代码来源:modules.py

示例13: test_parameter_Dense_equivalent

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import dot [as 别名]
def test_parameter_Dense_equivalent():
    def DenseEquivalent(out_dim, kernel_init=glorot_normal(), bias_init=normal()):
        @parametrized
        def dense(inputs):
            kernel = Parameter(lambda key: kernel_init(key, (inputs.shape[-1], out_dim)))()
            bias = Parameter(lambda key: bias_init(key, (out_dim,)))()
            return jnp.dot(inputs, kernel) + bias

        return dense

    test_Dense_shape(DenseEquivalent) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:13,代码来源:test_examples.py

示例14: feed

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import dot [as 别名]
def feed(self, x):
        return jnp.dot(self.W, x) 
开发者ID:SymJAX,项目名称:SymJAX,代码行数:4,代码来源:wrap_class.py

示例15: dot

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import dot [as 别名]
def dot(self, x, y):
        return np.dot(x, y) 
开发者ID:sharadmv,项目名称:deepx,代码行数:4,代码来源:jax.py


注:本文中的jax.numpy.dot方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。