当前位置: 首页>>代码示例>>Python>>正文


Python numpy.sqrt方法代码示例

本文整理汇总了Python中jax.numpy.sqrt方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.sqrt方法的具体用法?Python numpy.sqrt怎么用?Python numpy.sqrt使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在jax.numpy的用法示例。


在下文中一共展示了numpy.sqrt方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: BatchNorm

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sqrt [as 别名]
def BatchNorm(axis=(0, 1, 2), epsilon=1e-5, center=True, scale=True,
              beta_init=zeros, gamma_init=ones):
    """Layer construction function for a batch normalization layer."""

    axis = (axis,) if np.isscalar(axis) else axis

    @parametrized
    def batch_norm(x):
        ed = tuple(None if i in axis else slice(None) for i in range(np.ndim(x)))
        mean, var = np.mean(x, axis, keepdims=True), fastvar(x, axis, keepdims=True)
        z = (x - mean) / np.sqrt(var + epsilon)
        shape = tuple(d for i, d in enumerate(x.shape) if i not in axis)

        scaled = z * parameter(shape, gamma_init, 'gamma')[ed] if scale else z
        return scaled + parameter(shape, beta_init, 'beta')[ed] if center else scaled

    return batch_norm 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:19,代码来源:modules.py

示例2: ConvOrConvTranspose

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sqrt [as 别名]
def ConvOrConvTranspose(out_chan, filter_shape=(3, 3), strides=None, padding='SAME', init_scale=1.,
                        transpose=False):
    strides = strides or (1,) * len(filter_shape)

    def apply(inputs, V, g, b):
        V = g * _l2_normalize(V, (0, 1, 2))
        return (lax.conv_transpose if transpose else _conv)(inputs, V, strides, padding) - b

    @parametrized
    def conv_or_conv_transpose(inputs):
        V = parameter(filter_shape + (inputs.shape[-1], out_chan), normal(.05), 'V')

        example_out = apply(inputs, V=V, g=jnp.ones(out_chan), b=jnp.zeros(out_chan))

        # TODO remove need for `.aval.val` when capturing variables in initializer function:
        g = Parameter(lambda key: init_scale /
                                  jnp.sqrt(jnp.var(example_out.aval.val, (0, 1, 2)) + 1e-10), 'g')()
        b = Parameter(lambda key: jnp.mean(example_out.aval.val, (0, 1, 2)) * g.aval.val, 'b')()

        return apply(inputs, V, b, g)

    return conv_or_conv_transpose 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:24,代码来源:pixelcnn.py

示例3: model

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sqrt [as 别名]
def model(X, Y, D_H):

    D_X, D_Y = X.shape[1], 1

    # sample first layer (we put unit normal priors on all weights)
    w1 = numpyro.sample("w1", dist.Normal(jnp.zeros((D_X, D_H)), jnp.ones((D_X, D_H))))  # D_X D_H
    z1 = nonlin(jnp.matmul(X, w1))   # N D_H  <= first layer of activations

    # sample second layer
    w2 = numpyro.sample("w2", dist.Normal(jnp.zeros((D_H, D_H)), jnp.ones((D_H, D_H))))  # D_H D_H
    z2 = nonlin(jnp.matmul(z1, w2))  # N D_H  <= second layer of activations

    # sample final layer of weights and neural network output
    w3 = numpyro.sample("w3", dist.Normal(jnp.zeros((D_H, D_Y)), jnp.ones((D_H, D_Y))))  # D_H D_Y
    z3 = jnp.matmul(z2, w3)  # N D_Y  <= output of the neural network

    # we put a prior on the observation noise
    prec_obs = numpyro.sample("prec_obs", dist.Gamma(3.0, 1.0))
    sigma_obs = 1.0 / jnp.sqrt(prec_obs)

    # observe data
    numpyro.sample("Y", dist.Normal(z3, sigma_obs), obs=Y)


# helper function for HMC inference 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:27,代码来源:bnn.py

示例4: _get_proposal_loc_and_scale

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sqrt [as 别名]
def _get_proposal_loc_and_scale(samples, loc, scale, new_sample):
    # get loc/scale of q_{-n} (Algorithm 1, line 5 of ref [1]) for n from 1 -> N
    # these loc/scale will be stacked to the first dim; so
    #   proposal_loc.shape[0] = proposal_loc.shape[0] = N
    # Here, we use the numerical stability procedure in Appendix 6 of [1].
    weight = 1 / samples.shape[0]
    if scale.ndim > loc.ndim:
        new_scale = cholesky_update(scale, new_sample - loc, weight)
        proposal_scale = cholesky_update(new_scale, samples - loc, -weight)
        proposal_scale = cholesky_update(proposal_scale, new_sample - samples, - (weight ** 2))
    else:
        var = jnp.square(scale) + weight * jnp.square(new_sample - loc)
        proposal_var = var - weight * jnp.square(samples - loc)
        proposal_var = proposal_var - weight ** 2 * jnp.square(new_sample - samples)
        proposal_scale = jnp.sqrt(proposal_var)

    proposal_loc = loc + weight * (new_sample - samples)
    return proposal_loc, proposal_scale 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:20,代码来源:mcmc.py

示例5: _get_tr_params

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sqrt [as 别名]
def _get_tr_params(n, p):
    # See Table 1. Additionally, we pre-compute log(p), log1(-p) and the
    # constant terms, that depend only on (n, p, m) in log(f(k)) (bottom of page 5).
    mu = n * p
    spq = jnp.sqrt(mu * (1 - p))
    c = mu + 0.5
    b = 1.15 + 2.53 * spq
    a = -0.0873 + 0.0248 * b + 0.01 * p
    alpha = (2.83 + 5.1 / b) * spq
    u_r = 0.43
    v_r = 0.92 - 4.2 / b
    m = jnp.floor((n + 1) * p).astype(n.dtype)
    log_p = jnp.log(p)
    log1_p = jnp.log1p(-p)
    log_h = (m + 0.5) * (jnp.log((m + 1.) / (n - m + 1.)) + log1_p - log_p) + \
            (stirling_approx_tail(m) + stirling_approx_tail(n - m))
    return _tr_params(c, b, a, alpha, u_r, v_r, m, log_p, log1_p, log_h) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:19,代码来源:util.py

示例6: _cvine

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sqrt [as 别名]
def _cvine(self, key, size):
        # C-vine method first uses beta_dist to generate partial correlations,
        # then apply signed stick breaking to transform to cholesky factor.
        # Here is an attempt to prove that using signed stick breaking to
        # generate correlation matrices is the same as the C-vine method in [1]
        # for the entry r_32.
        #
        # With notations follow from [1], we define
        #   p: partial correlation matrix,
        #   c: cholesky factor,
        #   r: correlation matrix.
        # From recursive formula (2) in [1], we have
        #   r_32 = p_32 * sqrt{(1 - p_21^2)*(1 - p_31^2)} + p_21 * p_31 =: I
        # On the other hand, signed stick breaking process gives:
        #   l_21 = p_21, l_31 = p_31, l_22 = sqrt(1 - p_21^2), l_32 = p_32 * sqrt(1 - p_31^2)
        #   r_32 = l_21 * l_31 + l_22 * l_32
        #        = p_21 * p_31 + p_32 * sqrt{(1 - p_21^2)*(1 - p_31^2)} = I
        beta_sample = self._beta.sample(key, size)
        partial_correlation = 2 * beta_sample - 1  # scale to domain to (-1, 1)
        return signed_stick_breaking_tril(partial_correlation) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:22,代码来源:continuous.py

示例7: _onion

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sqrt [as 别名]
def _onion(self, key, size):
        key_beta, key_normal = random.split(key)
        # Now we generate w term in Algorithm 3.2 of [1].
        beta_sample = self._beta.sample(key_beta, size)
        # The following Normal distribution is used to create a uniform distribution on
        # a hypershere (ref: http://mathworld.wolfram.com/HyperspherePointPicking.html)
        normal_sample = random.normal(
            key_normal,
            shape=size + self.batch_shape + (self.dimension * (self.dimension - 1) // 2,)
        )
        normal_sample = vec_to_tril_matrix(normal_sample, diagonal=0)
        u_hypershere = normal_sample / jnp.linalg.norm(normal_sample, axis=-1, keepdims=True)
        w = jnp.expand_dims(jnp.sqrt(beta_sample), axis=-1) * u_hypershere

        # put w into the off-diagonal triangular part
        cholesky = ops.index_add(jnp.zeros(size + self.batch_shape + self.event_shape),
                                 ops.index[..., 1:, :-1], w)
        # correct the diagonal
        # NB: we clip due to numerical precision
        diag = jnp.sqrt(jnp.clip(1 - jnp.sum(cholesky ** 2, axis=-1), a_min=0.))
        cholesky = cholesky + jnp.expand_dims(diag, axis=-1) * jnp.identity(self.dimension)
        return cholesky 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:24,代码来源:continuous.py

示例8: clip_eta

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sqrt [as 别名]
def clip_eta(eta, norm, eps):
  """
  Helper function to clip the perturbation to epsilon norm ball.
  :param eta: A tensor with the current perturbation.
  :param norm: Order of the norm (mimics Numpy).
              Possible values: np.inf or 2.
  :param eps: Epsilon, bound of the perturbation.
  """

  # Clipping perturbation eta to self.norm norm ball
  if norm not in [np.inf, 2]:
    raise ValueError('norm must be np.inf or 2.')

  axis = list(range(1, len(eta.shape)))
  avoid_zero_div = 1e-12
  if norm == np.inf:
    eta = np.clip(eta, a_min=-eps, a_max=eps)
  elif norm == 2:
    # avoid_zero_div must go inside sqrt to avoid a divide by zero in the gradient through this operation
    norm = np.sqrt(np.maximum(avoid_zero_div, np.sum(np.square(eta), axis=axis, keepdims=True)))
    # We must *clip* to within the norm ball, not *normalize* onto the surface of the ball
    factor = np.minimum(1., np.divide(eps, norm))
    eta = eta * factor
  return eta 
开发者ID:tensorflow,项目名称:cleverhans,代码行数:26,代码来源:utils.py

示例9: _cholesky

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sqrt [as 别名]
def _cholesky(x):
    """
    Like :func:`numpy.linalg.cholesky` but uses sqrt for scalar matrices.
    """
    if x.shape[-1] == 1:
        return np.sqrt(x)
    return np.linalg.cholesky(x) 
开发者ID:pyro-ppl,项目名称:funsor,代码行数:9,代码来源:ops.py

示例10: gaussian_sample

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sqrt [as 别名]
def gaussian_sample(key, mu, sigmasq):
    """Sample a diagonal Gaussian."""
    return mu + np.sqrt(sigmasq) * random.normal(key, mu.shape) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:5,代码来源:mnist_vae.py

示例11: _l2_normalize

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sqrt [as 别名]
def _l2_normalize(arr, axis):
    return arr / jnp.sqrt(jnp.sum(arr ** 2, axis=axis, keepdims=True)) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:4,代码来源:pixelcnn.py

示例12: sqrt

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sqrt [as 别名]
def sqrt(self, x):
        return np.sqrt(x) 
开发者ID:sharadmv,项目名称:deepx,代码行数:4,代码来源:jax.py

示例13: wavelet_tensors

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sqrt [as 别名]
def wavelet_tensors(request):
  """Returns the Hamiltonian and MERA tensors for the D=2 wavelet MERA.

  From Evenbly & White, Phys. Rev. Lett. 116, 140403 (2016).
  """
  D = 2
  h = simple_mera.ham_ising()

  E = np.array([[1, 0], [0, 1]])
  X = np.array([[0, 1], [1, 0]])
  Y = np.array([[0, -1j], [1j, 0]])
  Z = np.array([[1, 0], [0, -1]])

  wmat_un = np.real((np.sqrt(3) + np.sqrt(2)) / 4 * np.kron(E, E) +
                    (np.sqrt(3) - np.sqrt(2)) / 4 * np.kron(Z, Z) + 1.j *
                    (1 + np.sqrt(2)) / 4 * np.kron(X, Y) + 1.j *
                    (1 - np.sqrt(2)) / 4 * np.kron(Y, X))

  umat = np.real((np.sqrt(3) + 2) / 4 * np.kron(E, E) +
                 (np.sqrt(3) - 2) / 4 * np.kron(Z, Z) +
                 1.j / 4 * np.kron(X, Y) + 1.j / 4 * np.kron(Y, X))

  w = np.reshape(wmat_un, (D, D, D, D))[:, 0, :, :]
  u = np.reshape(umat, (D, D, D, D))

  w = np.transpose(w, [1, 2, 0])
  u = np.transpose(u, [2, 3, 0, 1])

  return tuple(x.astype(np.complex128) for x in (h, w, u)) 
开发者ID:google,项目名称:TensorNetwork,代码行数:31,代码来源:simple_mera_test.py

示例14: sqrt

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sqrt [as 别名]
def sqrt(self, tensor_in):
        return np.sqrt(tensor_in) 
开发者ID:scikit-hep,项目名称:pyhf,代码行数:4,代码来源:jax_backend.py

示例15: normal_logpdf

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sqrt [as 别名]
def normal_logpdf(self, x, mu, sigma):
        # this is much faster than
        # norm.logpdf(x, loc=mu, scale=sigma)
        # https://codereview.stackexchange.com/questions/69718/fastest-computation-of-n-likelihoods-on-normal-distributions
        root2 = np.sqrt(2)
        root2pi = np.sqrt(2 * np.pi)
        prefactor = -np.log(sigma * root2pi)
        summand = -np.square(np.divide((x - mu), (root2 * sigma)))
        return prefactor + summand

    # def normal_logpdf(self, x, mu, sigma):
    #     return norm.logpdf(x, loc=mu, scale=sigma) 
开发者ID:scikit-hep,项目名称:pyhf,代码行数:14,代码来源:jax_backend.py


注:本文中的jax.numpy.sqrt方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。