当前位置: 首页>>代码示例>>Python>>正文


Python random.PRNGKey方法代码示例

本文整理汇总了Python中jax.random.PRNGKey方法的典型用法代码示例。如果您正苦于以下问题:Python random.PRNGKey方法的具体用法?Python random.PRNGKey怎么用?Python random.PRNGKey使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在jax.random的用法示例。


在下文中一共展示了random.PRNGKey方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __init__

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import PRNGKey [as 别名]
def __init__(self, obs_dim, *, seed=None):
        """Internal setup for Jax-based reward models.

        Initialises reward model using given seed & input size (`obs_dim`).

        Args:
            obs_dim (int): dimensionality of observation space.
            seed (int or None): random seed for generating initial params. If
                None, seed will be chosen arbitrarily, as in
                LinearRewardModel.
        """
        # TODO: apply jax.jit() to everything in sight
        net_init, self._net_apply = self.make_stax_model()
        if seed is None:
            # oh well
            seed = np.random.randint((1 << 63) - 1)
        rng = jrandom.PRNGKey(seed)
        out_shape, self._net_params = net_init(rng, (-1, obs_dim))
        self._net_grads = jax.grad(self._net_apply)
        # output shape should just be batch dim, nothing else
        assert out_shape == (-1,), "got a weird output shape %s" % (out_shape,) 
开发者ID:HumanCompatibleAI,项目名称:imitation,代码行数:23,代码来源:tabular_irl.py

示例2: jax_randint

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import PRNGKey [as 别名]
def jax_randint(key, shape, minval, maxval, dtype=np.int32):
  """Sample uniform random values in [minval, maxval) with given shape/dtype.

  Args:
    key: a PRNGKey used as the random key.
    shape: a tuple of nonnegative integers representing the shape.
    minval: int or array of ints broadcast-compatible with ``shape``, a minimum
      (inclusive) value for the range.
    maxval: int or array of ints broadcast-compatible with  ``shape``, a maximum
      (exclusive) value for the range.
    dtype: optional, an int dtype for the returned values (default int32).

  Returns:
    A random array with the specified shape and dtype.
  """
  return jax_random.randint(key, shape, minval=minval, maxval=maxval,
                            dtype=dtype) 
开发者ID:google,项目名称:trax,代码行数:19,代码来源:jax.py

示例3: _flat_reuse_dicts

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import PRNGKey [as 别名]
def _flat_reuse_dicts(reuse, *example_inputs):
        r = {}

        for module, parameters in reuse.items():
            inputs = example_inputs
            if isinstance(module, ShapedParametrized):
                module, inputs = module.parametrized, module.example_inputs

            if not isinstance(module, parametrized):
                raise ValueError('Keys for reuse must be parametrized or ShapedParametrized.')

            example_dict, _ = module._init_and_apply_parameters_dict(*inputs, key=PRNGKey(0))
            params_dict = parametrized._parameters_dict(parameters, example_dict)
            r.update(module._flatten_dict(params_dict))

        return r 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:18,代码来源:core.py

示例4: main

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import PRNGKey [as 别名]
def main(batch_size=32, nr_filters=8, epochs=10, step_size=.001, decay_rate=.999995,
         model_path=Path('./pixelcnn.params')):
    loss, _ = PixelCNNPP(nr_filters=nr_filters)
    get_train_batches, test_batches = dataset(batch_size)
    key, init_key = random.split(PRNGKey(0))
    opt = Adam(exponential_decay(step_size, 1, decay_rate))
    state = opt.init(loss.init_parameters(next(test_batches), key=init_key))

    for epoch in range(epochs):
        for batch in get_train_batches():
            key, update_key = random.split(key)
            i = opt.get_step(state)

            state, train_loss = opt.update_and_get_loss(loss.apply, state, batch, key=update_key,
                                                        jit=True)

            if i % 100 == 0 or i < 10:
                key, test_key = random.split(key)
                test_loss = loss.apply(opt.get_parameters(state), next(test_batches), key=test_key,
                                       jit=True)
                print(f"Epoch {epoch}, iteration {i}, "
                      f"train loss {train_loss:.3f}, "
                      f"test loss {test_loss:.3f} ")

        save(opt.get_parameters(state), model_path) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:27,代码来源:pixelcnn.py

示例5: test_readme

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import PRNGKey [as 别名]
def test_readme():
    net = Sequential(Dense(1024), relu, Dense(1024), relu, Dense(4), log_softmax)

    @parametrized
    def loss(inputs, targets):
        return -jnp.mean(net(inputs) * targets)

    def next_batch(): return jnp.zeros((3, 784)), jnp.zeros((3, 4))

    params = loss.init_parameters(*next_batch(), key=PRNGKey(0))

    print(params.sequential.dense2.bias)  # [-0.01101029, -0.00749435, -0.00952365,  0.00493979]

    assert jnp.allclose([-0.01101029, -0.00749435, -0.00952365, 0.00493979],
                        params.sequential.dense2.bias)

    out = loss.apply(params, *next_batch())
    assert () == out.shape

    out_ = loss.apply(params, *next_batch(), jit=True)
    assert out.shape == out_.shape 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:23,代码来源:test_examples.py

示例6: test_Parameter_dense

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import PRNGKey [as 别名]
def test_Parameter_dense():
    def Dense(out_dim, kernel_init=glorot_normal(), bias_init=normal()):
        @parametrized
        def dense(inputs):
            kernel = parameter((inputs.shape[-1], out_dim), kernel_init)
            bias = parameter((out_dim,), bias_init)
            return jnp.dot(inputs, kernel) + bias

        return dense

    net = Dense(2)
    inputs = jnp.zeros((1, 3))
    params = net.init_parameters(inputs, key=PRNGKey(0))
    assert (3, 2) == params.parameter0.shape
    assert (2,) == params.parameter1.shape

    out = net.apply(params, inputs, jit=True)
    assert (1, 2) == out.shape 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:20,代码来源:test_examples.py

示例7: test_mnist_classifier

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import PRNGKey [as 别名]
def test_mnist_classifier():
    from examples.mnist_classifier import predict, loss, accuracy

    next_batch = lambda: (jnp.zeros((3, 784)), jnp.zeros((3, 10)))
    opt = optimizers.Momentum(0.001, mass=0.9)
    state = opt.init(loss.init_parameters(*next_batch(), key=PRNGKey(0)))

    t = time.time()
    for _ in range(10):
        state = opt.update(loss.apply, state, *next_batch(), jit=True)

    elapsed = time.time() - t
    assert 5 > elapsed

    params = opt.get_parameters(state)
    train_acc = accuracy.apply_from({loss: params}, *next_batch(), jit=True)
    assert () == train_acc.shape

    predict_params = predict.parameters_from({loss.shaped(*next_batch()): params}, next_batch()[0])
    predictions = predict.apply(predict_params, next_batch()[0], jit=True)
    assert (3, 10) == predictions.shape 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:23,代码来源:test_examples.py

示例8: test_mnist_vae

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import PRNGKey [as 别名]
def test_mnist_vae():
    @parametrized
    def encode(input):
        input = Sequential(Dense(5), relu, Dense(5), relu)(input)
        mean = Dense(10)(input)
        variance = Sequential(Dense(10), softplus)(input)
        return mean, variance

    decode = Sequential(Dense(5), relu, Dense(5), relu, Dense(5 * 5))

    @parametrized
    def elbo(key, images):
        mu_z, sigmasq_z = encode(images)
        logits_x = decode(gaussian_sample(key, mu_z, sigmasq_z))
        return bernoulli_logpdf(logits_x, images) - gaussian_kl(mu_z, sigmasq_z)

    params = elbo.init_parameters(PRNGKey(0), jnp.zeros((32, 5 * 5)), key=PRNGKey(0))
    assert (5, 10) == params.encode.sequential1.dense.kernel.shape 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:20,代码来源:test_examples.py

示例9: test_submodule_order

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import PRNGKey [as 别名]
def test_submodule_order():
    @parametrized
    def net():
        p = Parameter(lambda key: jnp.zeros((1,)))
        a = p()
        b = parameter((2,), zeros)
        c = parameter((3,), zeros)
        d = parameter((4,), zeros)
        e = parameter((5,), zeros)
        f = parameter((6,), zeros)

        # must not mess up order (decided by first submodule call):
        k = p()

        return jnp.concatenate([a, f]) + jnp.concatenate([b, e]) + jnp.concatenate([c, d]) + k

    params = net.init_parameters(key=PRNGKey(0))

    assert jnp.zeros((1,)) == params.parameter0
    out = net.apply(params)
    assert (7,) == out.shape 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:23,代码来源:test_core.py

示例10: test_external_submodule

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import PRNGKey [as 别名]
def test_external_submodule():
    layer = Dense(3)

    @parametrized
    def net(inputs):
        return 2 * layer(inputs)

    inputs = random_inputs((2,))
    params = net.init_parameters(inputs, key=PRNGKey(0))
    out = net.apply(params, inputs)
    assert out.shape == (3,)

    out_ = net.apply(params, inputs)
    assert jnp.array_equal(out, out_)

    out_ = net.apply(params, inputs, jit=True)
    assert jnp.allclose(out, out_) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:19,代码来源:test_core.py

示例11: test_inline_submodule

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import PRNGKey [as 别名]
def test_inline_submodule():
    @parametrized
    def net(inputs):
        layer = Dense(3)
        return 2 * layer(inputs)

    inputs = random_inputs((2,))
    params = net.init_parameters(inputs, key=PRNGKey(0))
    out = net.apply(params, inputs)
    assert out.shape == (3,)

    out_ = net.apply(params, inputs)
    assert jnp.array_equal(out, out_)

    out_ = net.apply(params, inputs, jit=True)
    assert jnp.allclose(out, out_) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:18,代码来源:test_core.py

示例12: test_external_submodule2

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import PRNGKey [as 别名]
def test_external_submodule2():
    layer = Dense(2, zeros, zeros)

    @parametrized
    def net(inputs):
        return layer(inputs)

    inputs = jnp.zeros((1, 2))

    params = net.init_parameters(inputs, key=PRNGKey(0))
    assert_parameters_equal(((jnp.zeros((2, 2)), jnp.zeros(2)),), params)

    out = net.apply(params, inputs)
    assert jnp.array_equal(jnp.zeros((1, 2)), out)

    out_ = net.apply(params, inputs, jit=True)
    assert jnp.array_equal(out, out_) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:19,代码来源:test_core.py

示例13: test_submodule_reuse

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import PRNGKey [as 别名]
def test_submodule_reuse():
    inputs = jnp.zeros((1, 2))

    layer = Dense(5)
    net1 = Sequential(layer, Dense(2))
    net2 = Sequential(layer, Dense(3))

    layer_params = layer.init_parameters(inputs, key=PRNGKey(0))
    net1_params = net1.init_parameters(inputs, key=PRNGKey(1), reuse={layer: layer_params})
    net2_params = net2.init_parameters(inputs, key=PRNGKey(2), reuse={layer: layer_params})

    out1 = net1.apply(net1_params, inputs)
    assert out1.shape == (1, 2)

    out2 = net2.apply(net2_params, inputs)
    assert out2.shape == (1, 3)

    assert_dense_parameters_equal(layer_params, net1_params[0])
    assert_dense_parameters_equal(layer_params, net2_params[0])

    new_layer_params = layer.init_parameters(inputs, key=PRNGKey(3))
    combined_params = net1.parameters_from({net1: net1_params, layer: new_layer_params}, inputs)
    assert_dense_parameters_equal(new_layer_params, combined_params.dense0)
    assert_dense_parameters_equal(net1_params.dense1, combined_params.dense1) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:26,代码来源:test_core.py

示例14: test_scan_parametrized_cell_without_params

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import PRNGKey [as 别名]
def test_scan_parametrized_cell_without_params():
    @parametrized
    def cell(carry, x):
        return jnp.array([2]) * carry * x, jnp.array([2]) * carry * x

    @parametrized
    def rnn(inputs):
        _, outs = lax.scan(cell, jnp.zeros((2,)), inputs)
        return outs

    inputs = jnp.zeros((3,))

    params = rnn.init_parameters(inputs, key=PRNGKey(0))
    assert_parameters_equal(((),), params)

    outs = rnn.apply(params, inputs)

    assert (3, 2) == outs.shape 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:20,代码来源:test_core.py

示例15: test_scan_parametrized_cell

# 需要导入模块: from jax import random [as 别名]
# 或者: from jax.random import PRNGKey [as 别名]
def test_scan_parametrized_cell():
    @parametrized
    def cell(carry, x):
        scale = parameter((2,), zeros)
        return scale * jnp.array([2]) * carry * x, scale * jnp.array([2]) * carry * x

    @parametrized
    def rnn(inputs):
        _, outs = lax.scan(cell, jnp.zeros((2,)), inputs)
        return outs

    inputs = jnp.zeros((3,))

    rnn_params = rnn.init_parameters(inputs, key=PRNGKey(0))
    assert (2,) == rnn_params.cell.parameter.shape
    outs = rnn.apply(rnn_params, inputs)

    assert (3, 2) == outs.shape 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:20,代码来源:test_core.py


注:本文中的jax.random.PRNGKey方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。