当前位置: 首页>>代码示例>>Python>>正文


Python random.split方法代码示例

本文整理汇总了Python中jax.random.split方法的典型用法代码示例。如果您正苦于以下问题:Python random.split方法的具体用法?Python random.split怎么用?Python random.split使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在jax.random的用法示例。


在下文中一共展示了random.split方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: GatedResnet

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import split [as 别名]
def GatedResnet(Conv=None, nonlinearity=concat_elu, dropout_p=0.):
    @parametrized
    def gated_resnet(inputs, aux=None):
        chan = inputs.shape[-1]
        c1 = Conv(chan)(nonlinearity(inputs))
        if aux is not None:
            c1 = c1 + NIN(chan)(nonlinearity(aux))
        c1 = nonlinearity(c1)
        if dropout_p > 0:
            c1 = Dropout(rate=dropout_p)(c1)
        c2 = Conv(2 * chan, init_scale=0.1)(c1)
        a, b = jnp.split(c2, 2, axis=-1)
        c3 = a * sigmoid(b)
        return inputs + c3

    return gated_resnet 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:18,代码来源:pixelcnn.py

示例2: main

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import split [as 别名]
def main(batch_size=32, nr_filters=8, epochs=10, step_size=.001, decay_rate=.999995,
         model_path=Path('./pixelcnn.params')):
    loss, _ = PixelCNNPP(nr_filters=nr_filters)
    get_train_batches, test_batches = dataset(batch_size)
    key, init_key = random.split(PRNGKey(0))
    opt = Adam(exponential_decay(step_size, 1, decay_rate))
    state = opt.init(loss.init_parameters(next(test_batches), key=init_key))

    for epoch in range(epochs):
        for batch in get_train_batches():
            key, update_key = random.split(key)
            i = opt.get_step(state)

            state, train_loss = opt.update_and_get_loss(loss.apply, state, batch, key=update_key,
                                                        jit=True)

            if i % 100 == 0 or i < 10:
                key, test_key = random.split(key)
                test_loss = loss.apply(opt.get_parameters(state), next(test_batches), key=test_key,
                                       jit=True)
                print(f"Epoch {epoch}, iteration {i}, "
                      f"train loss {train_loss:.3f}, "
                      f"test loss {test_loss:.3f} ")

        save(opt.get_parameters(state), model_path) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:27,代码来源:pixelcnn.py

示例3: make_dataset

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import split [as 别名]
def make_dataset(rng_key) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """
    Make simulated dataset where potential customers who get a
    sales calls have ~2% higher chance of making another purchase.
    """
    key1, key2, key3 = random.split(rng_key, 3)

    num_calls = 51342
    num_no_calls = 48658

    made_purchase_got_called = dist.Bernoulli(0.084).sample(key1, sample_shape=(num_calls,))
    made_purchase_no_calls = dist.Bernoulli(0.061).sample(key2, sample_shape=(num_no_calls,))

    made_purchase = jnp.concatenate([made_purchase_got_called, made_purchase_no_calls])

    is_female = dist.Bernoulli(0.5).sample(key3, sample_shape=(num_calls + num_no_calls,))
    got_called = jnp.concatenate([jnp.ones(num_calls), jnp.zeros(num_no_calls)])
    design_matrix = jnp.hstack([jnp.ones((num_no_calls + num_calls, 1)),
                               got_called.reshape(-1, 1),
                               is_female.reshape(-1, 1)])

    return design_matrix, made_purchase 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:24,代码来源:proportion_test.py

示例4: main

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import split [as 别名]
def main(args):
    _, fetch_train = load_dataset(UCBADMIT, split='train', shuffle=False)
    dept, male, applications, admit = fetch_train()
    rng_key, rng_key_predict = random.split(random.PRNGKey(1))
    zs = run_inference(dept, male, applications, admit, rng_key, args)
    pred_probs = Predictive(glmm, zs)(rng_key_predict, dept, male, applications)['probs']
    header = '=' * 30 + 'glmm - TRAIN' + '=' * 30
    print_results(header, pred_probs, dept, male, admit / applications)

    # make plots
    fig, ax = plt.subplots(1, 1)

    ax.plot(range(1, 13), admit / applications, "o", ms=7, label="actual rate")
    ax.errorbar(range(1, 13), jnp.mean(pred_probs, 0), jnp.std(pred_probs, 0),
                fmt="o", c="k", mfc="none", ms=7, elinewidth=1, label=r"mean $\pm$ std")
    ax.plot(range(1, 13), jnp.percentile(pred_probs, 5, 0), "k+")
    ax.plot(range(1, 13), jnp.percentile(pred_probs, 95, 0), "k+")
    ax.set(xlabel="cases", ylabel="admit rate", title="Posterior Predictive Check with 90% CI")
    ax.legend()

    plt.savefig("ucbadmit_plot.pdf")
    plt.tight_layout() 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:24,代码来源:ucbadmit.py

示例5: main

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import split [as 别名]
def main(args):
    _, fetch_train = load_dataset(BASEBALL, split='train', shuffle=False)
    train, player_names = fetch_train()
    _, fetch_test = load_dataset(BASEBALL, split='test', shuffle=False)
    test, _ = fetch_test()
    at_bats, hits = train[:, 0], train[:, 1]
    season_at_bats, season_hits = test[:, 0], test[:, 1]
    for i, model in enumerate((fully_pooled,
                               not_pooled,
                               partially_pooled,
                               partially_pooled_with_logit,
                               )):
        rng_key, rng_key_predict = random.split(random.PRNGKey(i + 1))
        zs = run_inference(model, at_bats, hits, rng_key, args)
        predict(model, at_bats, hits, zs, rng_key_predict, player_names)
        predict(model, season_at_bats, season_hits, zs, rng_key_predict, player_names, train=False) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:18,代码来源:baseball.py

示例6: __call__

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import split [as 别名]
def __call__(self, rng_key, *args, **kwargs):
        """
        Returns dict of samples from the predictive distribution. By default, only sample sites not
        contained in `posterior_samples` are returned. This can be modified by changing the
        `return_sites` keyword argument of this :class:`Predictive` instance.

        :param jax.random.PRNGKey rng_key: random key to draw samples.
        :param args: model arguments.
        :param kwargs: model kwargs.
        """
        posterior_samples = self.posterior_samples
        if self.guide is not None:
            rng_key, guide_rng_key = random.split(rng_key)
            # use return_sites='' as a special signal to return all sites
            guide = substitute(self.guide, self.params)
            posterior_samples = _predictive(guide_rng_key, guide, posterior_samples,
                                            self.num_samples, return_sites='', parallel=self.parallel,
                                            model_args=args, model_kwargs=kwargs)
        model = substitute(self.model, self.params)
        return _predictive(rng_key, model, posterior_samples, self.num_samples,
                           return_sites=self.return_sites, parallel=self.parallel,
                           model_args=args, model_kwargs=kwargs) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:24,代码来源:util.py

示例7: update

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import split [as 别名]
def update(self, svi_state, *args, **kwargs):
        """
        Take a single step of SVI (possibly on a batch / minibatch of data),
        using the optimizer.

        :param svi_state: current state of SVI.
        :param args: arguments to the model / guide (these can possibly vary during
            the course of fitting).
        :param kwargs: keyword arguments to the model / guide (these can possibly vary
            during the course of fitting).
        :return: tuple of `(svi_state, loss)`.
        """
        rng_key, rng_key_step = random.split(svi_state.rng_key)
        params = self.optim.get_params(svi_state.optim_state)
        loss_val, grads = value_and_grad(
            lambda x: self.loss.loss(rng_key_step, self.constrain_fn(x), self.model, self.guide,
                                     *args, **kwargs, **self.static_kwargs))(params)
        optim_state = self.optim.update(grads, svi_state.optim_state)
        return SVIState(optim_state, rng_key), loss_val 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:21,代码来源:svi.py

示例8: evaluate

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import split [as 别名]
def evaluate(self, svi_state, *args, **kwargs):
        """
        Take a single step of SVI (possibly on a batch / minibatch of data).

        :param svi_state: current state of SVI.
        :param args: arguments to the model / guide (these can possibly vary during
            the course of fitting).
        :param kwargs: keyword arguments to the model / guide.
        :return: evaluate ELBO loss given the current parameter values
            (held within `svi_state.optim_state`).
        """
        # we split to have the same seed as `update_fn` given an svi_state
        _, rng_key_eval = random.split(svi_state.rng_key)
        params = self.get_params(svi_state)
        return self.loss.loss(rng_key_eval, params, self.model, self.guide,
                              *args, **kwargs, **self.static_kwargs) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:18,代码来源:svi.py

示例9: _binomial_inversion

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import split [as 别名]
def _binomial_inversion(key, p, n):
    def _binom_inv_body_fn(val):
        i, key, geom_acc = val
        key, key_u = random.split(key)
        u = random.uniform(key_u)
        geom = jnp.floor(jnp.log1p(-u) / log1_p) + 1
        geom_acc = geom_acc + geom
        return i + 1, key, geom_acc

    def _binom_inv_cond_fn(val):
        i, _, geom_acc = val
        return geom_acc <= n

    log1_p = jnp.log1p(-p)
    ret = lax.while_loop(_binom_inv_cond_fn, _binom_inv_body_fn,
                         (-1, key, 0.))
    return ret[0] 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:19,代码来源:util.py

示例10: _onion

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import split [as 别名]
def _onion(self, key, size):
        key_beta, key_normal = random.split(key)
        # Now we generate w term in Algorithm 3.2 of [1].
        beta_sample = self._beta.sample(key_beta, size)
        # The following Normal distribution is used to create a uniform distribution on
        # a hypershere (ref: http://mathworld.wolfram.com/HyperspherePointPicking.html)
        normal_sample = random.normal(
            key_normal,
            shape=size + self.batch_shape + (self.dimension * (self.dimension - 1) // 2,)
        )
        normal_sample = vec_to_tril_matrix(normal_sample, diagonal=0)
        u_hypershere = normal_sample / jnp.linalg.norm(normal_sample, axis=-1, keepdims=True)
        w = jnp.expand_dims(jnp.sqrt(beta_sample), axis=-1) * u_hypershere

        # put w into the off-diagonal triangular part
        cholesky = ops.index_add(jnp.zeros(size + self.batch_shape + self.event_shape),
                                 ops.index[..., 1:, :-1], w)
        # correct the diagonal
        # NB: we clip due to numerical precision
        diag = jnp.sqrt(jnp.clip(1 - jnp.sum(cholesky ** 2, axis=-1), a_min=0.))
        cholesky = cholesky + jnp.expand_dims(diag, axis=-1) * jnp.identity(self.dimension)
        return cholesky 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:24,代码来源:continuous.py

示例11: test_functional_map

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import split [as 别名]
def test_functional_map(algo, map_fn):
    if map_fn is pmap and xla_bridge.device_count() == 1:
        pytest.skip('pmap test requires device_count greater than 1.')

    true_mean, true_std = 1., 2.
    warmup_steps, num_samples = 1000, 8000

    def potential_fn(z):
        return 0.5 * jnp.sum(((z - true_mean) / true_std) ** 2)

    init_kernel, sample_kernel = hmc(potential_fn, algo=algo)
    init_params = jnp.array([0., -1.])
    rng_keys = random.split(random.PRNGKey(0), 2)

    init_kernel_map = map_fn(lambda init_param, rng_key: init_kernel(
        init_param, trajectory_length=9, num_warmup=warmup_steps, rng_key=rng_key))
    init_states = init_kernel_map(init_params, rng_keys)

    fori_collect_map = map_fn(lambda hmc_state: fori_collect(0, num_samples, sample_kernel, hmc_state,
                                                             transform=lambda x: x.z, progbar=False))
    chain_samples = fori_collect_map(init_states)

    assert_allclose(jnp.mean(chain_samples, axis=1), jnp.repeat(true_mean, 2), rtol=0.06)
    assert_allclose(jnp.std(chain_samples, axis=1), jnp.repeat(true_std, 2), rtol=0.06) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:26,代码来源:test_mcmc.py

示例12: test_initialize_model_dirichlet_categorical

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import split [as 别名]
def test_initialize_model_dirichlet_categorical(init_strategy):
    def model(data):
        concentration = jnp.array([1.0, 1.0, 1.0])
        p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration))
        numpyro.sample('obs', dist.Categorical(p_latent), obs=data)
        return p_latent

    true_probs = jnp.array([0.1, 0.6, 0.3])
    data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000,))

    rng_keys = random.split(random.PRNGKey(1), 2)
    init_params, _, _, _ = initialize_model(rng_keys, model,
                                            init_strategy=init_strategy,
                                            model_args=(data,))
    for i in range(2):
        init_params_i, _, _, _ = initialize_model(rng_keys[i], model,
                                                  init_strategy=init_strategy,
                                                  model_args=(data,))
        for name, p in init_params[0].items():
            # XXX: the result is equal if we disable fast-math-mode
            assert_allclose(p[i], init_params_i[0][name], atol=1e-6) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:23,代码来源:test_infer_util.py

示例13: evaluate_policy

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import split [as 别名]
def evaluate_policy(eval_env,
                    get_predictions,
                    temperatures,
                    max_timestep=20000,
                    n_evals=1,
                    len_history_for_policy=32,
                    rng=None):
  """Evaluate the policy."""

  processed_reward_sums = collections.defaultdict(list)
  raw_reward_sums = collections.defaultdict(list)
  for eval_rng in jax_random.split(rng, num=n_evals):
    for temperature in temperatures:
      trajs, _, _ = env_problem_utils.play_env_problem_with_policy(
          eval_env,
          get_predictions,
          num_trajectories=eval_env.batch_size,
          max_timestep=max_timestep,
          reset=True,
          policy_sampling=env_problem_utils.GUMBEL_SAMPLING,
          temperature=temperature,
          rng=eval_rng,
          len_history_for_policy=len_history_for_policy)
      processed_reward_sums[temperature].extend(sum(traj[2]) for traj in trajs)
      raw_reward_sums[temperature].extend(sum(traj[3]) for traj in trajs)

  # Return the mean and standard deviation for each temperature.
  def compute_stats(reward_dict):
    return {
        temperature: {"mean": onp.mean(rewards), "std": onp.std(rewards)}
        for (temperature, rewards) in reward_dict.items()
    }
  return {
      "processed": compute_stats(processed_reward_sums),
      "raw": compute_stats(raw_reward_sums),
  } 
开发者ID:yyht,项目名称:BERT,代码行数:38,代码来源:ppo.py

示例14: split

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import split [as 别名]
def split(self, prng, num=2):
    return backend()["random_split"](prng, num) 
开发者ID:yyht,项目名称:BERT,代码行数:4,代码来源:backend.py

示例15: get_batch

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import split [as 别名]
def get_batch(input_size, output_size, batch_size, key):
  key, split = random.split(key)
  # jax.random will always generate float32 even if jax_enable_x64==True.
  xs = random.normal(split, shape=(batch_size, input_size),
                     dtype=canonicalize_dtype(onp.float64))
  key, split = random.split(key)
  ys = random.randint(split, minval=0, maxval=output_size, shape=(batch_size,))
  ys = to_onehot(ys, output_size)
  return (xs, ys), key 
开发者ID:google,项目名称:spectral-density,代码行数:11,代码来源:spectral_density_test.py


注:本文中的jax.random.split方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。