当前位置: 首页>>代码示例>>Python>>正文


Python random.bernoulli方法代码示例

本文整理汇总了Python中jax.random.bernoulli方法的典型用法代码示例。如果您正苦于以下问题:Python random.bernoulli方法的具体用法?Python random.bernoulli怎么用?Python random.bernoulli使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在jax.random的用法示例。


在下文中一共展示了random.bernoulli方法的12个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: bernoulli

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import bernoulli [as 别名]
def bernoulli(self, *args, **kwargs):
    return backend()["random_bernoulli"](*args, **kwargs) 
开发者ID:yyht,项目名称:BERT,代码行数:4,代码来源:backend.py

示例2: Dropout

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import bernoulli [as 别名]
def Dropout(rate, test_mode=False):
    """Constructor for a dropout function with given rate."""
    rate = np.array(rate)

    @parametrized
    def dropout(inputs):
        if test_mode or rate == 0:
            return inputs

        keep_rate = 1 - rate
        keep = random.bernoulli(random_key(), keep_rate, inputs.shape)
        return np.where(keep, inputs / keep_rate, 0)

    return dropout 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:16,代码来源:modules.py

示例3: image_sample_grid

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import bernoulli [as 别名]
def image_sample_grid(nrow=10, ncol=10):
    """Sample images from the generative model."""
    logits = decode(random.normal(random_key(), (nrow * ncol, 10)))
    sampled_images = random.bernoulli(random_key(), np.logaddexp(0., logits))
    return image_grid(nrow, ncol, sampled_images, (28, 28)) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:7,代码来源:mnist_vae.py

示例4: evaluate

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import bernoulli [as 别名]
def evaluate(images):
    binarized_test = random.bernoulli(random_key(), images)
    return loss(binarized_test), image_sample_grid() 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:5,代码来源:mnist_vae.py

示例5: dropout

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import bernoulli [as 别名]
def dropout(self, x, p, seed=None):
        seed = next(self.rng)
        p = 1 - p
        keep = random.bernoulli(seed, p, x.shape)
        return np.where(keep, x / p, 0) 
开发者ID:sharadmv,项目名称:deepx,代码行数:7,代码来源:jax.py

示例6: binarize

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import bernoulli [as 别名]
def binarize(rng_key, batch):
    return random.bernoulli(rng_key, batch).astype(batch.dtype) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:4,代码来源:vae.py

示例7: _combine_tree

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import bernoulli [as 别名]
def _combine_tree(current_tree, new_tree, inverse_mass_matrix, going_right, rng_key, biased_transition):
    # Now we combine the current tree and the new tree. Note that outside
    # leaves of the combined tree are determined by the direction.
    z_left, r_left, z_left_grad, z_right, r_right, r_right_grad = cond(
        going_right,
        (current_tree, new_tree),
        lambda trees: (trees[0].z_left, trees[0].r_left,
                       trees[0].z_left_grad, trees[1].z_right,
                       trees[1].r_right, trees[1].z_right_grad),
        (new_tree, current_tree),
        lambda trees: (trees[0].z_left, trees[0].r_left,
                       trees[0].z_left_grad, trees[1].z_right,
                       trees[1].r_right, trees[1].z_right_grad)
    )
    r_sum = tree_multimap(jnp.add, current_tree.r_sum, new_tree.r_sum)

    if biased_transition:
        transition_prob = _biased_transition_kernel(current_tree, new_tree)
        turning = new_tree.turning | _is_turning(inverse_mass_matrix, r_left, r_right, r_sum)
    else:
        transition_prob = _uniform_transition_kernel(current_tree, new_tree)
        turning = current_tree.turning

    transition = random.bernoulli(rng_key, transition_prob)
    z_proposal, z_proposal_pe, z_proposal_grad, z_proposal_energy = cond(
        transition,
        new_tree, lambda tree: (tree.z_proposal, tree.z_proposal_pe, tree.z_proposal_grad, tree.z_proposal_energy),
        current_tree, lambda tree: (tree.z_proposal, tree.z_proposal_pe, tree.z_proposal_grad, tree.z_proposal_energy)
    )

    tree_depth = current_tree.depth + 1
    tree_weight = jnp.logaddexp(current_tree.weight, new_tree.weight)
    diverging = new_tree.diverging

    sum_accept_probs = current_tree.sum_accept_probs + new_tree.sum_accept_probs
    num_proposals = current_tree.num_proposals + new_tree.num_proposals

    return TreeInfo(z_left, r_left, z_left_grad, z_right, r_right, r_right_grad,
                    z_proposal, z_proposal_pe, z_proposal_grad, z_proposal_energy,
                    tree_depth, tree_weight, r_sum, turning, diverging,
                    sum_accept_probs, num_proposals) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:43,代码来源:hmc_util.py

示例8: sample

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import bernoulli [as 别名]
def sample(self, key, sample_shape=()):
        return random.bernoulli(key, self.probs, shape=sample_shape + self.batch_shape) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:4,代码来源:discrete.py

示例9: gen_values_within_bounds

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import bernoulli [as 别名]
def gen_values_within_bounds(constraint, size, key=random.PRNGKey(11)):
    eps = 1e-6

    if isinstance(constraint, constraints._Boolean):
        return random.bernoulli(key, shape=size)
    elif isinstance(constraint, constraints._GreaterThan):
        return jnp.exp(random.normal(key, size)) + constraint.lower_bound + eps
    elif isinstance(constraint, constraints._IntegerInterval):
        lower_bound = jnp.broadcast_to(constraint.lower_bound, size)
        upper_bound = jnp.broadcast_to(constraint.upper_bound, size)
        return random.randint(key, size, lower_bound, upper_bound + 1)
    elif isinstance(constraint, constraints._IntegerGreaterThan):
        return constraint.lower_bound + random.poisson(key, np.array(5), shape=size)
    elif isinstance(constraint, constraints._Interval):
        lower_bound = jnp.broadcast_to(constraint.lower_bound, size)
        upper_bound = jnp.broadcast_to(constraint.upper_bound, size)
        return random.uniform(key, size, minval=lower_bound, maxval=upper_bound)
    elif isinstance(constraint, (constraints._Real, constraints._RealVector)):
        return random.normal(key, size)
    elif isinstance(constraint, constraints._Simplex):
        return osp.dirichlet.rvs(alpha=jnp.ones((size[-1],)), size=size[:-1])
    elif isinstance(constraint, constraints._Multinomial):
        n = size[-1]
        return multinomial(key, p=jnp.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1])
    elif isinstance(constraint, constraints._CorrCholesky):
        return signed_stick_breaking_tril(
            random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1))
    elif isinstance(constraint, constraints._CorrMatrix):
        cholesky = signed_stick_breaking_tril(
            random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1))
        return jnp.matmul(cholesky, jnp.swapaxes(cholesky, -2, -1))
    elif isinstance(constraint, constraints._LowerCholesky):
        return jnp.tril(random.uniform(key, size))
    elif isinstance(constraint, constraints._PositiveDefinite):
        x = random.normal(key, size)
        return jnp.matmul(x, jnp.swapaxes(x, -2, -1))
    elif isinstance(constraint, constraints._OrderedVector):
        x = jnp.cumsum(random.exponential(key, size), -1)
        return x - random.normal(key, size[:-1])
    else:
        raise NotImplementedError('{} not implemented.'.format(constraint)) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:43,代码来源:test_distributions.py

示例10: gen_values_outside_bounds

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import bernoulli [as 别名]
def gen_values_outside_bounds(constraint, size, key=random.PRNGKey(11)):
    if isinstance(constraint, constraints._Boolean):
        return random.bernoulli(key, shape=size) - 2
    elif isinstance(constraint, constraints._GreaterThan):
        return constraint.lower_bound - jnp.exp(random.normal(key, size))
    elif isinstance(constraint, constraints._IntegerInterval):
        lower_bound = jnp.broadcast_to(constraint.lower_bound, size)
        return random.randint(key, size, lower_bound - 1, lower_bound)
    elif isinstance(constraint, constraints._IntegerGreaterThan):
        return constraint.lower_bound - random.poisson(key, np.array(5), shape=size)
    elif isinstance(constraint, constraints._Interval):
        upper_bound = jnp.broadcast_to(constraint.upper_bound, size)
        return random.uniform(key, size, minval=upper_bound, maxval=upper_bound + 1.)
    elif isinstance(constraint, (constraints._Real, constraints._RealVector)):
        return lax.full(size, jnp.nan)
    elif isinstance(constraint, constraints._Simplex):
        return osp.dirichlet.rvs(alpha=jnp.ones((size[-1],)), size=size[:-1]) + 1e-2
    elif isinstance(constraint, constraints._Multinomial):
        n = size[-1]
        return multinomial(key, p=jnp.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1]) + 1
    elif isinstance(constraint, constraints._CorrCholesky):
        return signed_stick_breaking_tril(
            random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,),
                           minval=-1, maxval=1)) + 1e-2
    elif isinstance(constraint, constraints._CorrMatrix):
        cholesky = 1e-2 + signed_stick_breaking_tril(
            random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1))
        return jnp.matmul(cholesky, jnp.swapaxes(cholesky, -2, -1))
    elif isinstance(constraint, constraints._LowerCholesky):
        return random.uniform(key, size)
    elif isinstance(constraint, constraints._PositiveDefinite):
        return random.normal(key, size)
    elif isinstance(constraint, constraints._OrderedVector):
        x = jnp.cumsum(random.exponential(key, size), -1)
        return x[..., ::-1]
    else:
        raise NotImplementedError('{} not implemented.'.format(constraint)) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:39,代码来源:test_distributions.py

示例11: main

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import bernoulli [as 别名]
def main():
    step_size = 0.001
    num_epochs = 100
    batch_size = 32
    test_key = PRNGKey(1)  # get reconstructions for a *fixed* latent variable sample over time

    train_images, test_images = mnist_images()
    num_complete_batches, leftover = divmod(train_images.shape[0], batch_size)
    num_batches = num_complete_batches + bool(leftover)
    opt = optimizers.Momentum(step_size, mass=0.9)

    @jit
    def binarize_batch(key, i, images):
        i = i % num_batches
        batch = lax.dynamic_slice_in_dim(images, i * batch_size, batch_size)
        return random.bernoulli(key, batch)

    @jit
    def run_epoch(key, state):
        def body_fun(i, state):
            loss_key, data_key = random.split(random.fold_in(key, i))
            batch = binarize_batch(data_key, i, train_images)
            return opt.update(loss.apply, state, batch, key=loss_key)

        return lax.fori_loop(0, num_batches, body_fun, state)

    example_key = PRNGKey(0)
    example_batch = binarize_batch(example_key, 0, images=train_images)
    shaped_elbo = loss.shaped(example_batch)
    init_parameters = shaped_elbo.init_parameters(key=PRNGKey(2))
    state = opt.init(init_parameters)

    for epoch in range(num_epochs):
        tic = time.time()
        state = run_epoch(PRNGKey(epoch), state)
        params = opt.get_parameters(state)
        test_elbo, samples = evaluate.apply_from({shaped_elbo: params}, test_images, key=test_key,
                                                 jit=True)
        print(f'Epoch {epoch: 3d} {test_elbo:.3f} ({time.time() - tic:.3f} sec)')
        from matplotlib import pyplot as plt
        plt.imshow(samples, cmap=plt.cm.gray)
        plt.show() 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:44,代码来源:mnist_vae.py

示例12: build_tree

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import bernoulli [as 别名]
def build_tree(verlet_update, kinetic_fn, verlet_state, inverse_mass_matrix, step_size, rng_key,
               max_delta_energy=1000., max_tree_depth=10):
    """
    Builds a binary tree from the `verlet_state`. This is used in NUTS sampler.

    **References:**

    1. *The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo*,
       Matthew D. Hoffman, Andrew Gelman
    2. *A Conceptual Introduction to Hamiltonian Monte Carlo*,
       Michael Betancourt

    :param verlet_update: A callable to get a new integrator state given a current
        integrator state.
    :param kinetic_fn: A callable to compute kinetic energy.
    :param verlet_state: Initial integrator state.
    :param inverse_mass_matrix: Inverse of the mass matrix.
    :param float step_size: Step size for the current trajectory.
    :param jax.random.PRNGKey rng_key: random key to be used as the source of
        randomness.
    :param float max_delta_energy: A threshold to decide if the new state diverges
        (based on the energy difference) too much from the initial integrator state.
    :return: information of the tree.
    :rtype: :data:`TreeInfo`
    """
    z, r, potential_energy, z_grad = verlet_state
    energy_current = potential_energy + kinetic_fn(inverse_mass_matrix, r)
    r_ckpts = jnp.zeros((max_tree_depth, inverse_mass_matrix.shape[-1]))
    r_sum_ckpts = jnp.zeros((max_tree_depth, inverse_mass_matrix.shape[-1]))

    tree = TreeInfo(z, r, z_grad, z, r, z_grad, z, potential_energy, z_grad, energy_current,
                    depth=0, weight=0., r_sum=r, turning=False, diverging=False,
                    sum_accept_probs=0., num_proposals=0)

    def _cond_fn(state):
        tree, _ = state
        return (tree.depth < max_tree_depth) & ~tree.turning & ~tree.diverging

    def _body_fn(state):
        tree, key = state
        key, direction_key, doubling_key = random.split(key, 3)
        going_right = random.bernoulli(direction_key)
        tree = _double_tree(tree, verlet_update, kinetic_fn, inverse_mass_matrix, step_size,
                            going_right, doubling_key, energy_current, max_delta_energy,
                            r_ckpts, r_sum_ckpts)
        return tree, key

    state = (tree, rng_key)
    tree, _ = while_loop(_cond_fn, _body_fn, state)
    return tree 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:52,代码来源:hmc_util.py


注:本文中的jax.random.bernoulli方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。