当前位置: 首页>>代码示例>>Python>>正文


Python numpy.sum方法代码示例

本文整理汇总了Python中jax.numpy.sum方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.sum方法的具体用法?Python numpy.sum怎么用?Python numpy.sum使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在jax.numpy的用法示例。


在下文中一共展示了numpy.sum方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: approximate_kl

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sum [as 别名]
def approximate_kl(log_prob_new, log_prob_old, mask):
  """Computes the approximate KL divergence between the old and new log-probs.

  Args:
    log_prob_new: (B, T+1, A) log probs new
    log_prob_old: (B, T+1, A) log probs old
    mask: (B, T)

  Returns:
    Approximate KL.
  """
  diff = log_prob_old - log_prob_new
  # Cut the last time-step out.
  diff = diff[:, :-1]
  # Mask out the irrelevant part.
  diff *= mask[:, :, np.newaxis]  # make mask (B, T, 1)
  # Average on non-masked part.
  return np.sum(diff) / np.sum(mask) 
开发者ID:yyht,项目名称:BERT,代码行数:20,代码来源:ppo.py

示例2: masked_entropy

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sum [as 别名]
def masked_entropy(log_probs, mask):
  """Computes the entropy for the given log-probs.

  Args:
    log_probs: (B, T+1, A) log probs
    mask: (B, T) mask.

  Returns:
    Entropy.
  """
  # Cut the last time-step out.
  lp = log_probs[:, :-1]
  # Mask out the irrelevant part.
  lp *= mask[:, :, np.newaxis]  # make mask (B, T, 1)
  p = np.exp(lp) * mask[:, :, np.newaxis]  # (B, T, 1)
  # Average on non-masked part and take negative.
  return -(np.sum(lp * p) / np.sum(mask)) 
开发者ID:yyht,项目名称:BERT,代码行数:19,代码来源:ppo.py

示例3: test_parameters_from_subsubmodule

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sum [as 别名]
def test_parameters_from_subsubmodule():
    subsublayer = Dense(2)
    sublayer = Sequential(subsublayer, relu)
    net = Sequential(sublayer, jnp.sum)
    inputs = jnp.zeros((1, 3))
    params = net.init_parameters(inputs, key=PRNGKey(0))
    out = net.apply(params, inputs)

    subsublayer_params = subsublayer.init_parameters(inputs, key=PRNGKey(0))

    params_ = net.parameters_from({subsublayer: subsublayer_params}, inputs)
    assert_dense_parameters_equal(subsublayer_params, params_[0][0])
    out_ = net.apply(params_, inputs)
    assert out.shape == out_.shape

    out_ = net.apply_from({subsublayer: subsublayer_params}, inputs)
    assert out.shape == out_.shape

    out_ = net.apply_from({subsublayer: subsublayer_params}, inputs, jit=True)
    assert out.shape == out_.shape 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:22,代码来源:test_core.py

示例4: test_parameters_from_sharing_between_multiple_parents

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sum [as 别名]
def test_parameters_from_sharing_between_multiple_parents():
    a = Dense(2)
    b = Sequential(a, jnp.sum)

    @parametrized
    def net(inputs):
        return a(inputs), b(inputs)

    inputs = jnp.zeros((1, 3))
    a_params = a.init_parameters(inputs, key=PRNGKey(0))
    out = a.apply(a_params, inputs)

    params = net.parameters_from({a: a_params}, inputs)
    assert_dense_parameters_equal(a_params, params.dense)
    assert_parameters_equal((), params.sequential)
    assert 2 == len(params)
    out_, _ = net.apply(params, inputs)
    assert jnp.array_equal(out, out_) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:20,代码来源:test_core.py

示例5: get_data

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sum [as 别名]
def get_data(N=20, S=2, P=10, sigma_obs=0.05):
    assert S < P and P > 1 and S > 0
    np.random.seed(0)

    X = np.random.randn(N, P)
    # generate S coefficients with non-negligible magnitude
    W = 0.5 + 2.5 * np.random.rand(S)
    # generate data using the S coefficients and a single pairwise interaction
    Y = np.sum(X[:, 0:S] * W, axis=-1) + X[:, 0] * X[:, 1] + sigma_obs * np.random.randn(N)
    Y -= jnp.mean(Y)
    Y_std = jnp.std(Y)

    assert X.shape == (N, P)
    assert Y.shape == (N,)

    return X, Y / Y_std, W / Y_std, 1.0 / Y_std


# Helper function for analyzing the posterior statistics for coefficient theta_i 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:21,代码来源:sparse_regression.py

示例6: test_beta_bernoulli

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sum [as 别名]
def test_beta_bernoulli(auto_class):
    data = jnp.array([[1.0] * 8 + [0.0] * 2,
                     [1.0] * 4 + [0.0] * 6]).T

    def model(data):
        f = numpyro.sample('beta', dist.Beta(jnp.ones(2), jnp.ones(2)))
        numpyro.sample('obs', dist.Bernoulli(f), obs=data)

    adam = optim.Adam(0.01)
    guide = auto_class(model, init_strategy=init_strategy)
    svi = SVI(model, guide, adam, ELBO())
    svi_state = svi.init(random.PRNGKey(1), data)

    def body_fn(i, val):
        svi_state, loss = svi.update(val, data)
        return svi_state

    svi_state = fori_loop(0, 3000, body_fn, svi_state)
    params = svi.get_params(svi_state)
    true_coefs = (jnp.sum(data, axis=0) + 1) / (data.shape[0] + 2)
    # test .sample_posterior method
    posterior_samples = guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(1000,))
    assert_allclose(jnp.mean(posterior_samples['beta'], 0), true_coefs, atol=0.05) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:25,代码来源:test_autoguide.py

示例7: test_unnormalized_normal_x64

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sum [as 别名]
def test_unnormalized_normal_x64(kernel_cls, dense_mass):
    true_mean, true_std = 1., 0.5
    warmup_steps, num_samples = (100000, 100000) if kernel_cls is SA else (1000, 8000)

    def potential_fn(z):
        return 0.5 * jnp.sum(((z - true_mean) / true_std) ** 2)

    init_params = jnp.array(0.)
    if kernel_cls is SA:
        kernel = SA(potential_fn=potential_fn, dense_mass=dense_mass)
    else:
        kernel = kernel_cls(potential_fn=potential_fn, trajectory_length=8, dense_mass=dense_mass)
    mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False)
    mcmc.run(random.PRNGKey(0), init_params=init_params)
    mcmc.print_summary()
    hmc_states = mcmc.get_samples()
    assert_allclose(jnp.mean(hmc_states), true_mean, rtol=0.07)
    assert_allclose(jnp.std(hmc_states), true_std, rtol=0.07)

    if 'JAX_ENABLE_X64' in os.environ:
        assert hmc_states.dtype == jnp.float64 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:23,代码来源:test_mcmc.py

示例8: test_correlated_mvn

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sum [as 别名]
def test_correlated_mvn():
    # This requires dense mass matrix estimation.
    D = 5

    warmup_steps, num_samples = 5000, 8000

    true_mean = 0.
    a = jnp.tril(0.5 * jnp.fliplr(jnp.eye(D)) + 0.1 * jnp.exp(random.normal(random.PRNGKey(0), shape=(D, D))))
    true_cov = jnp.dot(a, a.T)
    true_prec = jnp.linalg.inv(true_cov)

    def potential_fn(z):
        return 0.5 * jnp.dot(z.T, jnp.dot(true_prec, z))

    init_params = jnp.zeros(D)
    kernel = NUTS(potential_fn=potential_fn, dense_mass=True)
    mcmc = MCMC(kernel, warmup_steps, num_samples)
    mcmc.run(random.PRNGKey(0), init_params=init_params)
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples), true_mean, atol=0.02)
    assert np.sum(np.abs(np.cov(samples.T) - true_cov)) / D**2 < 0.02 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:23,代码来源:test_mcmc.py

示例9: test_diverging

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sum [as 别名]
def test_diverging(kernel_cls, adapt_step_size):
    data = random.normal(random.PRNGKey(0), (1000,))

    def model(data):
        loc = numpyro.sample('loc', dist.Normal(0., 1.))
        numpyro.sample('obs', dist.Normal(loc, 1), obs=data)

    kernel = kernel_cls(model, step_size=10., adapt_step_size=adapt_step_size, adapt_mass_matrix=False)
    num_warmup = num_samples = 1000
    mcmc = MCMC(kernel, num_warmup, num_samples)
    mcmc.warmup(random.PRNGKey(1), data, extra_fields=['diverging'], collect_warmup=True)
    warmup_divergences = mcmc.get_extra_fields()['diverging'].sum()
    mcmc.run(random.PRNGKey(2), data, extra_fields=['diverging'])
    num_divergences = warmup_divergences + mcmc.get_extra_fields()['diverging'].sum()
    if adapt_step_size:
        assert num_divergences <= num_warmup
    else:
        assert_allclose(num_divergences, num_warmup + num_samples) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:20,代码来源:test_mcmc.py

示例10: test_functional_map

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sum [as 别名]
def test_functional_map(algo, map_fn):
    if map_fn is pmap and xla_bridge.device_count() == 1:
        pytest.skip('pmap test requires device_count greater than 1.')

    true_mean, true_std = 1., 2.
    warmup_steps, num_samples = 1000, 8000

    def potential_fn(z):
        return 0.5 * jnp.sum(((z - true_mean) / true_std) ** 2)

    init_kernel, sample_kernel = hmc(potential_fn, algo=algo)
    init_params = jnp.array([0., -1.])
    rng_keys = random.split(random.PRNGKey(0), 2)

    init_kernel_map = map_fn(lambda init_param, rng_key: init_kernel(
        init_param, trajectory_length=9, num_warmup=warmup_steps, rng_key=rng_key))
    init_states = init_kernel_map(init_params, rng_keys)

    fori_collect_map = map_fn(lambda hmc_state: fori_collect(0, num_samples, sample_kernel, hmc_state,
                                                             transform=lambda x: x.z, progbar=False))
    chain_samples = fori_collect_map(init_states)

    assert_allclose(jnp.mean(chain_samples, axis=1), jnp.repeat(true_mean, 2), rtol=0.06)
    assert_allclose(jnp.std(chain_samples, axis=1), jnp.repeat(true_std, 2), rtol=0.06) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:26,代码来源:test_mcmc.py

示例11: test_categorical_log_prob_grad

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sum [as 别名]
def test_categorical_log_prob_grad():
    data = jnp.repeat(jnp.arange(3), 10)

    def f(x):
        return dist.Categorical(jax.nn.softmax(x * jnp.arange(1, 4))).log_prob(data).sum()

    def g(x):
        return dist.Categorical(logits=x * jnp.arange(1, 4)).log_prob(data).sum()

    x = 0.5
    fx, grad_fx = jax.value_and_grad(f)(x)
    gx, grad_gx = jax.value_and_grad(g)(x)
    assert_allclose(fx, gx)
    assert_allclose(grad_fx, grad_gx, atol=1e-4)


########################################
# Tests for constraints and transforms #
######################################## 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:21,代码来源:test_distributions.py

示例12: clip_eta

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sum [as 别名]
def clip_eta(eta, norm, eps):
  """
  Helper function to clip the perturbation to epsilon norm ball.
  :param eta: A tensor with the current perturbation.
  :param norm: Order of the norm (mimics Numpy).
              Possible values: np.inf or 2.
  :param eps: Epsilon, bound of the perturbation.
  """

  # Clipping perturbation eta to self.norm norm ball
  if norm not in [np.inf, 2]:
    raise ValueError('norm must be np.inf or 2.')

  axis = list(range(1, len(eta.shape)))
  avoid_zero_div = 1e-12
  if norm == np.inf:
    eta = np.clip(eta, a_min=-eps, a_max=eps)
  elif norm == 2:
    # avoid_zero_div must go inside sqrt to avoid a divide by zero in the gradient through this operation
    norm = np.sqrt(np.maximum(avoid_zero_div, np.sum(np.square(eta), axis=axis, keepdims=True)))
    # We must *clip* to within the norm ball, not *normalize* onto the surface of the ball
    factor = np.minimum(1., np.divide(eps, norm))
    eta = eta * factor
  return eta 
开发者ID:tensorflow,项目名称:cleverhans,代码行数:26,代码来源:utils.py

示例13: evaluate_policy

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sum [as 别名]
def evaluate_policy(eval_env,
                    get_predictions,
                    temperatures,
                    max_timestep=20000,
                    n_evals=1,
                    len_history_for_policy=32,
                    rng=None):
  """Evaluate the policy."""

  processed_reward_sums = collections.defaultdict(list)
  raw_reward_sums = collections.defaultdict(list)
  for eval_rng in jax_random.split(rng, num=n_evals):
    for temperature in temperatures:
      trajs, _, _ = env_problem_utils.play_env_problem_with_policy(
          eval_env,
          get_predictions,
          num_trajectories=eval_env.batch_size,
          max_timestep=max_timestep,
          reset=True,
          policy_sampling=env_problem_utils.GUMBEL_SAMPLING,
          temperature=temperature,
          rng=eval_rng,
          len_history_for_policy=len_history_for_policy)
      processed_reward_sums[temperature].extend(sum(traj[2]) for traj in trajs)
      raw_reward_sums[temperature].extend(sum(traj[3]) for traj in trajs)

  # Return the mean and standard deviation for each temperature.
  def compute_stats(reward_dict):
    return {
        temperature: {"mean": onp.mean(rewards), "std": onp.std(rewards)}
        for (temperature, rewards) in reward_dict.items()
    }
  return {
      "processed": compute_stats(processed_reward_sums),
      "raw": compute_stats(raw_reward_sums),
  } 
开发者ID:yyht,项目名称:BERT,代码行数:38,代码来源:ppo.py

示例14: _sum

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sum [as 别名]
def _sum(x, dim):
    return np.sum(x, axis=dim) 
开发者ID:pyro-ppl,项目名称:funsor,代码行数:4,代码来源:ops.py

示例15: loss

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import sum [as 别名]
def loss(y, y_hat):
  return -np.sum(y * y_hat) 
开发者ID:google,项目名称:spectral-density,代码行数:4,代码来源:spectral_density_test.py


注:本文中的jax.numpy.sum方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。