本文整理汇总了Python中jax.lax.stop_gradient方法的典型用法代码示例。如果您正苦于以下问题:Python lax.stop_gradient方法的具体用法?Python lax.stop_gradient怎么用?Python lax.stop_gradient使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类jax.lax
的用法示例。
在下文中一共展示了lax.stop_gradient方法的6个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _detach
# 需要导入模块: from jax import lax [as 别名]
# 或者: from jax.lax import stop_gradient [as 别名]
def _detach(x):
return lax.stop_gradient(x)
示例2: default_agent
# 需要导入模块: from jax import lax [as 别名]
# 或者: from jax.lax import stop_gradient [as 别名]
def default_agent(
obs_spec: specs.Array,
action_spec: specs.DiscreteArray,
seed: int = 0,
num_ensemble: int = 20,
) -> BootstrappedDqn:
"""Initialize a Bootstrapped DQN agent with default parameters."""
# Define network.
prior_scale = 3.
hidden_sizes = [50, 50]
def network(inputs: jnp.ndarray) -> jnp.ndarray:
"""Simple Q-network with randomized prior function."""
net = hk.nets.MLP([*hidden_sizes, action_spec.num_values])
prior_net = hk.nets.MLP([*hidden_sizes, action_spec.num_values])
x = hk.Flatten()(inputs)
return net(x) + prior_scale * lax.stop_gradient(prior_net(x))
optimizer = optix.adam(learning_rate=1e-3)
return BootstrappedDqn(
obs_spec=obs_spec,
action_spec=action_spec,
network=network,
batch_size=128,
discount=.99,
num_ensemble=num_ensemble,
replay_capacity=10000,
min_replay_size=128,
sgd_period=1,
target_update_period=4,
optimizer=optimizer,
mask_prob=0.5,
noise_scale=0.,
epsilon_fn=lambda _: 0.,
seed=seed,
)
示例3: loss
# 需要导入模块: from jax import lax [as 别名]
# 或者: from jax.lax import stop_gradient [as 别名]
def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
"""
Evaluates the Renyi ELBO with an estimator that uses num_particles many samples/particles.
:param jax.random.PRNGKey rng_key: random number generator seed.
:param dict param_map: dictionary of current parameter values keyed by site
name.
:param model: Python callable with NumPyro primitives for the model.
:param guide: Python callable with NumPyro primitives for the guide.
:param args: arguments to the model / guide (these can possibly vary during
the course of fitting).
:param kwargs: keyword arguments to the model / guide (these can possibly vary
during the course of fitting).
:returns: negative of the Renyi Evidence Lower Bound (ELBO) to be minimized.
"""
def single_particle_elbo(rng_key):
model_seed, guide_seed = random.split(rng_key)
seeded_model = seed(model, model_seed)
seeded_guide = seed(guide, guide_seed)
guide_log_density, guide_trace = log_density(seeded_guide, args, kwargs, param_map)
# NB: we only want to substitute params not available in guide_trace
model_param_map = {k: v for k, v in param_map.items() if k not in guide_trace}
seeded_model = replay(seeded_model, guide_trace)
model_log_density, _ = log_density(seeded_model, args, kwargs, model_param_map)
# log p(z) - log q(z)
elbo = model_log_density - guide_log_density
return elbo
rng_keys = random.split(rng_key, self.num_particles)
elbos = vmap(single_particle_elbo)(rng_keys)
scaled_elbos = (1. - self.alpha) * elbos
avg_log_exp = logsumexp(scaled_elbos) - jnp.log(self.num_particles)
weights = jnp.exp(scaled_elbos - avg_log_exp)
renyi_elbo = avg_log_exp / (1. - self.alpha)
weighted_elbo = jnp.dot(stop_gradient(weights), elbos) / self.num_particles
return - (stop_gradient(renyi_elbo - weighted_elbo) + weighted_elbo)
示例4: logmatmulexp
# 需要导入模块: from jax import lax [as 别名]
# 或者: from jax.lax import stop_gradient [as 别名]
def logmatmulexp(x, y):
"""
Numerically stable version of ``(x.log() @ y.log()).exp()``.
"""
x_shift = lax.stop_gradient(jnp.amax(x, -1, keepdims=True))
y_shift = lax.stop_gradient(jnp.amax(y, -2, keepdims=True))
xy = jnp.log(jnp.matmul(jnp.exp(x - x_shift), jnp.exp(y - y_shift)))
return xy + x_shift + y_shift
示例5: _clamp_preserve_gradients
# 需要导入模块: from jax import lax [as 别名]
# 或者: from jax.lax import stop_gradient [as 别名]
def _clamp_preserve_gradients(x, min, max):
return x + lax.stop_gradient(jnp.clip(x, a_min=min, a_max=max) - x)
# adapted from https://github.com/pyro-ppl/pyro/blob/dev/pyro/distributions/transforms/iaf.py
示例6: run
# 需要导入模块: from jax import lax [as 别名]
# 或者: from jax.lax import stop_gradient [as 别名]
def run(bsuite_id: str) -> str:
"""Runs a DQN agent on a given bsuite environment, logging to CSV."""
env = bsuite.load_and_record(
bsuite_id=bsuite_id,
save_path=FLAGS.save_path,
logging_mode=FLAGS.logging_mode,
overwrite=FLAGS.overwrite,
)
action_spec = env.action_spec()
# Define network.
prior_scale = 3.
hidden_sizes = [64, 64]
def network(inputs: jnp.ndarray) -> jnp.ndarray:
"""Simple Q-network with randomized prior function."""
net = hk.nets.MLP([*hidden_sizes, action_spec.num_values])
prior_net = hk.nets.MLP([*hidden_sizes, action_spec.num_values])
x = hk.Flatten()(inputs)
return net(x) + prior_scale * lax.stop_gradient(prior_net(x))
optimizer = optix.adam(learning_rate=1e-3)
agent = boot_dqn.BootstrappedDqn(
obs_spec=env.observation_spec(),
action_spec=action_spec,
network=network,
optimizer=optimizer,
num_ensemble=FLAGS.num_ensemble,
batch_size=128,
discount=.99,
replay_capacity=10000,
min_replay_size=128,
sgd_period=1,
target_update_period=4,
mask_prob=0.5,
noise_scale=0.,
)
num_episodes = FLAGS.num_episodes or getattr(env, 'bsuite_num_episodes')
experiment.run(
agent=agent,
environment=env,
num_episodes=num_episodes,
verbose=FLAGS.verbose)
return bsuite_id