本文整理汇总了Python中jax.random.split方法的典型用法代码示例。如果您正苦于以下问题:Python random.split方法的具体用法?Python random.split怎么用?Python random.split使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类jax.random
的用法示例。
在下文中一共展示了random.split方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: GatedResnet
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import split [as 别名]
def GatedResnet(Conv=None, nonlinearity=concat_elu, dropout_p=0.):
@parametrized
def gated_resnet(inputs, aux=None):
chan = inputs.shape[-1]
c1 = Conv(chan)(nonlinearity(inputs))
if aux is not None:
c1 = c1 + NIN(chan)(nonlinearity(aux))
c1 = nonlinearity(c1)
if dropout_p > 0:
c1 = Dropout(rate=dropout_p)(c1)
c2 = Conv(2 * chan, init_scale=0.1)(c1)
a, b = jnp.split(c2, 2, axis=-1)
c3 = a * sigmoid(b)
return inputs + c3
return gated_resnet
示例2: main
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import split [as 别名]
def main(batch_size=32, nr_filters=8, epochs=10, step_size=.001, decay_rate=.999995,
model_path=Path('./pixelcnn.params')):
loss, _ = PixelCNNPP(nr_filters=nr_filters)
get_train_batches, test_batches = dataset(batch_size)
key, init_key = random.split(PRNGKey(0))
opt = Adam(exponential_decay(step_size, 1, decay_rate))
state = opt.init(loss.init_parameters(next(test_batches), key=init_key))
for epoch in range(epochs):
for batch in get_train_batches():
key, update_key = random.split(key)
i = opt.get_step(state)
state, train_loss = opt.update_and_get_loss(loss.apply, state, batch, key=update_key,
jit=True)
if i % 100 == 0 or i < 10:
key, test_key = random.split(key)
test_loss = loss.apply(opt.get_parameters(state), next(test_batches), key=test_key,
jit=True)
print(f"Epoch {epoch}, iteration {i}, "
f"train loss {train_loss:.3f}, "
f"test loss {test_loss:.3f} ")
save(opt.get_parameters(state), model_path)
示例3: make_dataset
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import split [as 别名]
def make_dataset(rng_key) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Make simulated dataset where potential customers who get a
sales calls have ~2% higher chance of making another purchase.
"""
key1, key2, key3 = random.split(rng_key, 3)
num_calls = 51342
num_no_calls = 48658
made_purchase_got_called = dist.Bernoulli(0.084).sample(key1, sample_shape=(num_calls,))
made_purchase_no_calls = dist.Bernoulli(0.061).sample(key2, sample_shape=(num_no_calls,))
made_purchase = jnp.concatenate([made_purchase_got_called, made_purchase_no_calls])
is_female = dist.Bernoulli(0.5).sample(key3, sample_shape=(num_calls + num_no_calls,))
got_called = jnp.concatenate([jnp.ones(num_calls), jnp.zeros(num_no_calls)])
design_matrix = jnp.hstack([jnp.ones((num_no_calls + num_calls, 1)),
got_called.reshape(-1, 1),
is_female.reshape(-1, 1)])
return design_matrix, made_purchase
示例4: main
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import split [as 别名]
def main(args):
_, fetch_train = load_dataset(UCBADMIT, split='train', shuffle=False)
dept, male, applications, admit = fetch_train()
rng_key, rng_key_predict = random.split(random.PRNGKey(1))
zs = run_inference(dept, male, applications, admit, rng_key, args)
pred_probs = Predictive(glmm, zs)(rng_key_predict, dept, male, applications)['probs']
header = '=' * 30 + 'glmm - TRAIN' + '=' * 30
print_results(header, pred_probs, dept, male, admit / applications)
# make plots
fig, ax = plt.subplots(1, 1)
ax.plot(range(1, 13), admit / applications, "o", ms=7, label="actual rate")
ax.errorbar(range(1, 13), jnp.mean(pred_probs, 0), jnp.std(pred_probs, 0),
fmt="o", c="k", mfc="none", ms=7, elinewidth=1, label=r"mean $\pm$ std")
ax.plot(range(1, 13), jnp.percentile(pred_probs, 5, 0), "k+")
ax.plot(range(1, 13), jnp.percentile(pred_probs, 95, 0), "k+")
ax.set(xlabel="cases", ylabel="admit rate", title="Posterior Predictive Check with 90% CI")
ax.legend()
plt.savefig("ucbadmit_plot.pdf")
plt.tight_layout()
示例5: main
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import split [as 别名]
def main(args):
_, fetch_train = load_dataset(BASEBALL, split='train', shuffle=False)
train, player_names = fetch_train()
_, fetch_test = load_dataset(BASEBALL, split='test', shuffle=False)
test, _ = fetch_test()
at_bats, hits = train[:, 0], train[:, 1]
season_at_bats, season_hits = test[:, 0], test[:, 1]
for i, model in enumerate((fully_pooled,
not_pooled,
partially_pooled,
partially_pooled_with_logit,
)):
rng_key, rng_key_predict = random.split(random.PRNGKey(i + 1))
zs = run_inference(model, at_bats, hits, rng_key, args)
predict(model, at_bats, hits, zs, rng_key_predict, player_names)
predict(model, season_at_bats, season_hits, zs, rng_key_predict, player_names, train=False)
示例6: __call__
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import split [as 别名]
def __call__(self, rng_key, *args, **kwargs):
"""
Returns dict of samples from the predictive distribution. By default, only sample sites not
contained in `posterior_samples` are returned. This can be modified by changing the
`return_sites` keyword argument of this :class:`Predictive` instance.
:param jax.random.PRNGKey rng_key: random key to draw samples.
:param args: model arguments.
:param kwargs: model kwargs.
"""
posterior_samples = self.posterior_samples
if self.guide is not None:
rng_key, guide_rng_key = random.split(rng_key)
# use return_sites='' as a special signal to return all sites
guide = substitute(self.guide, self.params)
posterior_samples = _predictive(guide_rng_key, guide, posterior_samples,
self.num_samples, return_sites='', parallel=self.parallel,
model_args=args, model_kwargs=kwargs)
model = substitute(self.model, self.params)
return _predictive(rng_key, model, posterior_samples, self.num_samples,
return_sites=self.return_sites, parallel=self.parallel,
model_args=args, model_kwargs=kwargs)
示例7: update
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import split [as 别名]
def update(self, svi_state, *args, **kwargs):
"""
Take a single step of SVI (possibly on a batch / minibatch of data),
using the optimizer.
:param svi_state: current state of SVI.
:param args: arguments to the model / guide (these can possibly vary during
the course of fitting).
:param kwargs: keyword arguments to the model / guide (these can possibly vary
during the course of fitting).
:return: tuple of `(svi_state, loss)`.
"""
rng_key, rng_key_step = random.split(svi_state.rng_key)
params = self.optim.get_params(svi_state.optim_state)
loss_val, grads = value_and_grad(
lambda x: self.loss.loss(rng_key_step, self.constrain_fn(x), self.model, self.guide,
*args, **kwargs, **self.static_kwargs))(params)
optim_state = self.optim.update(grads, svi_state.optim_state)
return SVIState(optim_state, rng_key), loss_val
示例8: evaluate
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import split [as 别名]
def evaluate(self, svi_state, *args, **kwargs):
"""
Take a single step of SVI (possibly on a batch / minibatch of data).
:param svi_state: current state of SVI.
:param args: arguments to the model / guide (these can possibly vary during
the course of fitting).
:param kwargs: keyword arguments to the model / guide.
:return: evaluate ELBO loss given the current parameter values
(held within `svi_state.optim_state`).
"""
# we split to have the same seed as `update_fn` given an svi_state
_, rng_key_eval = random.split(svi_state.rng_key)
params = self.get_params(svi_state)
return self.loss.loss(rng_key_eval, params, self.model, self.guide,
*args, **kwargs, **self.static_kwargs)
示例9: _binomial_inversion
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import split [as 别名]
def _binomial_inversion(key, p, n):
def _binom_inv_body_fn(val):
i, key, geom_acc = val
key, key_u = random.split(key)
u = random.uniform(key_u)
geom = jnp.floor(jnp.log1p(-u) / log1_p) + 1
geom_acc = geom_acc + geom
return i + 1, key, geom_acc
def _binom_inv_cond_fn(val):
i, _, geom_acc = val
return geom_acc <= n
log1_p = jnp.log1p(-p)
ret = lax.while_loop(_binom_inv_cond_fn, _binom_inv_body_fn,
(-1, key, 0.))
return ret[0]
示例10: _onion
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import split [as 别名]
def _onion(self, key, size):
key_beta, key_normal = random.split(key)
# Now we generate w term in Algorithm 3.2 of [1].
beta_sample = self._beta.sample(key_beta, size)
# The following Normal distribution is used to create a uniform distribution on
# a hypershere (ref: http://mathworld.wolfram.com/HyperspherePointPicking.html)
normal_sample = random.normal(
key_normal,
shape=size + self.batch_shape + (self.dimension * (self.dimension - 1) // 2,)
)
normal_sample = vec_to_tril_matrix(normal_sample, diagonal=0)
u_hypershere = normal_sample / jnp.linalg.norm(normal_sample, axis=-1, keepdims=True)
w = jnp.expand_dims(jnp.sqrt(beta_sample), axis=-1) * u_hypershere
# put w into the off-diagonal triangular part
cholesky = ops.index_add(jnp.zeros(size + self.batch_shape + self.event_shape),
ops.index[..., 1:, :-1], w)
# correct the diagonal
# NB: we clip due to numerical precision
diag = jnp.sqrt(jnp.clip(1 - jnp.sum(cholesky ** 2, axis=-1), a_min=0.))
cholesky = cholesky + jnp.expand_dims(diag, axis=-1) * jnp.identity(self.dimension)
return cholesky
示例11: test_functional_map
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import split [as 别名]
def test_functional_map(algo, map_fn):
if map_fn is pmap and xla_bridge.device_count() == 1:
pytest.skip('pmap test requires device_count greater than 1.')
true_mean, true_std = 1., 2.
warmup_steps, num_samples = 1000, 8000
def potential_fn(z):
return 0.5 * jnp.sum(((z - true_mean) / true_std) ** 2)
init_kernel, sample_kernel = hmc(potential_fn, algo=algo)
init_params = jnp.array([0., -1.])
rng_keys = random.split(random.PRNGKey(0), 2)
init_kernel_map = map_fn(lambda init_param, rng_key: init_kernel(
init_param, trajectory_length=9, num_warmup=warmup_steps, rng_key=rng_key))
init_states = init_kernel_map(init_params, rng_keys)
fori_collect_map = map_fn(lambda hmc_state: fori_collect(0, num_samples, sample_kernel, hmc_state,
transform=lambda x: x.z, progbar=False))
chain_samples = fori_collect_map(init_states)
assert_allclose(jnp.mean(chain_samples, axis=1), jnp.repeat(true_mean, 2), rtol=0.06)
assert_allclose(jnp.std(chain_samples, axis=1), jnp.repeat(true_std, 2), rtol=0.06)
示例12: test_initialize_model_dirichlet_categorical
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import split [as 别名]
def test_initialize_model_dirichlet_categorical(init_strategy):
def model(data):
concentration = jnp.array([1.0, 1.0, 1.0])
p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration))
numpyro.sample('obs', dist.Categorical(p_latent), obs=data)
return p_latent
true_probs = jnp.array([0.1, 0.6, 0.3])
data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000,))
rng_keys = random.split(random.PRNGKey(1), 2)
init_params, _, _, _ = initialize_model(rng_keys, model,
init_strategy=init_strategy,
model_args=(data,))
for i in range(2):
init_params_i, _, _, _ = initialize_model(rng_keys[i], model,
init_strategy=init_strategy,
model_args=(data,))
for name, p in init_params[0].items():
# XXX: the result is equal if we disable fast-math-mode
assert_allclose(p[i], init_params_i[0][name], atol=1e-6)
示例13: evaluate_policy
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import split [as 别名]
def evaluate_policy(eval_env,
get_predictions,
temperatures,
max_timestep=20000,
n_evals=1,
len_history_for_policy=32,
rng=None):
"""Evaluate the policy."""
processed_reward_sums = collections.defaultdict(list)
raw_reward_sums = collections.defaultdict(list)
for eval_rng in jax_random.split(rng, num=n_evals):
for temperature in temperatures:
trajs, _, _ = env_problem_utils.play_env_problem_with_policy(
eval_env,
get_predictions,
num_trajectories=eval_env.batch_size,
max_timestep=max_timestep,
reset=True,
policy_sampling=env_problem_utils.GUMBEL_SAMPLING,
temperature=temperature,
rng=eval_rng,
len_history_for_policy=len_history_for_policy)
processed_reward_sums[temperature].extend(sum(traj[2]) for traj in trajs)
raw_reward_sums[temperature].extend(sum(traj[3]) for traj in trajs)
# Return the mean and standard deviation for each temperature.
def compute_stats(reward_dict):
return {
temperature: {"mean": onp.mean(rewards), "std": onp.std(rewards)}
for (temperature, rewards) in reward_dict.items()
}
return {
"processed": compute_stats(processed_reward_sums),
"raw": compute_stats(raw_reward_sums),
}
示例14: split
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import split [as 别名]
def split(self, prng, num=2):
return backend()["random_split"](prng, num)
示例15: get_batch
# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import split [as 别名]
def get_batch(input_size, output_size, batch_size, key):
key, split = random.split(key)
# jax.random will always generate float32 even if jax_enable_x64==True.
xs = random.normal(split, shape=(batch_size, input_size),
dtype=canonicalize_dtype(onp.float64))
key, split = random.split(key)
ys = random.randint(split, minval=0, maxval=output_size, shape=(batch_size,))
ys = to_onehot(ys, output_size)
return (xs, ys), key