本文整理汇总了Python中jax.random.uniform方法的典型用法代码示例。如果您正苦于以下问题:Python random.uniform方法的具体用法?Python random.uniform怎么用?Python random.uniform使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类jax.random
的用法示例。
在下文中一共展示了random.uniform方法的14个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: jax_randint
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import uniform [as 别名]
def jax_randint(key, shape, minval, maxval, dtype=np.int32):
"""Sample uniform random values in [minval, maxval) with given shape/dtype.
Args:
key: a PRNGKey used as the random key.
shape: a tuple of nonnegative integers representing the shape.
minval: int or array of ints broadcast-compatible with ``shape``, a minimum
(inclusive) value for the range.
maxval: int or array of ints broadcast-compatible with ``shape``, a maximum
(exclusive) value for the range.
dtype: optional, an int dtype for the returned values (default int32).
Returns:
A random array with the specified shape and dtype.
"""
return jax_random.randint(key, shape, minval=minval, maxval=maxval,
dtype=dtype)
示例2: _binomial_inversion
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import uniform [as 别名]
def _binomial_inversion(key, p, n):
def _binom_inv_body_fn(val):
i, key, geom_acc = val
key, key_u = random.split(key)
u = random.uniform(key_u)
geom = jnp.floor(jnp.log1p(-u) / log1_p) + 1
geom_acc = geom_acc + geom
return i + 1, key, geom_acc
def _binom_inv_cond_fn(val):
i, _, geom_acc = val
return geom_acc <= n
log1_p = jnp.log1p(-p)
ret = lax.while_loop(_binom_inv_cond_fn, _binom_inv_body_fn,
(-1, key, 0.))
return ret[0]
示例3: _onion
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import uniform [as 别名]
def _onion(self, key, size):
key_beta, key_normal = random.split(key)
# Now we generate w term in Algorithm 3.2 of [1].
beta_sample = self._beta.sample(key_beta, size)
# The following Normal distribution is used to create a uniform distribution on
# a hypershere (ref: http://mathworld.wolfram.com/HyperspherePointPicking.html)
normal_sample = random.normal(
key_normal,
shape=size + self.batch_shape + (self.dimension * (self.dimension - 1) // 2,)
)
normal_sample = vec_to_tril_matrix(normal_sample, diagonal=0)
u_hypershere = normal_sample / jnp.linalg.norm(normal_sample, axis=-1, keepdims=True)
w = jnp.expand_dims(jnp.sqrt(beta_sample), axis=-1) * u_hypershere
# put w into the off-diagonal triangular part
cholesky = ops.index_add(jnp.zeros(size + self.batch_shape + self.event_shape),
ops.index[..., 1:, :-1], w)
# correct the diagonal
# NB: we clip due to numerical precision
diag = jnp.sqrt(jnp.clip(1 - jnp.sum(cholesky ** 2, axis=-1), a_min=0.))
cholesky = cholesky + jnp.expand_dims(diag, axis=-1) * jnp.identity(self.dimension)
return cholesky
示例4: uniform
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import uniform [as 别名]
def uniform(self, *args, **kwargs):
return backend()["random_uniform"](*args, **kwargs)
示例5: random_inputs
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import uniform [as 别名]
def random_inputs(input_shape, key=PRNGKey(0)):
if type(input_shape) is tuple:
return random.uniform(key, input_shape, np.float32)
elif type(input_shape) is list:
return [random_inputs(key, shape) for shape in input_shape]
else:
raise TypeError(type(input_shape))
示例6: test_rng_injection
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import uniform [as 别名]
def test_rng_injection():
@parametrized
def rand():
return random.uniform(random_key())
p = rand.init_parameters(key=PRNGKey(0))
out = rand.apply(p, key=PRNGKey(0))
assert () == out.shape
示例7: random_uniform
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import uniform [as 别名]
def random_uniform(self, shape, minval=1.0, maxval=None, dtype=None, seed=None):
dtype = dtype or self.floatx()
if maxval is None:
minval, maxval = 0.0, minval
shape = list(shape)
seed = next(self.rng)
samples = random.uniform(seed, shape, dtype=dtype)
return samples * (maxval - minval) + minval
示例8: _categorical
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import uniform [as 别名]
def _categorical(key, p, shape):
# this implementation is fast when event shape is small, and slow otherwise
# Ref: https://stackoverflow.com/a/34190035
shape = shape or p.shape[:-1]
s = jnp.cumsum(p, axis=-1)
r = random.uniform(key, shape=shape + (1,))
# FIXME: replace this computation by using binary search as suggested in the above
# reference. A while_loop + vmap for a reshaped 2D array would be enough.
return jnp.sum(s < r, axis=-1)
示例9: sample
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import uniform [as 别名]
def sample(self, key, sample_shape=()):
# We use inverse transform method:
# z ~ inv_cdf(U), where U ~ Uniform(cdf(low), cdf(high)).
# ~ Uniform(arctan(low), arctan(high)) / pi + 1/2
size = sample_shape + self.batch_shape
minval = -jnp.arctan(self.base_loc)
maxval = jnp.pi / 2
u = minval + random.uniform(key, shape=size) * (maxval - minval)
return self.base_loc + jnp.tan(u)
示例10: test_param
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import uniform [as 别名]
def test_param():
# this test the validity of model/guide sites having
# param constraints contain composed transformed
rng_keys = random.split(random.PRNGKey(0), 5)
a_minval = 1
c_minval = -2
c_maxval = -1
a_init = jnp.exp(random.normal(rng_keys[0])) + a_minval
b_init = jnp.exp(random.normal(rng_keys[1]))
c_init = random.uniform(rng_keys[2], minval=c_minval, maxval=c_maxval)
d_init = random.uniform(rng_keys[3])
obs = random.normal(rng_keys[4])
def model():
a = numpyro.param('a', a_init, constraint=constraints.greater_than(a_minval))
b = numpyro.param('b', b_init, constraint=constraints.positive)
numpyro.sample('x', dist.Normal(a, b), obs=obs)
def guide():
c = numpyro.param('c', c_init, constraint=constraints.interval(c_minval, c_maxval))
d = numpyro.param('d', d_init, constraint=constraints.unit_interval)
numpyro.sample('y', dist.Normal(c, d), obs=obs)
adam = optim.Adam(0.01)
svi = SVI(model, guide, adam, ELBO())
svi_state = svi.init(random.PRNGKey(0))
params = svi.get_params(svi_state)
assert_allclose(params['a'], a_init)
assert_allclose(params['b'], b_init)
assert_allclose(params['c'], c_init)
assert_allclose(params['d'], d_init)
actual_loss = svi.evaluate(svi_state)
assert jnp.isfinite(actual_loss)
expected_loss = dist.Normal(c_init, d_init).log_prob(obs) - dist.Normal(a_init, b_init).log_prob(obs)
# not so precisely because we do transform / inverse transform stuffs
assert_allclose(actual_loss, expected_loss, rtol=1e-6)
示例11: sample
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import uniform [as 别名]
def sample(self, key, sample_shape=()):
transform = biject_to(self.support)
prototype_value = jnp.zeros(self.event_shape)
unconstrained_event_shape = jnp.shape(transform.inv(prototype_value))
shape = sample_shape + self.batch_shape + unconstrained_event_shape
unconstrained_samples = random.uniform(key, shape,
minval=-2,
maxval=2)
return transform(unconstrained_samples)
示例12: gen_values_within_bounds
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import uniform [as 别名]
def gen_values_within_bounds(constraint, size, key=random.PRNGKey(11)):
eps = 1e-6
if isinstance(constraint, constraints._Boolean):
return random.bernoulli(key, shape=size)
elif isinstance(constraint, constraints._GreaterThan):
return jnp.exp(random.normal(key, size)) + constraint.lower_bound + eps
elif isinstance(constraint, constraints._IntegerInterval):
lower_bound = jnp.broadcast_to(constraint.lower_bound, size)
upper_bound = jnp.broadcast_to(constraint.upper_bound, size)
return random.randint(key, size, lower_bound, upper_bound + 1)
elif isinstance(constraint, constraints._IntegerGreaterThan):
return constraint.lower_bound + random.poisson(key, np.array(5), shape=size)
elif isinstance(constraint, constraints._Interval):
lower_bound = jnp.broadcast_to(constraint.lower_bound, size)
upper_bound = jnp.broadcast_to(constraint.upper_bound, size)
return random.uniform(key, size, minval=lower_bound, maxval=upper_bound)
elif isinstance(constraint, (constraints._Real, constraints._RealVector)):
return random.normal(key, size)
elif isinstance(constraint, constraints._Simplex):
return osp.dirichlet.rvs(alpha=jnp.ones((size[-1],)), size=size[:-1])
elif isinstance(constraint, constraints._Multinomial):
n = size[-1]
return multinomial(key, p=jnp.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1])
elif isinstance(constraint, constraints._CorrCholesky):
return signed_stick_breaking_tril(
random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1))
elif isinstance(constraint, constraints._CorrMatrix):
cholesky = signed_stick_breaking_tril(
random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1))
return jnp.matmul(cholesky, jnp.swapaxes(cholesky, -2, -1))
elif isinstance(constraint, constraints._LowerCholesky):
return jnp.tril(random.uniform(key, size))
elif isinstance(constraint, constraints._PositiveDefinite):
x = random.normal(key, size)
return jnp.matmul(x, jnp.swapaxes(x, -2, -1))
elif isinstance(constraint, constraints._OrderedVector):
x = jnp.cumsum(random.exponential(key, size), -1)
return x - random.normal(key, size[:-1])
else:
raise NotImplementedError('{} not implemented.'.format(constraint))
示例13: gen_values_outside_bounds
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import uniform [as 别名]
def gen_values_outside_bounds(constraint, size, key=random.PRNGKey(11)):
if isinstance(constraint, constraints._Boolean):
return random.bernoulli(key, shape=size) - 2
elif isinstance(constraint, constraints._GreaterThan):
return constraint.lower_bound - jnp.exp(random.normal(key, size))
elif isinstance(constraint, constraints._IntegerInterval):
lower_bound = jnp.broadcast_to(constraint.lower_bound, size)
return random.randint(key, size, lower_bound - 1, lower_bound)
elif isinstance(constraint, constraints._IntegerGreaterThan):
return constraint.lower_bound - random.poisson(key, np.array(5), shape=size)
elif isinstance(constraint, constraints._Interval):
upper_bound = jnp.broadcast_to(constraint.upper_bound, size)
return random.uniform(key, size, minval=upper_bound, maxval=upper_bound + 1.)
elif isinstance(constraint, (constraints._Real, constraints._RealVector)):
return lax.full(size, jnp.nan)
elif isinstance(constraint, constraints._Simplex):
return osp.dirichlet.rvs(alpha=jnp.ones((size[-1],)), size=size[:-1]) + 1e-2
elif isinstance(constraint, constraints._Multinomial):
n = size[-1]
return multinomial(key, p=jnp.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1]) + 1
elif isinstance(constraint, constraints._CorrCholesky):
return signed_stick_breaking_tril(
random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,),
minval=-1, maxval=1)) + 1e-2
elif isinstance(constraint, constraints._CorrMatrix):
cholesky = 1e-2 + signed_stick_breaking_tril(
random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1))
return jnp.matmul(cholesky, jnp.swapaxes(cholesky, -2, -1))
elif isinstance(constraint, constraints._LowerCholesky):
return random.uniform(key, size)
elif isinstance(constraint, constraints._PositiveDefinite):
return random.normal(key, size)
elif isinstance(constraint, constraints._OrderedVector):
x = jnp.cumsum(random.exponential(key, size), -1)
return x[..., ::-1]
else:
raise NotImplementedError('{} not implemented.'.format(constraint))
示例14: _binomial_btrs
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import uniform [as 别名]
def _binomial_btrs(key, p, n):
"""
Based on the transformed rejection sampling algorithm (BTRS) from the
following reference:
Hormann, "The Generation of Binonmial Random Variates"
(https://core.ac.uk/download/pdf/11007254.pdf)
"""
def _btrs_body_fn(val):
_, key, _, _ = val
key, key_u, key_v = random.split(key, 3)
u = random.uniform(key_u)
v = random.uniform(key_v)
u = u - 0.5
k = jnp.floor((2 * tr_params.a / (0.5 - jnp.abs(u)) + tr_params.b) * u + tr_params.c).astype(n.dtype)
return k, key, u, v
def _btrs_cond_fn(val):
def accept_fn(k, u, v):
# See acceptance condition in Step 3. (Page 3) of TRS algorithm
# v <= f(k) * g_grad(u) / alpha
m = tr_params.m
log_p = tr_params.log_p
log1_p = tr_params.log1_p
# See: formula for log(f(k)) at bottom of Page 5.
log_f = (n + 1.) * jnp.log((n - m + 1.) / (n - k + 1.)) + \
(k + 0.5) * (jnp.log((n - k + 1.) / (k + 1.)) + log_p - log1_p) + \
(stirling_approx_tail(k) - stirling_approx_tail(n - k)) + tr_params.log_h
g = (tr_params.a / (0.5 - jnp.abs(u)) ** 2) + tr_params.b
return jnp.log((v * tr_params.alpha) / g) <= log_f
k, key, u, v = val
early_accept = (jnp.abs(u) <= tr_params.u_r) & (v <= tr_params.v_r)
early_reject = (k < 0) | (k > n)
return lax.cond(early_accept | early_reject,
(),
lambda _: ~early_accept,
(k, u, v),
lambda x: ~accept_fn(*x))
tr_params = _get_tr_params(n, p)
ret = lax.while_loop(_btrs_cond_fn, _btrs_body_fn,
(-1, key, 1., 1.)) # use k=-1 initially so that cond_fn returns True
return ret[0]