当前位置: 首页>>代码示例>>Python>>正文


Python lax.stop_gradient方法代码示例

本文整理汇总了Python中jax.lax.stop_gradient方法的典型用法代码示例。如果您正苦于以下问题:Python lax.stop_gradient方法的具体用法?Python lax.stop_gradient怎么用?Python lax.stop_gradient使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在jax.lax的用法示例。


在下文中一共展示了lax.stop_gradient方法的6个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _detach

# 需要导入模块: from jax import lax [as 别名]
# 或者: from jax.lax import stop_gradient [as 别名]
def _detach(x):
    return lax.stop_gradient(x) 
开发者ID:pyro-ppl,项目名称:funsor,代码行数:4,代码来源:ops.py

示例2: default_agent

# 需要导入模块: from jax import lax [as 别名]
# 或者: from jax.lax import stop_gradient [as 别名]
def default_agent(
    obs_spec: specs.Array,
    action_spec: specs.DiscreteArray,
    seed: int = 0,
    num_ensemble: int = 20,
) -> BootstrappedDqn:
  """Initialize a Bootstrapped DQN agent with default parameters."""

  # Define network.
  prior_scale = 3.
  hidden_sizes = [50, 50]

  def network(inputs: jnp.ndarray) -> jnp.ndarray:
    """Simple Q-network with randomized prior function."""
    net = hk.nets.MLP([*hidden_sizes, action_spec.num_values])
    prior_net = hk.nets.MLP([*hidden_sizes, action_spec.num_values])
    x = hk.Flatten()(inputs)
    return net(x) + prior_scale * lax.stop_gradient(prior_net(x))

  optimizer = optix.adam(learning_rate=1e-3)
  return BootstrappedDqn(
      obs_spec=obs_spec,
      action_spec=action_spec,
      network=network,
      batch_size=128,
      discount=.99,
      num_ensemble=num_ensemble,
      replay_capacity=10000,
      min_replay_size=128,
      sgd_period=1,
      target_update_period=4,
      optimizer=optimizer,
      mask_prob=0.5,
      noise_scale=0.,
      epsilon_fn=lambda _: 0.,
      seed=seed,
  ) 
开发者ID:deepmind,项目名称:bsuite,代码行数:39,代码来源:agent.py

示例3: loss

# 需要导入模块: from jax import lax [as 别名]
# 或者: from jax.lax import stop_gradient [as 别名]
def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
        """
        Evaluates the Renyi ELBO with an estimator that uses num_particles many samples/particles.

        :param jax.random.PRNGKey rng_key: random number generator seed.
        :param dict param_map: dictionary of current parameter values keyed by site
            name.
        :param model: Python callable with NumPyro primitives for the model.
        :param guide: Python callable with NumPyro primitives for the guide.
        :param args: arguments to the model / guide (these can possibly vary during
            the course of fitting).
        :param kwargs: keyword arguments to the model / guide (these can possibly vary
            during the course of fitting).
        :returns: negative of the Renyi Evidence Lower Bound (ELBO) to be minimized.
        """
        def single_particle_elbo(rng_key):
            model_seed, guide_seed = random.split(rng_key)
            seeded_model = seed(model, model_seed)
            seeded_guide = seed(guide, guide_seed)
            guide_log_density, guide_trace = log_density(seeded_guide, args, kwargs, param_map)
            # NB: we only want to substitute params not available in guide_trace
            model_param_map = {k: v for k, v in param_map.items() if k not in guide_trace}
            seeded_model = replay(seeded_model, guide_trace)
            model_log_density, _ = log_density(seeded_model, args, kwargs, model_param_map)

            # log p(z) - log q(z)
            elbo = model_log_density - guide_log_density
            return elbo

        rng_keys = random.split(rng_key, self.num_particles)
        elbos = vmap(single_particle_elbo)(rng_keys)
        scaled_elbos = (1. - self.alpha) * elbos
        avg_log_exp = logsumexp(scaled_elbos) - jnp.log(self.num_particles)
        weights = jnp.exp(scaled_elbos - avg_log_exp)
        renyi_elbo = avg_log_exp / (1. - self.alpha)
        weighted_elbo = jnp.dot(stop_gradient(weights), elbos) / self.num_particles
        return - (stop_gradient(renyi_elbo - weighted_elbo) + weighted_elbo) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:39,代码来源:elbo.py

示例4: logmatmulexp

# 需要导入模块: from jax import lax [as 别名]
# 或者: from jax.lax import stop_gradient [as 别名]
def logmatmulexp(x, y):
    """
    Numerically stable version of ``(x.log() @ y.log()).exp()``.
    """
    x_shift = lax.stop_gradient(jnp.amax(x, -1, keepdims=True))
    y_shift = lax.stop_gradient(jnp.amax(y, -2, keepdims=True))
    xy = jnp.log(jnp.matmul(jnp.exp(x - x_shift), jnp.exp(y - y_shift)))
    return xy + x_shift + y_shift 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:10,代码来源:util.py

示例5: _clamp_preserve_gradients

# 需要导入模块: from jax import lax [as 别名]
# 或者: from jax.lax import stop_gradient [as 别名]
def _clamp_preserve_gradients(x, min, max):
    return x + lax.stop_gradient(jnp.clip(x, a_min=min, a_max=max) - x)


# adapted from https://github.com/pyro-ppl/pyro/blob/dev/pyro/distributions/transforms/iaf.py 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:7,代码来源:flows.py

示例6: run

# 需要导入模块: from jax import lax [as 别名]
# 或者: from jax.lax import stop_gradient [as 别名]
def run(bsuite_id: str) -> str:
  """Runs a DQN agent on a given bsuite environment, logging to CSV."""

  env = bsuite.load_and_record(
      bsuite_id=bsuite_id,
      save_path=FLAGS.save_path,
      logging_mode=FLAGS.logging_mode,
      overwrite=FLAGS.overwrite,
  )
  action_spec = env.action_spec()

  # Define network.
  prior_scale = 3.
  hidden_sizes = [64, 64]
  def network(inputs: jnp.ndarray) -> jnp.ndarray:
    """Simple Q-network with randomized prior function."""
    net = hk.nets.MLP([*hidden_sizes, action_spec.num_values])
    prior_net = hk.nets.MLP([*hidden_sizes, action_spec.num_values])
    x = hk.Flatten()(inputs)
    return net(x) + prior_scale * lax.stop_gradient(prior_net(x))

  optimizer = optix.adam(learning_rate=1e-3)

  agent = boot_dqn.BootstrappedDqn(
      obs_spec=env.observation_spec(),
      action_spec=action_spec,
      network=network,
      optimizer=optimizer,
      num_ensemble=FLAGS.num_ensemble,
      batch_size=128,
      discount=.99,
      replay_capacity=10000,
      min_replay_size=128,
      sgd_period=1,
      target_update_period=4,
      mask_prob=0.5,
      noise_scale=0.,
  )

  num_episodes = FLAGS.num_episodes or getattr(env, 'bsuite_num_episodes')
  experiment.run(
      agent=agent,
      environment=env,
      num_episodes=num_episodes,
      verbose=FLAGS.verbose)

  return bsuite_id 
开发者ID:deepmind,项目名称:bsuite,代码行数:49,代码来源:run.py


注:本文中的jax.lax.stop_gradient方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。