當前位置: 首頁>>代碼示例>>Python>>正文


Python lax.stop_gradient方法代碼示例

本文整理匯總了Python中jax.lax.stop_gradient方法的典型用法代碼示例。如果您正苦於以下問題:Python lax.stop_gradient方法的具體用法?Python lax.stop_gradient怎麽用?Python lax.stop_gradient使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在jax.lax的用法示例。


在下文中一共展示了lax.stop_gradient方法的6個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: _detach

# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import stop_gradient [as 別名]
def _detach(x):
    return lax.stop_gradient(x) 
開發者ID:pyro-ppl,項目名稱:funsor,代碼行數:4,代碼來源:ops.py

示例2: default_agent

# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import stop_gradient [as 別名]
def default_agent(
    obs_spec: specs.Array,
    action_spec: specs.DiscreteArray,
    seed: int = 0,
    num_ensemble: int = 20,
) -> BootstrappedDqn:
  """Initialize a Bootstrapped DQN agent with default parameters."""

  # Define network.
  prior_scale = 3.
  hidden_sizes = [50, 50]

  def network(inputs: jnp.ndarray) -> jnp.ndarray:
    """Simple Q-network with randomized prior function."""
    net = hk.nets.MLP([*hidden_sizes, action_spec.num_values])
    prior_net = hk.nets.MLP([*hidden_sizes, action_spec.num_values])
    x = hk.Flatten()(inputs)
    return net(x) + prior_scale * lax.stop_gradient(prior_net(x))

  optimizer = optix.adam(learning_rate=1e-3)
  return BootstrappedDqn(
      obs_spec=obs_spec,
      action_spec=action_spec,
      network=network,
      batch_size=128,
      discount=.99,
      num_ensemble=num_ensemble,
      replay_capacity=10000,
      min_replay_size=128,
      sgd_period=1,
      target_update_period=4,
      optimizer=optimizer,
      mask_prob=0.5,
      noise_scale=0.,
      epsilon_fn=lambda _: 0.,
      seed=seed,
  ) 
開發者ID:deepmind,項目名稱:bsuite,代碼行數:39,代碼來源:agent.py

示例3: loss

# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import stop_gradient [as 別名]
def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
        """
        Evaluates the Renyi ELBO with an estimator that uses num_particles many samples/particles.

        :param jax.random.PRNGKey rng_key: random number generator seed.
        :param dict param_map: dictionary of current parameter values keyed by site
            name.
        :param model: Python callable with NumPyro primitives for the model.
        :param guide: Python callable with NumPyro primitives for the guide.
        :param args: arguments to the model / guide (these can possibly vary during
            the course of fitting).
        :param kwargs: keyword arguments to the model / guide (these can possibly vary
            during the course of fitting).
        :returns: negative of the Renyi Evidence Lower Bound (ELBO) to be minimized.
        """
        def single_particle_elbo(rng_key):
            model_seed, guide_seed = random.split(rng_key)
            seeded_model = seed(model, model_seed)
            seeded_guide = seed(guide, guide_seed)
            guide_log_density, guide_trace = log_density(seeded_guide, args, kwargs, param_map)
            # NB: we only want to substitute params not available in guide_trace
            model_param_map = {k: v for k, v in param_map.items() if k not in guide_trace}
            seeded_model = replay(seeded_model, guide_trace)
            model_log_density, _ = log_density(seeded_model, args, kwargs, model_param_map)

            # log p(z) - log q(z)
            elbo = model_log_density - guide_log_density
            return elbo

        rng_keys = random.split(rng_key, self.num_particles)
        elbos = vmap(single_particle_elbo)(rng_keys)
        scaled_elbos = (1. - self.alpha) * elbos
        avg_log_exp = logsumexp(scaled_elbos) - jnp.log(self.num_particles)
        weights = jnp.exp(scaled_elbos - avg_log_exp)
        renyi_elbo = avg_log_exp / (1. - self.alpha)
        weighted_elbo = jnp.dot(stop_gradient(weights), elbos) / self.num_particles
        return - (stop_gradient(renyi_elbo - weighted_elbo) + weighted_elbo) 
開發者ID:pyro-ppl,項目名稱:numpyro,代碼行數:39,代碼來源:elbo.py

示例4: logmatmulexp

# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import stop_gradient [as 別名]
def logmatmulexp(x, y):
    """
    Numerically stable version of ``(x.log() @ y.log()).exp()``.
    """
    x_shift = lax.stop_gradient(jnp.amax(x, -1, keepdims=True))
    y_shift = lax.stop_gradient(jnp.amax(y, -2, keepdims=True))
    xy = jnp.log(jnp.matmul(jnp.exp(x - x_shift), jnp.exp(y - y_shift)))
    return xy + x_shift + y_shift 
開發者ID:pyro-ppl,項目名稱:numpyro,代碼行數:10,代碼來源:util.py

示例5: _clamp_preserve_gradients

# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import stop_gradient [as 別名]
def _clamp_preserve_gradients(x, min, max):
    return x + lax.stop_gradient(jnp.clip(x, a_min=min, a_max=max) - x)


# adapted from https://github.com/pyro-ppl/pyro/blob/dev/pyro/distributions/transforms/iaf.py 
開發者ID:pyro-ppl,項目名稱:numpyro,代碼行數:7,代碼來源:flows.py

示例6: run

# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import stop_gradient [as 別名]
def run(bsuite_id: str) -> str:
  """Runs a DQN agent on a given bsuite environment, logging to CSV."""

  env = bsuite.load_and_record(
      bsuite_id=bsuite_id,
      save_path=FLAGS.save_path,
      logging_mode=FLAGS.logging_mode,
      overwrite=FLAGS.overwrite,
  )
  action_spec = env.action_spec()

  # Define network.
  prior_scale = 3.
  hidden_sizes = [64, 64]
  def network(inputs: jnp.ndarray) -> jnp.ndarray:
    """Simple Q-network with randomized prior function."""
    net = hk.nets.MLP([*hidden_sizes, action_spec.num_values])
    prior_net = hk.nets.MLP([*hidden_sizes, action_spec.num_values])
    x = hk.Flatten()(inputs)
    return net(x) + prior_scale * lax.stop_gradient(prior_net(x))

  optimizer = optix.adam(learning_rate=1e-3)

  agent = boot_dqn.BootstrappedDqn(
      obs_spec=env.observation_spec(),
      action_spec=action_spec,
      network=network,
      optimizer=optimizer,
      num_ensemble=FLAGS.num_ensemble,
      batch_size=128,
      discount=.99,
      replay_capacity=10000,
      min_replay_size=128,
      sgd_period=1,
      target_update_period=4,
      mask_prob=0.5,
      noise_scale=0.,
  )

  num_episodes = FLAGS.num_episodes or getattr(env, 'bsuite_num_episodes')
  experiment.run(
      agent=agent,
      environment=env,
      num_episodes=num_episodes,
      verbose=FLAGS.verbose)

  return bsuite_id 
開發者ID:deepmind,項目名稱:bsuite,代碼行數:49,代碼來源:run.py


注:本文中的jax.lax.stop_gradient方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。