本文整理汇总了Python中jax.numpy.sqrt方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.sqrt方法的具体用法?Python numpy.sqrt怎么用?Python numpy.sqrt使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类jax.numpy
的用法示例。
在下文中一共展示了numpy.sqrt方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: BatchNorm
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sqrt [as 别名]
def BatchNorm(axis=(0, 1, 2), epsilon=1e-5, center=True, scale=True,
beta_init=zeros, gamma_init=ones):
"""Layer construction function for a batch normalization layer."""
axis = (axis,) if np.isscalar(axis) else axis
@parametrized
def batch_norm(x):
ed = tuple(None if i in axis else slice(None) for i in range(np.ndim(x)))
mean, var = np.mean(x, axis, keepdims=True), fastvar(x, axis, keepdims=True)
z = (x - mean) / np.sqrt(var + epsilon)
shape = tuple(d for i, d in enumerate(x.shape) if i not in axis)
scaled = z * parameter(shape, gamma_init, 'gamma')[ed] if scale else z
return scaled + parameter(shape, beta_init, 'beta')[ed] if center else scaled
return batch_norm
示例2: ConvOrConvTranspose
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sqrt [as 别名]
def ConvOrConvTranspose(out_chan, filter_shape=(3, 3), strides=None, padding='SAME', init_scale=1.,
transpose=False):
strides = strides or (1,) * len(filter_shape)
def apply(inputs, V, g, b):
V = g * _l2_normalize(V, (0, 1, 2))
return (lax.conv_transpose if transpose else _conv)(inputs, V, strides, padding) - b
@parametrized
def conv_or_conv_transpose(inputs):
V = parameter(filter_shape + (inputs.shape[-1], out_chan), normal(.05), 'V')
example_out = apply(inputs, V=V, g=jnp.ones(out_chan), b=jnp.zeros(out_chan))
# TODO remove need for `.aval.val` when capturing variables in initializer function:
g = Parameter(lambda key: init_scale /
jnp.sqrt(jnp.var(example_out.aval.val, (0, 1, 2)) + 1e-10), 'g')()
b = Parameter(lambda key: jnp.mean(example_out.aval.val, (0, 1, 2)) * g.aval.val, 'b')()
return apply(inputs, V, b, g)
return conv_or_conv_transpose
示例3: model
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sqrt [as 别名]
def model(X, Y, D_H):
D_X, D_Y = X.shape[1], 1
# sample first layer (we put unit normal priors on all weights)
w1 = numpyro.sample("w1", dist.Normal(jnp.zeros((D_X, D_H)), jnp.ones((D_X, D_H)))) # D_X D_H
z1 = nonlin(jnp.matmul(X, w1)) # N D_H <= first layer of activations
# sample second layer
w2 = numpyro.sample("w2", dist.Normal(jnp.zeros((D_H, D_H)), jnp.ones((D_H, D_H)))) # D_H D_H
z2 = nonlin(jnp.matmul(z1, w2)) # N D_H <= second layer of activations
# sample final layer of weights and neural network output
w3 = numpyro.sample("w3", dist.Normal(jnp.zeros((D_H, D_Y)), jnp.ones((D_H, D_Y)))) # D_H D_Y
z3 = jnp.matmul(z2, w3) # N D_Y <= output of the neural network
# we put a prior on the observation noise
prec_obs = numpyro.sample("prec_obs", dist.Gamma(3.0, 1.0))
sigma_obs = 1.0 / jnp.sqrt(prec_obs)
# observe data
numpyro.sample("Y", dist.Normal(z3, sigma_obs), obs=Y)
# helper function for HMC inference
示例4: _get_proposal_loc_and_scale
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sqrt [as 别名]
def _get_proposal_loc_and_scale(samples, loc, scale, new_sample):
# get loc/scale of q_{-n} (Algorithm 1, line 5 of ref [1]) for n from 1 -> N
# these loc/scale will be stacked to the first dim; so
# proposal_loc.shape[0] = proposal_loc.shape[0] = N
# Here, we use the numerical stability procedure in Appendix 6 of [1].
weight = 1 / samples.shape[0]
if scale.ndim > loc.ndim:
new_scale = cholesky_update(scale, new_sample - loc, weight)
proposal_scale = cholesky_update(new_scale, samples - loc, -weight)
proposal_scale = cholesky_update(proposal_scale, new_sample - samples, - (weight ** 2))
else:
var = jnp.square(scale) + weight * jnp.square(new_sample - loc)
proposal_var = var - weight * jnp.square(samples - loc)
proposal_var = proposal_var - weight ** 2 * jnp.square(new_sample - samples)
proposal_scale = jnp.sqrt(proposal_var)
proposal_loc = loc + weight * (new_sample - samples)
return proposal_loc, proposal_scale
示例5: _get_tr_params
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sqrt [as 别名]
def _get_tr_params(n, p):
# See Table 1. Additionally, we pre-compute log(p), log1(-p) and the
# constant terms, that depend only on (n, p, m) in log(f(k)) (bottom of page 5).
mu = n * p
spq = jnp.sqrt(mu * (1 - p))
c = mu + 0.5
b = 1.15 + 2.53 * spq
a = -0.0873 + 0.0248 * b + 0.01 * p
alpha = (2.83 + 5.1 / b) * spq
u_r = 0.43
v_r = 0.92 - 4.2 / b
m = jnp.floor((n + 1) * p).astype(n.dtype)
log_p = jnp.log(p)
log1_p = jnp.log1p(-p)
log_h = (m + 0.5) * (jnp.log((m + 1.) / (n - m + 1.)) + log1_p - log_p) + \
(stirling_approx_tail(m) + stirling_approx_tail(n - m))
return _tr_params(c, b, a, alpha, u_r, v_r, m, log_p, log1_p, log_h)
示例6: _cvine
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sqrt [as 别名]
def _cvine(self, key, size):
# C-vine method first uses beta_dist to generate partial correlations,
# then apply signed stick breaking to transform to cholesky factor.
# Here is an attempt to prove that using signed stick breaking to
# generate correlation matrices is the same as the C-vine method in [1]
# for the entry r_32.
#
# With notations follow from [1], we define
# p: partial correlation matrix,
# c: cholesky factor,
# r: correlation matrix.
# From recursive formula (2) in [1], we have
# r_32 = p_32 * sqrt{(1 - p_21^2)*(1 - p_31^2)} + p_21 * p_31 =: I
# On the other hand, signed stick breaking process gives:
# l_21 = p_21, l_31 = p_31, l_22 = sqrt(1 - p_21^2), l_32 = p_32 * sqrt(1 - p_31^2)
# r_32 = l_21 * l_31 + l_22 * l_32
# = p_21 * p_31 + p_32 * sqrt{(1 - p_21^2)*(1 - p_31^2)} = I
beta_sample = self._beta.sample(key, size)
partial_correlation = 2 * beta_sample - 1 # scale to domain to (-1, 1)
return signed_stick_breaking_tril(partial_correlation)
示例7: _onion
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sqrt [as 别名]
def _onion(self, key, size):
key_beta, key_normal = random.split(key)
# Now we generate w term in Algorithm 3.2 of [1].
beta_sample = self._beta.sample(key_beta, size)
# The following Normal distribution is used to create a uniform distribution on
# a hypershere (ref: http://mathworld.wolfram.com/HyperspherePointPicking.html)
normal_sample = random.normal(
key_normal,
shape=size + self.batch_shape + (self.dimension * (self.dimension - 1) // 2,)
)
normal_sample = vec_to_tril_matrix(normal_sample, diagonal=0)
u_hypershere = normal_sample / jnp.linalg.norm(normal_sample, axis=-1, keepdims=True)
w = jnp.expand_dims(jnp.sqrt(beta_sample), axis=-1) * u_hypershere
# put w into the off-diagonal triangular part
cholesky = ops.index_add(jnp.zeros(size + self.batch_shape + self.event_shape),
ops.index[..., 1:, :-1], w)
# correct the diagonal
# NB: we clip due to numerical precision
diag = jnp.sqrt(jnp.clip(1 - jnp.sum(cholesky ** 2, axis=-1), a_min=0.))
cholesky = cholesky + jnp.expand_dims(diag, axis=-1) * jnp.identity(self.dimension)
return cholesky
示例8: clip_eta
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sqrt [as 别名]
def clip_eta(eta, norm, eps):
"""
Helper function to clip the perturbation to epsilon norm ball.
:param eta: A tensor with the current perturbation.
:param norm: Order of the norm (mimics Numpy).
Possible values: np.inf or 2.
:param eps: Epsilon, bound of the perturbation.
"""
# Clipping perturbation eta to self.norm norm ball
if norm not in [np.inf, 2]:
raise ValueError('norm must be np.inf or 2.')
axis = list(range(1, len(eta.shape)))
avoid_zero_div = 1e-12
if norm == np.inf:
eta = np.clip(eta, a_min=-eps, a_max=eps)
elif norm == 2:
# avoid_zero_div must go inside sqrt to avoid a divide by zero in the gradient through this operation
norm = np.sqrt(np.maximum(avoid_zero_div, np.sum(np.square(eta), axis=axis, keepdims=True)))
# We must *clip* to within the norm ball, not *normalize* onto the surface of the ball
factor = np.minimum(1., np.divide(eps, norm))
eta = eta * factor
return eta
示例9: _cholesky
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sqrt [as 别名]
def _cholesky(x):
"""
Like :func:`numpy.linalg.cholesky` but uses sqrt for scalar matrices.
"""
if x.shape[-1] == 1:
return np.sqrt(x)
return np.linalg.cholesky(x)
示例10: gaussian_sample
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sqrt [as 别名]
def gaussian_sample(key, mu, sigmasq):
"""Sample a diagonal Gaussian."""
return mu + np.sqrt(sigmasq) * random.normal(key, mu.shape)
示例11: _l2_normalize
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sqrt [as 别名]
def _l2_normalize(arr, axis):
return arr / jnp.sqrt(jnp.sum(arr ** 2, axis=axis, keepdims=True))
示例12: sqrt
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sqrt [as 别名]
def sqrt(self, x):
return np.sqrt(x)
示例13: wavelet_tensors
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sqrt [as 别名]
def wavelet_tensors(request):
"""Returns the Hamiltonian and MERA tensors for the D=2 wavelet MERA.
From Evenbly & White, Phys. Rev. Lett. 116, 140403 (2016).
"""
D = 2
h = simple_mera.ham_ising()
E = np.array([[1, 0], [0, 1]])
X = np.array([[0, 1], [1, 0]])
Y = np.array([[0, -1j], [1j, 0]])
Z = np.array([[1, 0], [0, -1]])
wmat_un = np.real((np.sqrt(3) + np.sqrt(2)) / 4 * np.kron(E, E) +
(np.sqrt(3) - np.sqrt(2)) / 4 * np.kron(Z, Z) + 1.j *
(1 + np.sqrt(2)) / 4 * np.kron(X, Y) + 1.j *
(1 - np.sqrt(2)) / 4 * np.kron(Y, X))
umat = np.real((np.sqrt(3) + 2) / 4 * np.kron(E, E) +
(np.sqrt(3) - 2) / 4 * np.kron(Z, Z) +
1.j / 4 * np.kron(X, Y) + 1.j / 4 * np.kron(Y, X))
w = np.reshape(wmat_un, (D, D, D, D))[:, 0, :, :]
u = np.reshape(umat, (D, D, D, D))
w = np.transpose(w, [1, 2, 0])
u = np.transpose(u, [2, 3, 0, 1])
return tuple(x.astype(np.complex128) for x in (h, w, u))
示例14: sqrt
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sqrt [as 别名]
def sqrt(self, tensor_in):
return np.sqrt(tensor_in)
示例15: normal_logpdf
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sqrt [as 别名]
def normal_logpdf(self, x, mu, sigma):
# this is much faster than
# norm.logpdf(x, loc=mu, scale=sigma)
# https://codereview.stackexchange.com/questions/69718/fastest-computation-of-n-likelihoods-on-normal-distributions
root2 = np.sqrt(2)
root2pi = np.sqrt(2 * np.pi)
prefactor = -np.log(sigma * root2pi)
summand = -np.square(np.divide((x - mu), (root2 * sigma)))
return prefactor + summand
# def normal_logpdf(self, x, mu, sigma):
# return norm.logpdf(x, loc=mu, scale=sigma)