当前位置: 首页>>代码示例>>Python>>正文


Python numpy.ndarray方法代码示例

本文整理汇总了Python中jax.numpy.ndarray方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.ndarray方法的具体用法?Python numpy.ndarray怎么用?Python numpy.ndarray使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在jax.numpy的用法示例。


在下文中一共展示了numpy.ndarray方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: deltas

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ndarray [as 别名]
def deltas(predicted_values, rewards, mask, gamma=0.99):
  r"""Computes TD-residuals from V(s) and rewards.

  Where a `delta`, i.e. a td-residual is defined as:

  delta_{b,t} = r_{b,t} + \gamma * v_{b,t+1} - v_{b,t}.

  Args:
    predicted_values: ndarray of shape (B, T+1). NOTE: Expects axis 2 was
      squeezed. These represent V(s_bt) for b < B and t < T+1
    rewards: ndarray of shape (B, T) of rewards.
    mask: ndarray of shape (B, T) of mask for rewards.
    gamma: float, discount factor.

  Returns:
    ndarray of shape (B, T) of one-step TD-residuals.
  """

  # Predicted values at time t, cutting off the last to have shape (B, T).
  predicted_values_bt = predicted_values[:, :-1]
  # Predicted values at time t+1, by cutting off the first to have shape (B, T)
  predicted_values_btplus1 = predicted_values[:, 1:]
  # Return the deltas as defined above.
  return (rewards +
          (gamma * predicted_values_btplus1) - predicted_values_bt) * mask 
开发者ID:yyht,项目名称:BERT,代码行数:27,代码来源:ppo.py

示例2: gae_advantages

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ndarray [as 别名]
def gae_advantages(td_deltas, mask, lambda_=0.95, gamma=0.99):
  r"""Computes the GAE advantages given the one step TD-residuals.

  The formula for a GAE advantage estimator is as follows:

  A_{bt} = \sum_{l=0}^{\infty}(\gamma * \lambda)^{l}(\delta_{b,t+l}).

  Internally we just call rewards_to_go, since it is the same computation.

  Args:
    td_deltas: np.ndarray of shape (B, T) of one step TD-residuals.
    mask: np.ndarray of shape (B, T) of mask for the residuals. It maybe the
      case that the `td_deltas` are already masked correctly since they are
      produced by `deltas(...)`
    lambda_: float, lambda parameter for GAE estimators.
    gamma: float, lambda parameter for GAE estimators.

  Returns:
    GAE advantage estimates.
  """

  return rewards_to_go(td_deltas, mask, lambda_ * gamma) 
开发者ID:yyht,项目名称:BERT,代码行数:24,代码来源:ppo.py

示例3: chosen_probabs

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ndarray [as 别名]
def chosen_probabs(probab_observations, actions):
  """Picks out the probabilities of the actions along batch and time-steps.

  Args:
    probab_observations: ndarray of shape `[B, T+1, A]`, where
      probab_observations[b, t, i] contains the log-probability of action = i at
      the t^th time-step in the b^th trajectory.
    actions: ndarray of shape `[B, T]`, with each entry in [0, A) denoting which
      action was chosen in the b^th trajectory's t^th time-step.

  Returns:
    `[B, T]` ndarray with the log-probabilities of the chosen actions.
  """
  B, T = actions.shape  # pylint: disable=invalid-name
  assert (B, T + 1) == probab_observations.shape[:2]
  return probab_observations[np.arange(B)[:, None], np.arange(T), actions] 
开发者ID:yyht,项目名称:BERT,代码行数:18,代码来源:ppo.py

示例4: assert_parameters_equal

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ndarray [as 别名]
def assert_parameters_equal(p, p_):
    if isinstance(p, np.ndarray):
        assert np.array_equal(p, p_)
        return

    assert isinstance(p, tuple) or isinstance(p, list) or isinstance(p, dict)
    assert isinstance(p, tuple) == isinstance(p_, tuple)
    assert isinstance(p, list) == isinstance(p_, list)
    assert isinstance(p, dict) == isinstance(p_, dict)

    assert len(p) == len(p_)

    if isinstance(p, dict):
        for k, e in p.items():
            assert_parameters_equal(e, p_[k])
    else:
        for e, e_ in zip(p, p_):
            assert_parameters_equal(e, e_) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:20,代码来源:util.py

示例5: default_agent

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ndarray [as 别名]
def default_agent(obs_spec: specs.Array,
                  action_spec: specs.DiscreteArray,
                  seed: int = 0) -> base.Agent:
  """Initialize a DQN agent with default parameters."""

  def network(inputs: jnp.ndarray) -> jnp.ndarray:
    flat_inputs = hk.Flatten()(inputs)
    mlp = hk.nets.MLP([64, 64, action_spec.num_values])
    action_values = mlp(flat_inputs)
    return action_values

  return DQN(
      obs_spec=obs_spec,
      action_spec=action_spec,
      network=network,
      optimizer=optix.adam(1e-3),
      batch_size=32,
      discount=0.99,
      replay_capacity=10000,
      min_replay_size=100,
      sgd_period=1,
      target_update_period=4,
      epsilon=0.05,
      rng=hk.PRNGSequence(seed),
  ) 
开发者ID:deepmind,项目名称:bsuite,代码行数:27,代码来源:agent.py

示例6: default_agent

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ndarray [as 别名]
def default_agent(obs_spec: specs.Array,
                  action_spec: specs.DiscreteArray,
                  seed: int = 0) -> base.Agent:
  """Creates an actor-critic agent with default hyperparameters."""

  def network(inputs: jnp.ndarray) -> Tuple[Logits, Value]:
    flat_inputs = hk.Flatten()(inputs)
    torso = hk.nets.MLP([64, 64])
    policy_head = hk.Linear(action_spec.num_values)
    value_head = hk.Linear(1)
    embedding = torso(flat_inputs)
    logits = policy_head(embedding)
    value = value_head(embedding)
    return logits, jnp.squeeze(value, axis=-1)

  return ActorCritic(
      obs_spec=obs_spec,
      action_spec=action_spec,
      network=network,
      optimizer=optix.adam(3e-3),
      rng=hk.PRNGSequence(seed),
      sequence_length=32,
      discount=0.99,
      td_lambda=0.9,
  ) 
开发者ID:deepmind,项目名称:bsuite,代码行数:27,代码来源:agent.py

示例7: make_dataset

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ndarray [as 别名]
def make_dataset(rng_key) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """
    Make simulated dataset where potential customers who get a
    sales calls have ~2% higher chance of making another purchase.
    """
    key1, key2, key3 = random.split(rng_key, 3)

    num_calls = 51342
    num_no_calls = 48658

    made_purchase_got_called = dist.Bernoulli(0.084).sample(key1, sample_shape=(num_calls,))
    made_purchase_no_calls = dist.Bernoulli(0.061).sample(key2, sample_shape=(num_no_calls,))

    made_purchase = jnp.concatenate([made_purchase_got_called, made_purchase_no_calls])

    is_female = dist.Bernoulli(0.5).sample(key3, sample_shape=(num_calls + num_no_calls,))
    got_called = jnp.concatenate([jnp.ones(num_calls), jnp.zeros(num_no_calls)])
    design_matrix = jnp.hstack([jnp.ones((num_no_calls + num_calls, 1)),
                               got_called.reshape(-1, 1),
                               is_female.reshape(-1, 1)])

    return design_matrix, made_purchase 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:24,代码来源:proportion_test.py

示例8: model

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ndarray [as 别名]
def model(design_matrix: jnp.ndarray, outcome: jnp.ndarray = None) -> None:
    """
    Model definition: Log odds of making a purchase is a linear combination
    of covariates. Specify a Normal prior over regression coefficients.
    :param design_matrix: Covariates. All categorical variables have been one-hot
        encoded.
    :param outcome: Binary response variable. In this case, whether or not the
        customer made a purchase.
    """

    beta = numpyro.sample('coefficients', dist.MultivariateNormal(loc=0.,
                                                                  covariance_matrix=jnp.eye(design_matrix.shape[1])))
    logits = design_matrix.dot(beta)

    with numpyro.plate('data', design_matrix.shape[0]):
        numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=outcome) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:18,代码来源:proportion_test.py

示例9: print_results

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ndarray [as 别名]
def print_results(coef: jnp.ndarray, interval_size: float = 0.95) -> None:
    """
    Print the confidence interval for the effect size with interval_size
    probability mass.
    """

    baseline_response = expit(coef[:, 0])
    response_with_calls = expit(coef[:, 0] + coef[:, 1])

    impact_on_probability = hpdi(response_with_calls - baseline_response, prob=interval_size)

    effect_of_gender = hpdi(coef[:, 2], prob=interval_size)

    print(f"There is a {interval_size * 100}% probability that calling customers "
          "increases the chance they'll make a purchase by "
          f"{(100 * impact_on_probability[0]):.2} to {(100 * impact_on_probability[1]):.2} percentage points."
          )

    print(f"There is a {interval_size * 100}% probability the effect of gender on the log odds of conversion "
          f"lies in the interval ({effect_of_gender[0]:.2}, {effect_of_gender[1]:.2f})."
          " Since this interval contains 0, we can conclude gender does not impact the conversion rate.") 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:23,代码来源:proportion_test.py

示例10: run_inference

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ndarray [as 别名]
def run_inference(design_matrix: jnp.ndarray, outcome: jnp.ndarray,
                  rng_key: jnp.ndarray,
                  num_warmup: int,
                  num_samples: int, num_chains: int,
                  interval_size: float = 0.95) -> None:
    """
    Estimate the effect size.
    """

    kernel = NUTS(model)
    mcmc = MCMC(kernel, num_warmup, num_samples, num_chains,
                progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(rng_key, design_matrix, outcome)

    # 0th column is intercept (not getting called)
    # 1st column is effect of getting called
    # 2nd column is effect of gender (should be none since assigned at random)
    coef = mcmc.get_samples()['coefficients']
    print_results(coef, interval_size) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:21,代码来源:proportion_test.py

示例11: compute_probab_ratios

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ndarray [as 别名]
def compute_probab_ratios(p_new, p_old, actions, reward_mask):
  """Computes the probability ratios for each time-step in a trajectory.

  Args:
    p_new: ndarray of shape [B, T+1, A] of the log-probabilities that the policy
      network assigns to all the actions at each time-step in each batch using
      the old parameters.
    p_old: ndarray of shape [B, T+1, A], same as above, but using old policy
      network parameters.
    actions: ndarray of shape [B, T] where each element is from [0, A).
    reward_mask: ndarray of shape [B, T] masking over probabilities.

  Returns:
    probab_ratios: ndarray of shape [B, T], where
    probab_ratios_{b,t} = p_new_{b,t,action_{b,t}} / p_old_{b,t,action_{b,t}}
  """

  B, T = actions.shape  # pylint: disable=invalid-name
  assert (B, T + 1) == p_old.shape[:2]
  assert (B, T + 1) == p_new.shape[:2]

  logp_old = chosen_probabs(p_old, actions)
  logp_new = chosen_probabs(p_new, actions)

  assert (B, T) == logp_old.shape
  assert (B, T) == logp_new.shape

  # Since these are log-probabilities, we just subtract them.
  probab_ratios = np.exp(logp_new - logp_old) * reward_mask
  assert (B, T) == probab_ratios.shape
  return probab_ratios 
开发者ID:yyht,项目名称:BERT,代码行数:33,代码来源:ppo.py

示例12: default_agent

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ndarray [as 别名]
def default_agent(
    obs_spec: specs.Array,
    action_spec: specs.DiscreteArray,
    seed: int = 0,
    num_ensemble: int = 20,
) -> BootstrappedDqn:
  """Initialize a Bootstrapped DQN agent with default parameters."""

  # Define network.
  prior_scale = 3.
  hidden_sizes = [50, 50]

  def network(inputs: jnp.ndarray) -> jnp.ndarray:
    """Simple Q-network with randomized prior function."""
    net = hk.nets.MLP([*hidden_sizes, action_spec.num_values])
    prior_net = hk.nets.MLP([*hidden_sizes, action_spec.num_values])
    x = hk.Flatten()(inputs)
    return net(x) + prior_scale * lax.stop_gradient(prior_net(x))

  optimizer = optix.adam(learning_rate=1e-3)
  return BootstrappedDqn(
      obs_spec=obs_spec,
      action_spec=action_spec,
      network=network,
      batch_size=128,
      discount=.99,
      num_ensemble=num_ensemble,
      replay_capacity=10000,
      min_replay_size=128,
      sgd_period=1,
      target_update_period=4,
      optimizer=optimizer,
      mask_prob=0.5,
      noise_scale=0.,
      epsilon_fn=lambda _: 0.,
      seed=seed,
  ) 
开发者ID:deepmind,项目名称:bsuite,代码行数:39,代码来源:agent.py

示例13: get_backend

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ndarray [as 别名]
def get_backend(tensor) -> 'AbstractBackend':
    """
    Takes a correct backend (e.g. numpy backend if tensor is numpy.ndarray) for a tensor.
    If needed, imports package and creates backend
    """
    for framework_name, backend in _backends.items():
        if backend.is_appropriate_type(tensor):
            return backend

    # Find backend subclasses recursively
    backend_subclasses = []
    backends = AbstractBackend.__subclasses__()
    while backends:
        backend = backends.pop()
        backends += backend.__subclasses__()
        backend_subclasses.append(backend)

    for BackendSubclass in backend_subclasses:
        if _debug_importing:
            print('Testing for subclass of ', BackendSubclass)
        if BackendSubclass.framework_name not in _backends:
            # check that module was already imported. Otherwise it can't be imported
            if BackendSubclass.framework_name in sys.modules:
                if _debug_importing:
                    print('Imported backend for ', BackendSubclass.framework_name)
                backend = BackendSubclass()
                _backends[backend.framework_name] = backend
                if backend.is_appropriate_type(tensor):
                    return backend

    raise RuntimeError('Tensor type unknown to einops {}'.format(type(tensor))) 
开发者ID:arogozhnikov,项目名称:einops,代码行数:33,代码来源:_backends.py

示例14: is_appropriate_type

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ndarray [as 别名]
def is_appropriate_type(self, tensor):
        return isinstance(tensor, self.np.ndarray) 
开发者ID:arogozhnikov,项目名称:einops,代码行数:4,代码来源:_backends.py

示例15: _hashable

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ndarray [as 别名]
def _hashable(x):
    # When the arguments are JITed, ShapedArray is hashable.
    if isinstance(x, Tracer):
        return x
    elif isinstance(x, DeviceArray):
        return x.copy().tobytes()
    elif isinstance(x, jnp.ndarray):
        return x.tobytes()
    return x 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:11,代码来源:mcmc.py


注:本文中的jax.numpy.ndarray方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。