本文整理汇总了Python中jax.random.bernoulli方法的典型用法代码示例。如果您正苦于以下问题:Python random.bernoulli方法的具体用法?Python random.bernoulli怎么用?Python random.bernoulli使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类jax.random
的用法示例。
在下文中一共展示了random.bernoulli方法的12个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: bernoulli
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import bernoulli [as 别名]
def bernoulli(self, *args, **kwargs):
return backend()["random_bernoulli"](*args, **kwargs)
示例2: Dropout
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import bernoulli [as 别名]
def Dropout(rate, test_mode=False):
"""Constructor for a dropout function with given rate."""
rate = np.array(rate)
@parametrized
def dropout(inputs):
if test_mode or rate == 0:
return inputs
keep_rate = 1 - rate
keep = random.bernoulli(random_key(), keep_rate, inputs.shape)
return np.where(keep, inputs / keep_rate, 0)
return dropout
示例3: image_sample_grid
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import bernoulli [as 别名]
def image_sample_grid(nrow=10, ncol=10):
"""Sample images from the generative model."""
logits = decode(random.normal(random_key(), (nrow * ncol, 10)))
sampled_images = random.bernoulli(random_key(), np.logaddexp(0., logits))
return image_grid(nrow, ncol, sampled_images, (28, 28))
示例4: evaluate
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import bernoulli [as 别名]
def evaluate(images):
binarized_test = random.bernoulli(random_key(), images)
return loss(binarized_test), image_sample_grid()
示例5: dropout
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import bernoulli [as 别名]
def dropout(self, x, p, seed=None):
seed = next(self.rng)
p = 1 - p
keep = random.bernoulli(seed, p, x.shape)
return np.where(keep, x / p, 0)
示例6: binarize
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import bernoulli [as 别名]
def binarize(rng_key, batch):
return random.bernoulli(rng_key, batch).astype(batch.dtype)
示例7: _combine_tree
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import bernoulli [as 别名]
def _combine_tree(current_tree, new_tree, inverse_mass_matrix, going_right, rng_key, biased_transition):
# Now we combine the current tree and the new tree. Note that outside
# leaves of the combined tree are determined by the direction.
z_left, r_left, z_left_grad, z_right, r_right, r_right_grad = cond(
going_right,
(current_tree, new_tree),
lambda trees: (trees[0].z_left, trees[0].r_left,
trees[0].z_left_grad, trees[1].z_right,
trees[1].r_right, trees[1].z_right_grad),
(new_tree, current_tree),
lambda trees: (trees[0].z_left, trees[0].r_left,
trees[0].z_left_grad, trees[1].z_right,
trees[1].r_right, trees[1].z_right_grad)
)
r_sum = tree_multimap(jnp.add, current_tree.r_sum, new_tree.r_sum)
if biased_transition:
transition_prob = _biased_transition_kernel(current_tree, new_tree)
turning = new_tree.turning | _is_turning(inverse_mass_matrix, r_left, r_right, r_sum)
else:
transition_prob = _uniform_transition_kernel(current_tree, new_tree)
turning = current_tree.turning
transition = random.bernoulli(rng_key, transition_prob)
z_proposal, z_proposal_pe, z_proposal_grad, z_proposal_energy = cond(
transition,
new_tree, lambda tree: (tree.z_proposal, tree.z_proposal_pe, tree.z_proposal_grad, tree.z_proposal_energy),
current_tree, lambda tree: (tree.z_proposal, tree.z_proposal_pe, tree.z_proposal_grad, tree.z_proposal_energy)
)
tree_depth = current_tree.depth + 1
tree_weight = jnp.logaddexp(current_tree.weight, new_tree.weight)
diverging = new_tree.diverging
sum_accept_probs = current_tree.sum_accept_probs + new_tree.sum_accept_probs
num_proposals = current_tree.num_proposals + new_tree.num_proposals
return TreeInfo(z_left, r_left, z_left_grad, z_right, r_right, r_right_grad,
z_proposal, z_proposal_pe, z_proposal_grad, z_proposal_energy,
tree_depth, tree_weight, r_sum, turning, diverging,
sum_accept_probs, num_proposals)
示例8: sample
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import bernoulli [as 别名]
def sample(self, key, sample_shape=()):
return random.bernoulli(key, self.probs, shape=sample_shape + self.batch_shape)
示例9: gen_values_within_bounds
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import bernoulli [as 别名]
def gen_values_within_bounds(constraint, size, key=random.PRNGKey(11)):
eps = 1e-6
if isinstance(constraint, constraints._Boolean):
return random.bernoulli(key, shape=size)
elif isinstance(constraint, constraints._GreaterThan):
return jnp.exp(random.normal(key, size)) + constraint.lower_bound + eps
elif isinstance(constraint, constraints._IntegerInterval):
lower_bound = jnp.broadcast_to(constraint.lower_bound, size)
upper_bound = jnp.broadcast_to(constraint.upper_bound, size)
return random.randint(key, size, lower_bound, upper_bound + 1)
elif isinstance(constraint, constraints._IntegerGreaterThan):
return constraint.lower_bound + random.poisson(key, np.array(5), shape=size)
elif isinstance(constraint, constraints._Interval):
lower_bound = jnp.broadcast_to(constraint.lower_bound, size)
upper_bound = jnp.broadcast_to(constraint.upper_bound, size)
return random.uniform(key, size, minval=lower_bound, maxval=upper_bound)
elif isinstance(constraint, (constraints._Real, constraints._RealVector)):
return random.normal(key, size)
elif isinstance(constraint, constraints._Simplex):
return osp.dirichlet.rvs(alpha=jnp.ones((size[-1],)), size=size[:-1])
elif isinstance(constraint, constraints._Multinomial):
n = size[-1]
return multinomial(key, p=jnp.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1])
elif isinstance(constraint, constraints._CorrCholesky):
return signed_stick_breaking_tril(
random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1))
elif isinstance(constraint, constraints._CorrMatrix):
cholesky = signed_stick_breaking_tril(
random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1))
return jnp.matmul(cholesky, jnp.swapaxes(cholesky, -2, -1))
elif isinstance(constraint, constraints._LowerCholesky):
return jnp.tril(random.uniform(key, size))
elif isinstance(constraint, constraints._PositiveDefinite):
x = random.normal(key, size)
return jnp.matmul(x, jnp.swapaxes(x, -2, -1))
elif isinstance(constraint, constraints._OrderedVector):
x = jnp.cumsum(random.exponential(key, size), -1)
return x - random.normal(key, size[:-1])
else:
raise NotImplementedError('{} not implemented.'.format(constraint))
示例10: gen_values_outside_bounds
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import bernoulli [as 别名]
def gen_values_outside_bounds(constraint, size, key=random.PRNGKey(11)):
if isinstance(constraint, constraints._Boolean):
return random.bernoulli(key, shape=size) - 2
elif isinstance(constraint, constraints._GreaterThan):
return constraint.lower_bound - jnp.exp(random.normal(key, size))
elif isinstance(constraint, constraints._IntegerInterval):
lower_bound = jnp.broadcast_to(constraint.lower_bound, size)
return random.randint(key, size, lower_bound - 1, lower_bound)
elif isinstance(constraint, constraints._IntegerGreaterThan):
return constraint.lower_bound - random.poisson(key, np.array(5), shape=size)
elif isinstance(constraint, constraints._Interval):
upper_bound = jnp.broadcast_to(constraint.upper_bound, size)
return random.uniform(key, size, minval=upper_bound, maxval=upper_bound + 1.)
elif isinstance(constraint, (constraints._Real, constraints._RealVector)):
return lax.full(size, jnp.nan)
elif isinstance(constraint, constraints._Simplex):
return osp.dirichlet.rvs(alpha=jnp.ones((size[-1],)), size=size[:-1]) + 1e-2
elif isinstance(constraint, constraints._Multinomial):
n = size[-1]
return multinomial(key, p=jnp.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1]) + 1
elif isinstance(constraint, constraints._CorrCholesky):
return signed_stick_breaking_tril(
random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,),
minval=-1, maxval=1)) + 1e-2
elif isinstance(constraint, constraints._CorrMatrix):
cholesky = 1e-2 + signed_stick_breaking_tril(
random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1))
return jnp.matmul(cholesky, jnp.swapaxes(cholesky, -2, -1))
elif isinstance(constraint, constraints._LowerCholesky):
return random.uniform(key, size)
elif isinstance(constraint, constraints._PositiveDefinite):
return random.normal(key, size)
elif isinstance(constraint, constraints._OrderedVector):
x = jnp.cumsum(random.exponential(key, size), -1)
return x[..., ::-1]
else:
raise NotImplementedError('{} not implemented.'.format(constraint))
示例11: main
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import bernoulli [as 别名]
def main():
step_size = 0.001
num_epochs = 100
batch_size = 32
test_key = PRNGKey(1) # get reconstructions for a *fixed* latent variable sample over time
train_images, test_images = mnist_images()
num_complete_batches, leftover = divmod(train_images.shape[0], batch_size)
num_batches = num_complete_batches + bool(leftover)
opt = optimizers.Momentum(step_size, mass=0.9)
@jit
def binarize_batch(key, i, images):
i = i % num_batches
batch = lax.dynamic_slice_in_dim(images, i * batch_size, batch_size)
return random.bernoulli(key, batch)
@jit
def run_epoch(key, state):
def body_fun(i, state):
loss_key, data_key = random.split(random.fold_in(key, i))
batch = binarize_batch(data_key, i, train_images)
return opt.update(loss.apply, state, batch, key=loss_key)
return lax.fori_loop(0, num_batches, body_fun, state)
example_key = PRNGKey(0)
example_batch = binarize_batch(example_key, 0, images=train_images)
shaped_elbo = loss.shaped(example_batch)
init_parameters = shaped_elbo.init_parameters(key=PRNGKey(2))
state = opt.init(init_parameters)
for epoch in range(num_epochs):
tic = time.time()
state = run_epoch(PRNGKey(epoch), state)
params = opt.get_parameters(state)
test_elbo, samples = evaluate.apply_from({shaped_elbo: params}, test_images, key=test_key,
jit=True)
print(f'Epoch {epoch: 3d} {test_elbo:.3f} ({time.time() - tic:.3f} sec)')
from matplotlib import pyplot as plt
plt.imshow(samples, cmap=plt.cm.gray)
plt.show()
示例12: build_tree
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import bernoulli [as 别名]
def build_tree(verlet_update, kinetic_fn, verlet_state, inverse_mass_matrix, step_size, rng_key,
max_delta_energy=1000., max_tree_depth=10):
"""
Builds a binary tree from the `verlet_state`. This is used in NUTS sampler.
**References:**
1. *The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo*,
Matthew D. Hoffman, Andrew Gelman
2. *A Conceptual Introduction to Hamiltonian Monte Carlo*,
Michael Betancourt
:param verlet_update: A callable to get a new integrator state given a current
integrator state.
:param kinetic_fn: A callable to compute kinetic energy.
:param verlet_state: Initial integrator state.
:param inverse_mass_matrix: Inverse of the mass matrix.
:param float step_size: Step size for the current trajectory.
:param jax.random.PRNGKey rng_key: random key to be used as the source of
randomness.
:param float max_delta_energy: A threshold to decide if the new state diverges
(based on the energy difference) too much from the initial integrator state.
:return: information of the tree.
:rtype: :data:`TreeInfo`
"""
z, r, potential_energy, z_grad = verlet_state
energy_current = potential_energy + kinetic_fn(inverse_mass_matrix, r)
r_ckpts = jnp.zeros((max_tree_depth, inverse_mass_matrix.shape[-1]))
r_sum_ckpts = jnp.zeros((max_tree_depth, inverse_mass_matrix.shape[-1]))
tree = TreeInfo(z, r, z_grad, z, r, z_grad, z, potential_energy, z_grad, energy_current,
depth=0, weight=0., r_sum=r, turning=False, diverging=False,
sum_accept_probs=0., num_proposals=0)
def _cond_fn(state):
tree, _ = state
return (tree.depth < max_tree_depth) & ~tree.turning & ~tree.diverging
def _body_fn(state):
tree, key = state
key, direction_key, doubling_key = random.split(key, 3)
going_right = random.bernoulli(direction_key)
tree = _double_tree(tree, verlet_update, kinetic_fn, inverse_mass_matrix, step_size,
going_right, doubling_key, energy_current, max_delta_energy,
r_ckpts, r_sum_ckpts)
return tree, key
state = (tree, rng_key)
tree, _ = while_loop(_cond_fn, _body_fn, state)
return tree