当前位置: 首页>>代码示例>>Python>>正文


Python numpy.float32方法代码示例

本文整理汇总了Python中jax.numpy.float32方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.float32方法的具体用法?Python numpy.float32怎么用?Python numpy.float32使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在jax.numpy的用法示例。


在下文中一共展示了numpy.float32方法的11个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get_padding_value

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import float32 [as 别名]
def get_padding_value(dtype):
  """Returns the padding value given a dtype."""
  padding_value = None
  if dtype == np.uint8:
    padding_value = np.uint8(0)
  elif dtype == np.uint16:
    padding_value = np.uint16(0)
  elif dtype == np.float32 or dtype == np.float64:
    padding_value = 0.0
  else:
    padding_value = 0
  assert padding_value is not None
  return padding_value


# TODO(afrozm): Use np.pad instead and make jittable? 
开发者ID:yyht,项目名称:BERT,代码行数:18,代码来源:ppo.py

示例2: _one_hot

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import float32 [as 别名]
def _one_hot(x, k, dtype=np.float32):
    """Create a one-hot encoding of x of size k."""
    return np.array(x[:, None] == np.arange(k), dtype) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:5,代码来源:mnist_classifier.py

示例3: mnist

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import float32 [as 别名]
def mnist():
    # https://github.com/google/jax/blob/master/docs/gpu_memory_allocation.rst
    import tensorflow as tf
    tf.config.experimental.set_visible_devices([], "GPU")

    import tensorflow_datasets as tfds
    dataset = tfds.load("mnist:1.0.0")
    images = lambda d: np.reshape(np.float32(d['image']) / 256, (-1, 784))
    labels = lambda d: _one_hot(d['label'], 10)
    train = next(tfds.as_numpy(dataset['train'].shuffle(50000).batch(50000)))
    test = next(tfds.as_numpy(dataset['test'].batch(10000)))
    return images(train), labels(train), images(test), labels(test) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:14,代码来源:mnist_classifier.py

示例4: random_inputs

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import float32 [as 别名]
def random_inputs(input_shape, key=PRNGKey(0)):
    if type(input_shape) is tuple:
        return random.uniform(key, input_shape, np.float32)
    elif type(input_shape) is list:
        return [random_inputs(key, shape) for shape in input_shape]
    else:
        raise TypeError(type(input_shape)) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:9,代码来源:util.py

示例5: float32

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import float32 [as 别名]
def float32(self):
        return np.float32 
开发者ID:sharadmv,项目名称:deepx,代码行数:4,代码来源:jax.py

示例6: __init__

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import float32 [as 别名]
def __init__(self, **kwargs):
        self.name = 'jax'
        self.precision = kwargs.get('precision', '64b')
        self.dtypemap = {
            'float': np.float64 if self.precision == '64b' else np.float32,
            'int': np.int64 if self.precision == '64b' else np.int32,
            'bool': np.bool_,
        }
        config.update('jax_enable_x64', self.precision == '64b') 
开发者ID:scikit-hep,项目名称:pyhf,代码行数:11,代码来源:jax_backend.py

示例7: default_agent

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import float32 [as 别名]
def default_agent(obs_spec: specs.Array,
                  action_spec: specs.DiscreteArray,
                  seed: int = 0) -> base.Agent:
  """Creates an actor-critic agent with default hyperparameters."""

  hidden_size = 256
  initial_rnn_state = hk.LSTMState(
      hidden=jnp.zeros((1, hidden_size), dtype=jnp.float32),
      cell=jnp.zeros((1, hidden_size), dtype=jnp.float32))

  def network(inputs: jnp.ndarray,
              state) -> Tuple[Tuple[Logits, Value], LSTMState]:
    flat_inputs = hk.Flatten()(inputs)
    torso = hk.nets.MLP([hidden_size, hidden_size])
    lstm = hk.LSTM(hidden_size)
    policy_head = hk.Linear(action_spec.num_values)
    value_head = hk.Linear(1)

    embedding = torso(flat_inputs)
    embedding, state = lstm(embedding, state)
    logits = policy_head(embedding)
    value = value_head(embedding)
    return (logits, jnp.squeeze(value, axis=-1)), state

  return ActorCriticRNN(
      obs_spec=obs_spec,
      action_spec=action_spec,
      network=network,
      initial_rnn_state=initial_rnn_state,
      optimizer=optix.adam(3e-3),
      rng=hk.PRNGSequence(seed),
      sequence_length=32,
      discount=0.99,
      td_lambda=0.9,
  ) 
开发者ID:deepmind,项目名称:bsuite,代码行数:37,代码来源:agent.py

示例8: test_mnist_data_load

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import float32 [as 别名]
def test_mnist_data_load():
    def mean_pixels(i, mean_pix):
        batch, _ = fetch(i, idx)
        return mean_pix + jnp.sum(batch) / batch.size

    init, fetch = load_dataset(MNIST, batch_size=128, split='train')
    num_batches, idx = init()
    assert fori_loop(0, num_batches, mean_pixels, jnp.float32(0.)) / num_batches < 0.15 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:10,代码来源:test_example_utils.py

示例9: one_hot

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import float32 [as 别名]
def one_hot(x, k, dtype=np.float32):
  """Create a one-hot encoding of x of size k."""
  return np.array(x[:, None] == np.arange(k), dtype) 
开发者ID:tensorflow,项目名称:cleverhans,代码行数:5,代码来源:utils.py

示例10: update

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import float32 [as 别名]
def update(
      self,
      timestep: dm_env.TimeStep,
      action: base.Action,
      new_timestep: dm_env.TimeStep,
  ):
    """Update the agent: add transition to replay and periodically do SGD."""

    # Thompson sampling: every episode pick a new Q-network as the policy.
    if new_timestep.last():
      k = np.random.randint(self._num_ensemble)
      self._active_head = self._ensemble[k]

    # Generate bootstrapping mask & reward noise.
    mask = np.random.binomial(1, self._mask_prob, self._num_ensemble)
    noise = np.random.randn(self._num_ensemble)

    # Make transition and add to replay.
    transition = [
        timestep.observation,
        action,
        np.float32(new_timestep.reward),
        np.float32(new_timestep.discount),
        new_timestep.observation,
        mask,
        noise,
    ]
    self._replay.add(transition)

    if self._replay.size < self._min_replay_size:
      return

    # Periodically sample from replay and do SGD for the whole ensemble.
    if self._total_steps % self._sgd_period == 0:
      transitions = self._replay.sample(self._batch_size)
      o_tm1, a_tm1, r_t, d_t, o_t, m_t, z_t = transitions
      for k, state in enumerate(self._ensemble):
        transitions = [o_tm1, a_tm1, r_t, d_t, o_t, m_t[:, k], z_t[:, k]]
        self._ensemble[k] = self._sgd_step(state, transitions)

    # Periodically update target parameters.
    for k, state in enumerate(self._ensemble):
      if state.step % self._target_update_period == 0:
        self._ensemble[k] = state._replace(target_params=state.params) 
开发者ID:deepmind,项目名称:bsuite,代码行数:46,代码来源:agent.py

示例11: __init__

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import float32 [as 别名]
def __init__(
      self,
      obs_spec: specs.Array,
      action_spec: specs.DiscreteArray,
      network: PolicyValueNet,
      optimizer: optix.InitUpdate,
      rng: hk.PRNGSequence,
      sequence_length: int,
      discount: float,
      td_lambda: float,
  ):

    # Define loss function.
    def loss(trajectory: sequence.Trajectory) -> jnp.ndarray:
      """"Actor-critic loss."""
      logits, values = network(trajectory.observations)
      td_errors = rlax.td_lambda(
          v_tm1=values[:-1],
          r_t=trajectory.rewards,
          discount_t=trajectory.discounts * discount,
          v_t=values[1:],
          lambda_=jnp.array(td_lambda),
      )
      critic_loss = jnp.mean(td_errors**2)
      actor_loss = rlax.policy_gradient_loss(
          logits_t=logits[:-1],
          a_t=trajectory.actions,
          adv_t=td_errors,
          w_t=jnp.ones_like(td_errors))

      return actor_loss + critic_loss

    # Transform the loss into a pure function.
    loss_fn = hk.transform(loss).apply

    # Define update function.
    @jax.jit
    def sgd_step(state: TrainingState,
                 trajectory: sequence.Trajectory) -> TrainingState:
      """Does a step of SGD over a trajectory."""
      gradients = jax.grad(loss_fn)(state.params, trajectory)
      updates, new_opt_state = optimizer.update(gradients, state.opt_state)
      new_params = optix.apply_updates(state.params, updates)
      return TrainingState(params=new_params, opt_state=new_opt_state)

    # Initialize network parameters and optimiser state.
    init, forward = hk.transform(network)
    dummy_observation = jnp.zeros((1, *obs_spec.shape), dtype=jnp.float32)
    initial_params = init(next(rng), dummy_observation)
    initial_opt_state = optimizer.init(initial_params)

    # Internalize state.
    self._state = TrainingState(initial_params, initial_opt_state)
    self._forward = jax.jit(forward)
    self._buffer = sequence.Buffer(obs_spec, action_spec, sequence_length)
    self._sgd_step = sgd_step
    self._rng = rng 
开发者ID:deepmind,项目名称:bsuite,代码行数:59,代码来源:agent.py


注:本文中的jax.numpy.float32方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。