当前位置: 首页>>代码示例>>Python>>正文


Python numpy.exp方法代码示例

本文整理汇总了Python中jax.numpy.exp方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.exp方法的具体用法?Python numpy.exp怎么用?Python numpy.exp使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在jax.numpy的用法示例。


在下文中一共展示了numpy.exp方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: masked_entropy

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import exp [as 别名]
def masked_entropy(log_probs, mask):
  """Computes the entropy for the given log-probs.

  Args:
    log_probs: (B, T+1, A) log probs
    mask: (B, T) mask.

  Returns:
    Entropy.
  """
  # Cut the last time-step out.
  lp = log_probs[:, :-1]
  # Mask out the irrelevant part.
  lp *= mask[:, :, np.newaxis]  # make mask (B, T, 1)
  p = np.exp(lp) * mask[:, :, np.newaxis]  # (B, T, 1)
  # Average on non-masked part and take negative.
  return -(np.sum(lp * p) / np.sum(mask)) 
开发者ID:yyht,项目名称:BERT,代码行数:19,代码来源:ppo.py

示例2: print_results

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import exp [as 别名]
def print_results(posterior, dates):
    def _print_row(values, row_name=''):
        quantiles = jnp.array([0.2, 0.4, 0.5, 0.6, 0.8])
        row_name_fmt = '{:>8}'
        header_format = row_name_fmt + '{:>12}' * 5
        row_format = row_name_fmt + '{:>12.3f}' * 5
        columns = ['(p{})'.format(q * 100) for q in quantiles]
        q_values = jnp.quantile(values, quantiles, axis=0)
        print(header_format.format('', *columns))
        print(row_format.format(row_name, *q_values))
        print('\n')

    print('=' * 20, 'sigma', '=' * 20)
    _print_row(posterior['sigma'])
    print('=' * 20, 'nu', '=' * 20)
    _print_row(posterior['nu'])
    print('=' * 20, 'volatility', '=' * 20)
    for i in range(0, len(dates), 180):
        _print_row(jnp.exp(posterior['s'][:, i]), dates[i]) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:21,代码来源:stochastic_volatility.py

示例3: _build_basetree

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import exp [as 别名]
def _build_basetree(vv_update, kinetic_fn, z, r, z_grad, inverse_mass_matrix, step_size, going_right,
                    energy_current, max_delta_energy):
    step_size = jnp.where(going_right, step_size, -step_size)
    z_new, r_new, potential_energy_new, z_new_grad = vv_update(
        step_size,
        inverse_mass_matrix,
        (z, r, energy_current, z_grad),
    )

    energy_new = potential_energy_new + kinetic_fn(inverse_mass_matrix, r_new)
    delta_energy = energy_new - energy_current
    # Handles the NaN case.
    delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf, delta_energy)
    tree_weight = -delta_energy

    diverging = delta_energy > max_delta_energy
    accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0)
    return TreeInfo(z_new, r_new, z_new_grad, z_new, r_new, z_new_grad,
                    z_new, potential_energy_new, z_new_grad, energy_new,
                    depth=0, weight=tree_weight, r_sum=r_new, turning=False,
                    diverging=diverging, sum_accept_probs=accept_prob, num_proposals=1) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:23,代码来源:hmc_util.py

示例4: test_correlated_mvn

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import exp [as 别名]
def test_correlated_mvn():
    # This requires dense mass matrix estimation.
    D = 5

    warmup_steps, num_samples = 5000, 8000

    true_mean = 0.
    a = jnp.tril(0.5 * jnp.fliplr(jnp.eye(D)) + 0.1 * jnp.exp(random.normal(random.PRNGKey(0), shape=(D, D))))
    true_cov = jnp.dot(a, a.T)
    true_prec = jnp.linalg.inv(true_cov)

    def potential_fn(z):
        return 0.5 * jnp.dot(z.T, jnp.dot(true_prec, z))

    init_params = jnp.zeros(D)
    kernel = NUTS(potential_fn=potential_fn, dense_mass=True)
    mcmc = MCMC(kernel, warmup_steps, num_samples)
    mcmc.run(random.PRNGKey(0), init_params=init_params)
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples), true_mean, atol=0.02)
    assert np.sum(np.abs(np.cov(samples.T) - true_cov)) / D**2 < 0.02 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:23,代码来源:test_mcmc.py

示例5: compute_probab_ratios

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import exp [as 别名]
def compute_probab_ratios(p_new, p_old, actions, reward_mask):
  """Computes the probability ratios for each time-step in a trajectory.

  Args:
    p_new: ndarray of shape [B, T+1, A] of the log-probabilities that the policy
      network assigns to all the actions at each time-step in each batch using
      the old parameters.
    p_old: ndarray of shape [B, T+1, A], same as above, but using old policy
      network parameters.
    actions: ndarray of shape [B, T] where each element is from [0, A).
    reward_mask: ndarray of shape [B, T] masking over probabilities.

  Returns:
    probab_ratios: ndarray of shape [B, T], where
    probab_ratios_{b,t} = p_new_{b,t,action_{b,t}} / p_old_{b,t,action_{b,t}}
  """

  B, T = actions.shape  # pylint: disable=invalid-name
  assert (B, T + 1) == p_old.shape[:2]
  assert (B, T + 1) == p_new.shape[:2]

  logp_old = chosen_probabs(p_old, actions)
  logp_new = chosen_probabs(p_new, actions)

  assert (B, T) == logp_old.shape
  assert (B, T) == logp_new.shape

  # Since these are log-probabilities, we just subtract them.
  probab_ratios = np.exp(logp_new - logp_old) * reward_mask
  assert (B, T) == probab_ratios.shape
  return probab_ratios 
开发者ID:yyht,项目名称:BERT,代码行数:33,代码来源:ppo.py

示例6: tanh

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import exp [as 别名]
def tanh(self, x):
        import jax.numpy as np
        y = np.exp(-2.0 * x)
        return (1.0 - y) / (1.0 + y) 
开发者ID:Kaggle,项目名称:docker-python,代码行数:6,代码来源:test_jax.py

示例7: sigmoid

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import exp [as 别名]
def sigmoid(self, x):
        return 1 / (1 + np.exp(-x)) 
开发者ID:sharadmv,项目名称:deepx,代码行数:4,代码来源:jax.py

示例8: softmax

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import exp [as 别名]
def softmax(self, x, T=1.0):
        unnormalized = np.exp(x - x.max(-1, keepdims=True))
        return unnormalized / unnormalized.sum(-1, keepdims=True) 
开发者ID:sharadmv,项目名称:deepx,代码行数:5,代码来源:jax.py

示例9: exp

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import exp [as 别名]
def exp(self, x):
        return np.exp(x) 
开发者ID:sharadmv,项目名称:deepx,代码行数:4,代码来源:jax.py

示例10: exp

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import exp [as 别名]
def exp(self, tensor_in):
        return np.exp(tensor_in) 
开发者ID:scikit-hep,项目名称:pyhf,代码行数:4,代码来源:jax_backend.py

示例11: poisson

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import exp [as 别名]
def poisson(self, n, lam):
        r"""
        The continous approximation, using :math:`n! = \Gamma\left(n+1\right)`,
        to the probability mass function of the Poisson distribution evaluated
        at :code:`n` given the parameter :code:`lam`.

        Example:

            >>> import pyhf
            >>> pyhf.set_backend("jax")
            >>> pyhf.tensorlib.poisson(5., 6.)
            DeviceArray(0.16062314, dtype=float64)
            >>> values = pyhf.tensorlib.astensor([5., 9.])
            >>> rates = pyhf.tensorlib.astensor([6., 8.])
            >>> pyhf.tensorlib.poisson(values, rates)
            DeviceArray([0.16062314, 0.12407692], dtype=float64)

        Args:
            n (`tensor` or `float`): The value at which to evaluate the approximation to the Poisson distribution p.m.f.
                                  (the observed number of events)
            lam (`tensor` or `float`): The mean of the Poisson distribution p.m.f.
                                    (the expected number of events)

        Returns:
            JAX ndarray: Value of the continous approximation to Poisson(n|lam)
        """
        n = np.asarray(n)
        lam = np.asarray(lam)
        return np.exp(n * np.log(lam) - lam - gammaln(n + 1.0)) 
开发者ID:scikit-hep,项目名称:pyhf,代码行数:31,代码来源:jax_backend.py

示例12: main

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import exp [as 别名]
def main(args):
    _, fetch = load_dataset(LYNXHARE, shuffle=False)
    year, data = fetch()  # data is in hare -> lynx order

    # use dense_mass for better mixing rate
    mcmc = MCMC(NUTS(model, dense_mass=True),
                args.num_warmup, args.num_samples, num_chains=args.num_chains,
                progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(PRNGKey(1), N=data.shape[0], y=jnp.log(data))
    mcmc.print_summary()

    # predict populations
    y_pred = Predictive(model, mcmc.get_samples())(PRNGKey(2), data.shape[0])["y"]
    pop_pred = jnp.exp(y_pred)
    mu, pi = jnp.mean(pop_pred, 0), jnp.percentile(pop_pred, (10, 90), 0)
    plt.plot(year, data[:, 0], "ko", mfc="none", ms=4, label="true hare", alpha=0.67)
    plt.plot(year, data[:, 1], "bx", label="true lynx")
    plt.plot(year, mu[:, 0], "k-.", label="pred hare", lw=1, alpha=0.67)
    plt.plot(year, mu[:, 1], "b--", label="pred lynx")
    plt.fill_between(year, pi[0, :, 0], pi[1, :, 0], color="k", alpha=0.2)
    plt.fill_between(year, pi[0, :, 1], pi[1, :, 1], color="b", alpha=0.3)
    plt.gca().set(ylim=(0, 160), xlabel="year", ylabel="population (in thousands)")
    plt.title("Posterior predictive (80% CI) with predator-prey pattern.")
    plt.legend()

    plt.savefig("ode_plot.pdf")
    plt.tight_layout() 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:29,代码来源:ode.py

示例13: model

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import exp [as 别名]
def model(returns):
    step_size = numpyro.sample('sigma', dist.Exponential(50.))
    s = numpyro.sample('s', dist.GaussianRandomWalk(scale=step_size, num_steps=jnp.shape(returns)[0]))
    nu = numpyro.sample('nu', dist.Exponential(.1))
    return numpyro.sample('r', dist.StudentT(df=nu, loc=0., scale=jnp.exp(s)),
                          obs=returns) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:8,代码来源:stochastic_volatility.py

示例14: main

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import exp [as 别名]
def main(args):
    _, fetch = load_dataset(SP500, shuffle=False)
    dates, returns = fetch()
    init_rng_key, sample_rng_key = random.split(random.PRNGKey(args.rng_seed))
    model_info = initialize_model(init_rng_key, model, model_args=(returns,))
    init_kernel, sample_kernel = hmc(model_info.potential_fn, algo='NUTS')
    hmc_state = init_kernel(model_info.param_info, args.num_warmup, rng_key=sample_rng_key)
    hmc_states = fori_collect(args.num_warmup, args.num_warmup + args.num_samples, sample_kernel, hmc_state,
                              transform=lambda hmc_state: model_info.postprocess_fn(hmc_state.z),
                              progbar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    print_results(hmc_states, dates)

    fig, ax = plt.subplots(1, 1)
    dates = mdates.num2date(mdates.datestr2num(dates))
    ax.plot(dates, returns, lw=0.5)
    # format the ticks
    ax.xaxis.set_major_locator(mdates.YearLocator())
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y'))
    ax.xaxis.set_minor_locator(mdates.MonthLocator())

    ax.plot(dates, jnp.exp(hmc_states['s'].T), 'r', alpha=0.01)
    legend = ax.legend(['returns', 'volatility'], loc='upper right')
    legend.legendHandles[1].set_alpha(0.6)
    ax.set(xlabel='time', ylabel='returns', title='Volatility of S&P500 over time')

    plt.savefig("stochastic_volatility_plot.pdf")
    plt.tight_layout() 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:29,代码来源:stochastic_volatility.py

示例15: guide

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import exp [as 别名]
def guide(data):
    guide_loc = numpyro.param("guide_loc", 0.)
    guide_scale = jnp.exp(numpyro.param("guide_scale_log", 0.))
    numpyro.sample("loc", dist.Normal(guide_loc, guide_scale)) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:6,代码来源:minipyro.py


注:本文中的jax.numpy.exp方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。