本文整理汇总了Python中jax.numpy.ndarray方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.ndarray方法的具体用法?Python numpy.ndarray怎么用?Python numpy.ndarray使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类jax.numpy
的用法示例。
在下文中一共展示了numpy.ndarray方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: deltas
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ndarray [as 别名]
def deltas(predicted_values, rewards, mask, gamma=0.99):
r"""Computes TD-residuals from V(s) and rewards.
Where a `delta`, i.e. a td-residual is defined as:
delta_{b,t} = r_{b,t} + \gamma * v_{b,t+1} - v_{b,t}.
Args:
predicted_values: ndarray of shape (B, T+1). NOTE: Expects axis 2 was
squeezed. These represent V(s_bt) for b < B and t < T+1
rewards: ndarray of shape (B, T) of rewards.
mask: ndarray of shape (B, T) of mask for rewards.
gamma: float, discount factor.
Returns:
ndarray of shape (B, T) of one-step TD-residuals.
"""
# Predicted values at time t, cutting off the last to have shape (B, T).
predicted_values_bt = predicted_values[:, :-1]
# Predicted values at time t+1, by cutting off the first to have shape (B, T)
predicted_values_btplus1 = predicted_values[:, 1:]
# Return the deltas as defined above.
return (rewards +
(gamma * predicted_values_btplus1) - predicted_values_bt) * mask
示例2: gae_advantages
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ndarray [as 别名]
def gae_advantages(td_deltas, mask, lambda_=0.95, gamma=0.99):
r"""Computes the GAE advantages given the one step TD-residuals.
The formula for a GAE advantage estimator is as follows:
A_{bt} = \sum_{l=0}^{\infty}(\gamma * \lambda)^{l}(\delta_{b,t+l}).
Internally we just call rewards_to_go, since it is the same computation.
Args:
td_deltas: np.ndarray of shape (B, T) of one step TD-residuals.
mask: np.ndarray of shape (B, T) of mask for the residuals. It maybe the
case that the `td_deltas` are already masked correctly since they are
produced by `deltas(...)`
lambda_: float, lambda parameter for GAE estimators.
gamma: float, lambda parameter for GAE estimators.
Returns:
GAE advantage estimates.
"""
return rewards_to_go(td_deltas, mask, lambda_ * gamma)
示例3: chosen_probabs
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ndarray [as 别名]
def chosen_probabs(probab_observations, actions):
"""Picks out the probabilities of the actions along batch and time-steps.
Args:
probab_observations: ndarray of shape `[B, T+1, A]`, where
probab_observations[b, t, i] contains the log-probability of action = i at
the t^th time-step in the b^th trajectory.
actions: ndarray of shape `[B, T]`, with each entry in [0, A) denoting which
action was chosen in the b^th trajectory's t^th time-step.
Returns:
`[B, T]` ndarray with the log-probabilities of the chosen actions.
"""
B, T = actions.shape # pylint: disable=invalid-name
assert (B, T + 1) == probab_observations.shape[:2]
return probab_observations[np.arange(B)[:, None], np.arange(T), actions]
示例4: assert_parameters_equal
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ndarray [as 别名]
def assert_parameters_equal(p, p_):
if isinstance(p, np.ndarray):
assert np.array_equal(p, p_)
return
assert isinstance(p, tuple) or isinstance(p, list) or isinstance(p, dict)
assert isinstance(p, tuple) == isinstance(p_, tuple)
assert isinstance(p, list) == isinstance(p_, list)
assert isinstance(p, dict) == isinstance(p_, dict)
assert len(p) == len(p_)
if isinstance(p, dict):
for k, e in p.items():
assert_parameters_equal(e, p_[k])
else:
for e, e_ in zip(p, p_):
assert_parameters_equal(e, e_)
示例5: default_agent
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ndarray [as 别名]
def default_agent(obs_spec: specs.Array,
action_spec: specs.DiscreteArray,
seed: int = 0) -> base.Agent:
"""Initialize a DQN agent with default parameters."""
def network(inputs: jnp.ndarray) -> jnp.ndarray:
flat_inputs = hk.Flatten()(inputs)
mlp = hk.nets.MLP([64, 64, action_spec.num_values])
action_values = mlp(flat_inputs)
return action_values
return DQN(
obs_spec=obs_spec,
action_spec=action_spec,
network=network,
optimizer=optix.adam(1e-3),
batch_size=32,
discount=0.99,
replay_capacity=10000,
min_replay_size=100,
sgd_period=1,
target_update_period=4,
epsilon=0.05,
rng=hk.PRNGSequence(seed),
)
示例6: default_agent
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ndarray [as 别名]
def default_agent(obs_spec: specs.Array,
action_spec: specs.DiscreteArray,
seed: int = 0) -> base.Agent:
"""Creates an actor-critic agent with default hyperparameters."""
def network(inputs: jnp.ndarray) -> Tuple[Logits, Value]:
flat_inputs = hk.Flatten()(inputs)
torso = hk.nets.MLP([64, 64])
policy_head = hk.Linear(action_spec.num_values)
value_head = hk.Linear(1)
embedding = torso(flat_inputs)
logits = policy_head(embedding)
value = value_head(embedding)
return logits, jnp.squeeze(value, axis=-1)
return ActorCritic(
obs_spec=obs_spec,
action_spec=action_spec,
network=network,
optimizer=optix.adam(3e-3),
rng=hk.PRNGSequence(seed),
sequence_length=32,
discount=0.99,
td_lambda=0.9,
)
示例7: make_dataset
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ndarray [as 别名]
def make_dataset(rng_key) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Make simulated dataset where potential customers who get a
sales calls have ~2% higher chance of making another purchase.
"""
key1, key2, key3 = random.split(rng_key, 3)
num_calls = 51342
num_no_calls = 48658
made_purchase_got_called = dist.Bernoulli(0.084).sample(key1, sample_shape=(num_calls,))
made_purchase_no_calls = dist.Bernoulli(0.061).sample(key2, sample_shape=(num_no_calls,))
made_purchase = jnp.concatenate([made_purchase_got_called, made_purchase_no_calls])
is_female = dist.Bernoulli(0.5).sample(key3, sample_shape=(num_calls + num_no_calls,))
got_called = jnp.concatenate([jnp.ones(num_calls), jnp.zeros(num_no_calls)])
design_matrix = jnp.hstack([jnp.ones((num_no_calls + num_calls, 1)),
got_called.reshape(-1, 1),
is_female.reshape(-1, 1)])
return design_matrix, made_purchase
示例8: model
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ndarray [as 别名]
def model(design_matrix: jnp.ndarray, outcome: jnp.ndarray = None) -> None:
"""
Model definition: Log odds of making a purchase is a linear combination
of covariates. Specify a Normal prior over regression coefficients.
:param design_matrix: Covariates. All categorical variables have been one-hot
encoded.
:param outcome: Binary response variable. In this case, whether or not the
customer made a purchase.
"""
beta = numpyro.sample('coefficients', dist.MultivariateNormal(loc=0.,
covariance_matrix=jnp.eye(design_matrix.shape[1])))
logits = design_matrix.dot(beta)
with numpyro.plate('data', design_matrix.shape[0]):
numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=outcome)
示例9: print_results
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ndarray [as 别名]
def print_results(coef: jnp.ndarray, interval_size: float = 0.95) -> None:
"""
Print the confidence interval for the effect size with interval_size
probability mass.
"""
baseline_response = expit(coef[:, 0])
response_with_calls = expit(coef[:, 0] + coef[:, 1])
impact_on_probability = hpdi(response_with_calls - baseline_response, prob=interval_size)
effect_of_gender = hpdi(coef[:, 2], prob=interval_size)
print(f"There is a {interval_size * 100}% probability that calling customers "
"increases the chance they'll make a purchase by "
f"{(100 * impact_on_probability[0]):.2} to {(100 * impact_on_probability[1]):.2} percentage points."
)
print(f"There is a {interval_size * 100}% probability the effect of gender on the log odds of conversion "
f"lies in the interval ({effect_of_gender[0]:.2}, {effect_of_gender[1]:.2f})."
" Since this interval contains 0, we can conclude gender does not impact the conversion rate.")
示例10: run_inference
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ndarray [as 别名]
def run_inference(design_matrix: jnp.ndarray, outcome: jnp.ndarray,
rng_key: jnp.ndarray,
num_warmup: int,
num_samples: int, num_chains: int,
interval_size: float = 0.95) -> None:
"""
Estimate the effect size.
"""
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup, num_samples, num_chains,
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
mcmc.run(rng_key, design_matrix, outcome)
# 0th column is intercept (not getting called)
# 1st column is effect of getting called
# 2nd column is effect of gender (should be none since assigned at random)
coef = mcmc.get_samples()['coefficients']
print_results(coef, interval_size)
示例11: compute_probab_ratios
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ndarray [as 别名]
def compute_probab_ratios(p_new, p_old, actions, reward_mask):
"""Computes the probability ratios for each time-step in a trajectory.
Args:
p_new: ndarray of shape [B, T+1, A] of the log-probabilities that the policy
network assigns to all the actions at each time-step in each batch using
the old parameters.
p_old: ndarray of shape [B, T+1, A], same as above, but using old policy
network parameters.
actions: ndarray of shape [B, T] where each element is from [0, A).
reward_mask: ndarray of shape [B, T] masking over probabilities.
Returns:
probab_ratios: ndarray of shape [B, T], where
probab_ratios_{b,t} = p_new_{b,t,action_{b,t}} / p_old_{b,t,action_{b,t}}
"""
B, T = actions.shape # pylint: disable=invalid-name
assert (B, T + 1) == p_old.shape[:2]
assert (B, T + 1) == p_new.shape[:2]
logp_old = chosen_probabs(p_old, actions)
logp_new = chosen_probabs(p_new, actions)
assert (B, T) == logp_old.shape
assert (B, T) == logp_new.shape
# Since these are log-probabilities, we just subtract them.
probab_ratios = np.exp(logp_new - logp_old) * reward_mask
assert (B, T) == probab_ratios.shape
return probab_ratios
示例12: default_agent
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ndarray [as 别名]
def default_agent(
obs_spec: specs.Array,
action_spec: specs.DiscreteArray,
seed: int = 0,
num_ensemble: int = 20,
) -> BootstrappedDqn:
"""Initialize a Bootstrapped DQN agent with default parameters."""
# Define network.
prior_scale = 3.
hidden_sizes = [50, 50]
def network(inputs: jnp.ndarray) -> jnp.ndarray:
"""Simple Q-network with randomized prior function."""
net = hk.nets.MLP([*hidden_sizes, action_spec.num_values])
prior_net = hk.nets.MLP([*hidden_sizes, action_spec.num_values])
x = hk.Flatten()(inputs)
return net(x) + prior_scale * lax.stop_gradient(prior_net(x))
optimizer = optix.adam(learning_rate=1e-3)
return BootstrappedDqn(
obs_spec=obs_spec,
action_spec=action_spec,
network=network,
batch_size=128,
discount=.99,
num_ensemble=num_ensemble,
replay_capacity=10000,
min_replay_size=128,
sgd_period=1,
target_update_period=4,
optimizer=optimizer,
mask_prob=0.5,
noise_scale=0.,
epsilon_fn=lambda _: 0.,
seed=seed,
)
示例13: get_backend
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ndarray [as 别名]
def get_backend(tensor) -> 'AbstractBackend':
"""
Takes a correct backend (e.g. numpy backend if tensor is numpy.ndarray) for a tensor.
If needed, imports package and creates backend
"""
for framework_name, backend in _backends.items():
if backend.is_appropriate_type(tensor):
return backend
# Find backend subclasses recursively
backend_subclasses = []
backends = AbstractBackend.__subclasses__()
while backends:
backend = backends.pop()
backends += backend.__subclasses__()
backend_subclasses.append(backend)
for BackendSubclass in backend_subclasses:
if _debug_importing:
print('Testing for subclass of ', BackendSubclass)
if BackendSubclass.framework_name not in _backends:
# check that module was already imported. Otherwise it can't be imported
if BackendSubclass.framework_name in sys.modules:
if _debug_importing:
print('Imported backend for ', BackendSubclass.framework_name)
backend = BackendSubclass()
_backends[backend.framework_name] = backend
if backend.is_appropriate_type(tensor):
return backend
raise RuntimeError('Tensor type unknown to einops {}'.format(type(tensor)))
示例14: is_appropriate_type
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ndarray [as 别名]
def is_appropriate_type(self, tensor):
return isinstance(tensor, self.np.ndarray)
示例15: _hashable
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import ndarray [as 别名]
def _hashable(x):
# When the arguments are JITed, ShapedArray is hashable.
if isinstance(x, Tracer):
return x
elif isinstance(x, DeviceArray):
return x.copy().tobytes()
elif isinstance(x, jnp.ndarray):
return x.tobytes()
return x