本文整理汇总了Python中jax.numpy.float32方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.float32方法的具体用法?Python numpy.float32怎么用?Python numpy.float32使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类jax.numpy
的用法示例。
在下文中一共展示了numpy.float32方法的11个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: get_padding_value
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import float32 [as 别名]
def get_padding_value(dtype):
"""Returns the padding value given a dtype."""
padding_value = None
if dtype == np.uint8:
padding_value = np.uint8(0)
elif dtype == np.uint16:
padding_value = np.uint16(0)
elif dtype == np.float32 or dtype == np.float64:
padding_value = 0.0
else:
padding_value = 0
assert padding_value is not None
return padding_value
# TODO(afrozm): Use np.pad instead and make jittable?
示例2: _one_hot
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import float32 [as 别名]
def _one_hot(x, k, dtype=np.float32):
"""Create a one-hot encoding of x of size k."""
return np.array(x[:, None] == np.arange(k), dtype)
示例3: mnist
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import float32 [as 别名]
def mnist():
# https://github.com/google/jax/blob/master/docs/gpu_memory_allocation.rst
import tensorflow as tf
tf.config.experimental.set_visible_devices([], "GPU")
import tensorflow_datasets as tfds
dataset = tfds.load("mnist:1.0.0")
images = lambda d: np.reshape(np.float32(d['image']) / 256, (-1, 784))
labels = lambda d: _one_hot(d['label'], 10)
train = next(tfds.as_numpy(dataset['train'].shuffle(50000).batch(50000)))
test = next(tfds.as_numpy(dataset['test'].batch(10000)))
return images(train), labels(train), images(test), labels(test)
示例4: random_inputs
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import float32 [as 别名]
def random_inputs(input_shape, key=PRNGKey(0)):
if type(input_shape) is tuple:
return random.uniform(key, input_shape, np.float32)
elif type(input_shape) is list:
return [random_inputs(key, shape) for shape in input_shape]
else:
raise TypeError(type(input_shape))
示例5: float32
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import float32 [as 别名]
def float32(self):
return np.float32
示例6: __init__
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import float32 [as 别名]
def __init__(self, **kwargs):
self.name = 'jax'
self.precision = kwargs.get('precision', '64b')
self.dtypemap = {
'float': np.float64 if self.precision == '64b' else np.float32,
'int': np.int64 if self.precision == '64b' else np.int32,
'bool': np.bool_,
}
config.update('jax_enable_x64', self.precision == '64b')
示例7: default_agent
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import float32 [as 别名]
def default_agent(obs_spec: specs.Array,
action_spec: specs.DiscreteArray,
seed: int = 0) -> base.Agent:
"""Creates an actor-critic agent with default hyperparameters."""
hidden_size = 256
initial_rnn_state = hk.LSTMState(
hidden=jnp.zeros((1, hidden_size), dtype=jnp.float32),
cell=jnp.zeros((1, hidden_size), dtype=jnp.float32))
def network(inputs: jnp.ndarray,
state) -> Tuple[Tuple[Logits, Value], LSTMState]:
flat_inputs = hk.Flatten()(inputs)
torso = hk.nets.MLP([hidden_size, hidden_size])
lstm = hk.LSTM(hidden_size)
policy_head = hk.Linear(action_spec.num_values)
value_head = hk.Linear(1)
embedding = torso(flat_inputs)
embedding, state = lstm(embedding, state)
logits = policy_head(embedding)
value = value_head(embedding)
return (logits, jnp.squeeze(value, axis=-1)), state
return ActorCriticRNN(
obs_spec=obs_spec,
action_spec=action_spec,
network=network,
initial_rnn_state=initial_rnn_state,
optimizer=optix.adam(3e-3),
rng=hk.PRNGSequence(seed),
sequence_length=32,
discount=0.99,
td_lambda=0.9,
)
示例8: test_mnist_data_load
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import float32 [as 别名]
def test_mnist_data_load():
def mean_pixels(i, mean_pix):
batch, _ = fetch(i, idx)
return mean_pix + jnp.sum(batch) / batch.size
init, fetch = load_dataset(MNIST, batch_size=128, split='train')
num_batches, idx = init()
assert fori_loop(0, num_batches, mean_pixels, jnp.float32(0.)) / num_batches < 0.15
示例9: one_hot
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import float32 [as 别名]
def one_hot(x, k, dtype=np.float32):
"""Create a one-hot encoding of x of size k."""
return np.array(x[:, None] == np.arange(k), dtype)
示例10: update
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import float32 [as 别名]
def update(
self,
timestep: dm_env.TimeStep,
action: base.Action,
new_timestep: dm_env.TimeStep,
):
"""Update the agent: add transition to replay and periodically do SGD."""
# Thompson sampling: every episode pick a new Q-network as the policy.
if new_timestep.last():
k = np.random.randint(self._num_ensemble)
self._active_head = self._ensemble[k]
# Generate bootstrapping mask & reward noise.
mask = np.random.binomial(1, self._mask_prob, self._num_ensemble)
noise = np.random.randn(self._num_ensemble)
# Make transition and add to replay.
transition = [
timestep.observation,
action,
np.float32(new_timestep.reward),
np.float32(new_timestep.discount),
new_timestep.observation,
mask,
noise,
]
self._replay.add(transition)
if self._replay.size < self._min_replay_size:
return
# Periodically sample from replay and do SGD for the whole ensemble.
if self._total_steps % self._sgd_period == 0:
transitions = self._replay.sample(self._batch_size)
o_tm1, a_tm1, r_t, d_t, o_t, m_t, z_t = transitions
for k, state in enumerate(self._ensemble):
transitions = [o_tm1, a_tm1, r_t, d_t, o_t, m_t[:, k], z_t[:, k]]
self._ensemble[k] = self._sgd_step(state, transitions)
# Periodically update target parameters.
for k, state in enumerate(self._ensemble):
if state.step % self._target_update_period == 0:
self._ensemble[k] = state._replace(target_params=state.params)
示例11: __init__
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import float32 [as 别名]
def __init__(
self,
obs_spec: specs.Array,
action_spec: specs.DiscreteArray,
network: PolicyValueNet,
optimizer: optix.InitUpdate,
rng: hk.PRNGSequence,
sequence_length: int,
discount: float,
td_lambda: float,
):
# Define loss function.
def loss(trajectory: sequence.Trajectory) -> jnp.ndarray:
""""Actor-critic loss."""
logits, values = network(trajectory.observations)
td_errors = rlax.td_lambda(
v_tm1=values[:-1],
r_t=trajectory.rewards,
discount_t=trajectory.discounts * discount,
v_t=values[1:],
lambda_=jnp.array(td_lambda),
)
critic_loss = jnp.mean(td_errors**2)
actor_loss = rlax.policy_gradient_loss(
logits_t=logits[:-1],
a_t=trajectory.actions,
adv_t=td_errors,
w_t=jnp.ones_like(td_errors))
return actor_loss + critic_loss
# Transform the loss into a pure function.
loss_fn = hk.transform(loss).apply
# Define update function.
@jax.jit
def sgd_step(state: TrainingState,
trajectory: sequence.Trajectory) -> TrainingState:
"""Does a step of SGD over a trajectory."""
gradients = jax.grad(loss_fn)(state.params, trajectory)
updates, new_opt_state = optimizer.update(gradients, state.opt_state)
new_params = optix.apply_updates(state.params, updates)
return TrainingState(params=new_params, opt_state=new_opt_state)
# Initialize network parameters and optimiser state.
init, forward = hk.transform(network)
dummy_observation = jnp.zeros((1, *obs_spec.shape), dtype=jnp.float32)
initial_params = init(next(rng), dummy_observation)
initial_opt_state = optimizer.init(initial_params)
# Internalize state.
self._state = TrainingState(initial_params, initial_opt_state)
self._forward = jax.jit(forward)
self._buffer = sequence.Buffer(obs_spec, action_spec, sequence_length)
self._sgd_step = sgd_step
self._rng = rng