本文整理汇总了Python中jax.numpy.exp方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.exp方法的具体用法?Python numpy.exp怎么用?Python numpy.exp使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类jax.numpy
的用法示例。
在下文中一共展示了numpy.exp方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: masked_entropy
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import exp [as 别名]
def masked_entropy(log_probs, mask):
"""Computes the entropy for the given log-probs.
Args:
log_probs: (B, T+1, A) log probs
mask: (B, T) mask.
Returns:
Entropy.
"""
# Cut the last time-step out.
lp = log_probs[:, :-1]
# Mask out the irrelevant part.
lp *= mask[:, :, np.newaxis] # make mask (B, T, 1)
p = np.exp(lp) * mask[:, :, np.newaxis] # (B, T, 1)
# Average on non-masked part and take negative.
return -(np.sum(lp * p) / np.sum(mask))
示例2: print_results
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import exp [as 别名]
def print_results(posterior, dates):
def _print_row(values, row_name=''):
quantiles = jnp.array([0.2, 0.4, 0.5, 0.6, 0.8])
row_name_fmt = '{:>8}'
header_format = row_name_fmt + '{:>12}' * 5
row_format = row_name_fmt + '{:>12.3f}' * 5
columns = ['(p{})'.format(q * 100) for q in quantiles]
q_values = jnp.quantile(values, quantiles, axis=0)
print(header_format.format('', *columns))
print(row_format.format(row_name, *q_values))
print('\n')
print('=' * 20, 'sigma', '=' * 20)
_print_row(posterior['sigma'])
print('=' * 20, 'nu', '=' * 20)
_print_row(posterior['nu'])
print('=' * 20, 'volatility', '=' * 20)
for i in range(0, len(dates), 180):
_print_row(jnp.exp(posterior['s'][:, i]), dates[i])
示例3: _build_basetree
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import exp [as 别名]
def _build_basetree(vv_update, kinetic_fn, z, r, z_grad, inverse_mass_matrix, step_size, going_right,
energy_current, max_delta_energy):
step_size = jnp.where(going_right, step_size, -step_size)
z_new, r_new, potential_energy_new, z_new_grad = vv_update(
step_size,
inverse_mass_matrix,
(z, r, energy_current, z_grad),
)
energy_new = potential_energy_new + kinetic_fn(inverse_mass_matrix, r_new)
delta_energy = energy_new - energy_current
# Handles the NaN case.
delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf, delta_energy)
tree_weight = -delta_energy
diverging = delta_energy > max_delta_energy
accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0)
return TreeInfo(z_new, r_new, z_new_grad, z_new, r_new, z_new_grad,
z_new, potential_energy_new, z_new_grad, energy_new,
depth=0, weight=tree_weight, r_sum=r_new, turning=False,
diverging=diverging, sum_accept_probs=accept_prob, num_proposals=1)
示例4: test_correlated_mvn
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import exp [as 别名]
def test_correlated_mvn():
# This requires dense mass matrix estimation.
D = 5
warmup_steps, num_samples = 5000, 8000
true_mean = 0.
a = jnp.tril(0.5 * jnp.fliplr(jnp.eye(D)) + 0.1 * jnp.exp(random.normal(random.PRNGKey(0), shape=(D, D))))
true_cov = jnp.dot(a, a.T)
true_prec = jnp.linalg.inv(true_cov)
def potential_fn(z):
return 0.5 * jnp.dot(z.T, jnp.dot(true_prec, z))
init_params = jnp.zeros(D)
kernel = NUTS(potential_fn=potential_fn, dense_mass=True)
mcmc = MCMC(kernel, warmup_steps, num_samples)
mcmc.run(random.PRNGKey(0), init_params=init_params)
samples = mcmc.get_samples()
assert_allclose(jnp.mean(samples), true_mean, atol=0.02)
assert np.sum(np.abs(np.cov(samples.T) - true_cov)) / D**2 < 0.02
示例5: compute_probab_ratios
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import exp [as 别名]
def compute_probab_ratios(p_new, p_old, actions, reward_mask):
"""Computes the probability ratios for each time-step in a trajectory.
Args:
p_new: ndarray of shape [B, T+1, A] of the log-probabilities that the policy
network assigns to all the actions at each time-step in each batch using
the old parameters.
p_old: ndarray of shape [B, T+1, A], same as above, but using old policy
network parameters.
actions: ndarray of shape [B, T] where each element is from [0, A).
reward_mask: ndarray of shape [B, T] masking over probabilities.
Returns:
probab_ratios: ndarray of shape [B, T], where
probab_ratios_{b,t} = p_new_{b,t,action_{b,t}} / p_old_{b,t,action_{b,t}}
"""
B, T = actions.shape # pylint: disable=invalid-name
assert (B, T + 1) == p_old.shape[:2]
assert (B, T + 1) == p_new.shape[:2]
logp_old = chosen_probabs(p_old, actions)
logp_new = chosen_probabs(p_new, actions)
assert (B, T) == logp_old.shape
assert (B, T) == logp_new.shape
# Since these are log-probabilities, we just subtract them.
probab_ratios = np.exp(logp_new - logp_old) * reward_mask
assert (B, T) == probab_ratios.shape
return probab_ratios
示例6: tanh
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import exp [as 别名]
def tanh(self, x):
import jax.numpy as np
y = np.exp(-2.0 * x)
return (1.0 - y) / (1.0 + y)
示例7: sigmoid
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import exp [as 别名]
def sigmoid(self, x):
return 1 / (1 + np.exp(-x))
示例8: softmax
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import exp [as 别名]
def softmax(self, x, T=1.0):
unnormalized = np.exp(x - x.max(-1, keepdims=True))
return unnormalized / unnormalized.sum(-1, keepdims=True)
示例9: exp
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import exp [as 别名]
def exp(self, x):
return np.exp(x)
示例10: exp
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import exp [as 别名]
def exp(self, tensor_in):
return np.exp(tensor_in)
示例11: poisson
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import exp [as 别名]
def poisson(self, n, lam):
r"""
The continous approximation, using :math:`n! = \Gamma\left(n+1\right)`,
to the probability mass function of the Poisson distribution evaluated
at :code:`n` given the parameter :code:`lam`.
Example:
>>> import pyhf
>>> pyhf.set_backend("jax")
>>> pyhf.tensorlib.poisson(5., 6.)
DeviceArray(0.16062314, dtype=float64)
>>> values = pyhf.tensorlib.astensor([5., 9.])
>>> rates = pyhf.tensorlib.astensor([6., 8.])
>>> pyhf.tensorlib.poisson(values, rates)
DeviceArray([0.16062314, 0.12407692], dtype=float64)
Args:
n (`tensor` or `float`): The value at which to evaluate the approximation to the Poisson distribution p.m.f.
(the observed number of events)
lam (`tensor` or `float`): The mean of the Poisson distribution p.m.f.
(the expected number of events)
Returns:
JAX ndarray: Value of the continous approximation to Poisson(n|lam)
"""
n = np.asarray(n)
lam = np.asarray(lam)
return np.exp(n * np.log(lam) - lam - gammaln(n + 1.0))
示例12: main
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import exp [as 别名]
def main(args):
_, fetch = load_dataset(LYNXHARE, shuffle=False)
year, data = fetch() # data is in hare -> lynx order
# use dense_mass for better mixing rate
mcmc = MCMC(NUTS(model, dense_mass=True),
args.num_warmup, args.num_samples, num_chains=args.num_chains,
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
mcmc.run(PRNGKey(1), N=data.shape[0], y=jnp.log(data))
mcmc.print_summary()
# predict populations
y_pred = Predictive(model, mcmc.get_samples())(PRNGKey(2), data.shape[0])["y"]
pop_pred = jnp.exp(y_pred)
mu, pi = jnp.mean(pop_pred, 0), jnp.percentile(pop_pred, (10, 90), 0)
plt.plot(year, data[:, 0], "ko", mfc="none", ms=4, label="true hare", alpha=0.67)
plt.plot(year, data[:, 1], "bx", label="true lynx")
plt.plot(year, mu[:, 0], "k-.", label="pred hare", lw=1, alpha=0.67)
plt.plot(year, mu[:, 1], "b--", label="pred lynx")
plt.fill_between(year, pi[0, :, 0], pi[1, :, 0], color="k", alpha=0.2)
plt.fill_between(year, pi[0, :, 1], pi[1, :, 1], color="b", alpha=0.3)
plt.gca().set(ylim=(0, 160), xlabel="year", ylabel="population (in thousands)")
plt.title("Posterior predictive (80% CI) with predator-prey pattern.")
plt.legend()
plt.savefig("ode_plot.pdf")
plt.tight_layout()
示例13: model
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import exp [as 别名]
def model(returns):
step_size = numpyro.sample('sigma', dist.Exponential(50.))
s = numpyro.sample('s', dist.GaussianRandomWalk(scale=step_size, num_steps=jnp.shape(returns)[0]))
nu = numpyro.sample('nu', dist.Exponential(.1))
return numpyro.sample('r', dist.StudentT(df=nu, loc=0., scale=jnp.exp(s)),
obs=returns)
示例14: main
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import exp [as 别名]
def main(args):
_, fetch = load_dataset(SP500, shuffle=False)
dates, returns = fetch()
init_rng_key, sample_rng_key = random.split(random.PRNGKey(args.rng_seed))
model_info = initialize_model(init_rng_key, model, model_args=(returns,))
init_kernel, sample_kernel = hmc(model_info.potential_fn, algo='NUTS')
hmc_state = init_kernel(model_info.param_info, args.num_warmup, rng_key=sample_rng_key)
hmc_states = fori_collect(args.num_warmup, args.num_warmup + args.num_samples, sample_kernel, hmc_state,
transform=lambda hmc_state: model_info.postprocess_fn(hmc_state.z),
progbar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
print_results(hmc_states, dates)
fig, ax = plt.subplots(1, 1)
dates = mdates.num2date(mdates.datestr2num(dates))
ax.plot(dates, returns, lw=0.5)
# format the ticks
ax.xaxis.set_major_locator(mdates.YearLocator())
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y'))
ax.xaxis.set_minor_locator(mdates.MonthLocator())
ax.plot(dates, jnp.exp(hmc_states['s'].T), 'r', alpha=0.01)
legend = ax.legend(['returns', 'volatility'], loc='upper right')
legend.legendHandles[1].set_alpha(0.6)
ax.set(xlabel='time', ylabel='returns', title='Volatility of S&P500 over time')
plt.savefig("stochastic_volatility_plot.pdf")
plt.tight_layout()
示例15: guide
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import exp [as 别名]
def guide(data):
guide_loc = numpyro.param("guide_loc", 0.)
guide_scale = jnp.exp(numpyro.param("guide_scale_log", 0.))
numpyro.sample("loc", dist.Normal(guide_loc, guide_scale))