当前位置: 首页>>代码示例>>Python>>正文


Python numpy.arange方法代码示例

本文整理汇总了Python中jax.numpy.arange方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.arange方法的具体用法?Python numpy.arange怎么用?Python numpy.arange使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在jax.numpy的用法示例。


在下文中一共展示了numpy.arange方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: chosen_probabs

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import arange [as 别名]
def chosen_probabs(probab_observations, actions):
  """Picks out the probabilities of the actions along batch and time-steps.

  Args:
    probab_observations: ndarray of shape `[B, T+1, A]`, where
      probab_observations[b, t, i] contains the log-probability of action = i at
      the t^th time-step in the b^th trajectory.
    actions: ndarray of shape `[B, T]`, with each entry in [0, A) denoting which
      action was chosen in the b^th trajectory's t^th time-step.

  Returns:
    `[B, T]` ndarray with the log-probabilities of the chosen actions.
  """
  B, T = actions.shape  # pylint: disable=invalid-name
  assert (B, T + 1) == probab_observations.shape[:2]
  return probab_observations[np.arange(B)[:, None], np.arange(T), actions] 
开发者ID:yyht,项目名称:BERT,代码行数:18,代码来源:ppo.py

示例2: get_data

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import arange [as 别名]
def get_data(N=50, D_X=3, sigma_obs=0.05, N_test=500):
    D_Y = 1  # create 1d outputs
    np.random.seed(0)
    X = jnp.linspace(-1, 1, N)
    X = jnp.power(X[:, np.newaxis], jnp.arange(D_X))
    W = 0.5 * np.random.randn(D_X)
    Y = jnp.dot(X, W) + 0.5 * jnp.power(0.5 + X[:, 1], 2.0) * jnp.sin(4.0 * X[:, 1])
    Y += sigma_obs * np.random.randn(N)
    Y = Y[:, np.newaxis]
    Y -= jnp.mean(Y)
    Y /= jnp.std(Y)

    assert X.shape == (N, D_X)
    assert Y.shape == (N, D_Y)

    X_test = jnp.linspace(-1.3, 1.3, N_test)
    X_test = jnp.power(X_test[:, np.newaxis], jnp.arange(D_X))

    return X, Y, X_test 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:21,代码来源:bnn.py

示例3: model

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import arange [as 别名]
def model(N, y=None):
    """
    :param int N: number of measurement times
    :param numpy.ndarray y: measured populations with shape (N, 2)
    """
    # initial population
    z_init = numpyro.sample("z_init", dist.LogNormal(jnp.log(10), 1), sample_shape=(2,))
    # measurement times
    ts = jnp.arange(float(N))
    # parameters alpha, beta, gamma, delta of dz_dt
    theta = numpyro.sample(
        "theta",
        dist.TruncatedNormal(low=0., loc=jnp.array([0.5, 0.05, 1.5, 0.05]),
                             scale=jnp.array([0.5, 0.05, 0.5, 0.05])))
    # integrate dz/dt, the result will have shape N x 2
    z = odeint(dz_dt, z_init, ts, theta, rtol=1e-5, atol=1e-3, mxstep=500)
    # measurement errors, we expect that measured hare has larger error than measured lynx
    sigma = numpyro.sample("sigma", dist.Exponential(jnp.array([1, 2])))
    # measured populations (in log scale)
    numpyro.sample("y", dist.Normal(jnp.log(z), sigma), obs=y) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:22,代码来源:ode.py

示例4: _multinomial

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import arange [as 别名]
def _multinomial(key, p, n, n_max, shape=()):
    if jnp.shape(n) != jnp.shape(p)[:-1]:
        broadcast_shape = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1])
        n = jnp.broadcast_to(n, broadcast_shape)
        p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:])
    shape = shape or p.shape[:-1]
    # get indices from categorical distribution then gather the result
    indices = categorical(key, p, (n_max,) + shape)
    # mask out values when counts is heterogeneous
    if jnp.ndim(n) > 0:
        mask = promote_shapes(jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,))[0]
        mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype)
        excess = jnp.concatenate([jnp.expand_dims(n_max - n, -1), jnp.zeros(jnp.shape(n) + (p.shape[-1] - 1,))], -1)
    else:
        mask = 1
        excess = 0
    # NB: we transpose to move batch shape to the front
    indices_2D = (jnp.reshape(indices * mask, (n_max, -1,))).T
    samples_2D = vmap(_scatter_add_one, (0, 0, 0))(jnp.zeros((indices_2D.shape[0], p.shape[-1]),
                                                             dtype=indices.dtype),
                                                   jnp.expand_dims(indices_2D, axis=-1),
                                                   jnp.ones(indices_2D.shape, dtype=indices.dtype))
    return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:25,代码来源:util.py

示例5: test_chain

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import arange [as 别名]
def test_chain(use_init_params, chain_method):
    N, dim = 3000, 3
    num_chains = 2
    num_warmup, num_samples = 5000, 5000
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = jnp.arange(1., dim + 1.)
    logits = jnp.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(labels):
        coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
        logits = jnp.sum(coefs * data, axis=-1)
        return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)

    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, num_warmup, num_samples, num_chains=num_chains)
    mcmc.chain_method = chain_method
    init_params = None if not use_init_params else \
        {'coefs': jnp.tile(jnp.ones(dim), num_chains).reshape(num_chains, dim)}
    mcmc.run(random.PRNGKey(2), labels, init_params=init_params)
    samples_flat = mcmc.get_samples()
    assert samples_flat['coefs'].shape[0] == num_chains * num_samples
    samples = mcmc.get_samples(group_by_chain=True)
    assert samples['coefs'].shape[:2] == (num_chains, num_samples)
    assert_allclose(jnp.mean(samples_flat['coefs'], 0), true_coefs, atol=0.21) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:27,代码来源:test_mcmc.py

示例6: test_gaussian_subposterior

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import arange [as 别名]
def test_gaussian_subposterior(method, diagonal):
    D = 10
    n_samples = 10000
    n_draws = 9000
    n_subs = 8

    mean = jnp.arange(D)
    cov = jnp.ones((D, D)) * 0.9 + jnp.identity(D) * 0.1
    subcov = n_subs * cov  # subposterior's covariance
    subposteriors = list(dist.MultivariateNormal(mean, subcov).sample(
        random.PRNGKey(1), (n_subs, n_samples)))

    draws = method(subposteriors, n_draws, diagonal=diagonal)
    assert draws.shape == (n_draws, D)
    assert_allclose(jnp.mean(draws, axis=0), mean, atol=0.03)
    if diagonal:
        assert_allclose(jnp.var(draws, axis=0), jnp.diag(cov), atol=0.05)
    else:
        assert_allclose(jnp.cov(draws.T), cov, atol=0.05) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:21,代码来源:test_hmc_util.py

示例7: test_categorical_log_prob_grad

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import arange [as 别名]
def test_categorical_log_prob_grad():
    data = jnp.repeat(jnp.arange(3), 10)

    def f(x):
        return dist.Categorical(jax.nn.softmax(x * jnp.arange(1, 4))).log_prob(data).sum()

    def g(x):
        return dist.Categorical(logits=x * jnp.arange(1, 4)).log_prob(data).sum()

    x = 0.5
    fx, grad_fx = jax.value_and_grad(f)(x)
    gx, grad_gx = jax.value_and_grad(g)(x)
    assert_allclose(fx, gx)
    assert_allclose(grad_fx, grad_gx, atol=1e-4)


########################################
# Tests for constraints and transforms #
######################################## 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:21,代码来源:test_distributions.py

示例8: solve_implicit

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import arange [as 别名]
def solve_implicit(ks, a, b, c, d, b_edge=None, d_edge=None):
    land_mask = (ks >= 0)[:, :, np.newaxis]
    edge_mask = land_mask & (np.arange(a.shape[2])[np.newaxis, np.newaxis, :]
                             == ks[:, :, np.newaxis])
    water_mask = land_mask & (np.arange(a.shape[2])[np.newaxis, np.newaxis, :]
                              >= ks[:, :, np.newaxis])

    a_tri = water_mask * a * np.logical_not(edge_mask)
    b_tri = where(water_mask, b, 1.)
    if b_edge is not None:
        b_tri = where(edge_mask, b_edge, b_tri)
    c_tri = water_mask * c
    d_tri = water_mask * d
    if d_edge is not None:
        d_tri = where(edge_mask, d_edge, d_tri)

    return solve_tridiag(a_tri, b_tri, c_tri, d_tri), water_mask 
开发者ID:dionhaefner,项目名称:pyhpc-benchmarks,代码行数:19,代码来源:tke_jax.py

示例9: _new_arange

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import arange [as 别名]
def _new_arange(x, start, stop, step):
    return np.arange(start, stop, step) 
开发者ID:pyro-ppl,项目名称:funsor,代码行数:4,代码来源:ops.py

示例10: significance_map

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import arange [as 别名]
def significance_map(self):
    return np.reshape(np.broadcast_to(
        np.arange(self._precision), self._space.shape + (self._precision,)), -1) 
开发者ID:google,项目名称:trax,代码行数:5,代码来源:space_serializer.py

示例11: main

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import arange [as 别名]
def main():
    key = PRNGKey(0)

    batch_size = 8
    num_classes = 1001
    input_shape = (224, 224, 3, batch_size)
    step_size = 0.1
    num_steps = 10

    resnet = ResNet50(num_classes)

    @parametrized
    def loss(inputs, targets):
        logits = resnet(inputs)
        return np.sum(logits * targets)

    @parametrized
    def accuracy(inputs, targets):
        target_class = np.argmax(targets, axis=-1)
        predicted_class = np.argmax(resnet(inputs), axis=-1)
        return np.mean(predicted_class == target_class)

    def synth_batches():
        rng = npr.RandomState(0)
        while True:
            images = rng.rand(*input_shape).astype('float32')
            labels = rng.randint(num_classes, size=(batch_size, 1))
            onehot_labels = labels == np.arange(num_classes)
            yield images, onehot_labels

    opt = optimizers.Momentum(step_size, mass=0.9)
    batches = synth_batches()

    print("\nInitializing parameters.")
    state = opt.init(loss.init_parameters(*next(batches), key=key))
    for i in range(num_steps):
        print(f'Training on batch {i}.')
        state = opt.update(loss.apply, state, *next(batches))
    trained_params = opt.get_parameters(state) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:41,代码来源:resnet50.py

示例12: _one_hot

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import arange [as 别名]
def _one_hot(x, k, dtype=np.float32):
    """Create a one-hot encoding of x of size k."""
    return np.array(x[:, None] == np.arange(k), dtype) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:5,代码来源:mnist_classifier.py

示例13: _extract_signal_patches

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import arange [as 别名]
def _extract_signal_patches(signal, window_length, hop=1, data_format="NCW"):
    assert not hasattr(window_length, "__len__")
    assert signal.ndim == 3
    if data_format == "NCW":
        N = (signal.shape[2] - window_length) // hop + 1
        indices = jnp.arange(window_length) + jnp.expand_dims(jnp.arange(N) * hop, 1)
        indices = jnp.reshape(indices, [1, 1, N * window_length])
        patches = jnp.take_along_axis(signal, indices, 2)
        return jnp.reshape(patches, signal.shape[:2] + (N, window_length))
    else:
        error 
开发者ID:SymJAX,项目名称:SymJAX,代码行数:13,代码来源:ops_special.py

示例14: _extract_image_patches

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import arange [as 别名]
def _extract_image_patches(
    image, window_shape, hop=1, data_format="NCHW", mode="valid"
):
    if mode == "same":
        p1 = window_shape[0] - 1
        p2 = window_shape[1] - 1
        image = jnp.pad(
            image, [(0, 0), (0, 0), (p1 // 2, p1 - p1 // 2), (p2 // 2, p2 - p2 // 2)]
        )
    if not hasattr(hop, "__len__"):
        hop = (hop, hop)
    if data_format == "NCHW":

        # compute the number of windows in both dimensions
        N = (
            (image.shape[2] - window_shape[0]) // hop[0] + 1,
            (image.shape[3] - window_shape[1]) // hop[1] + 1,
        )

        # compute the base indices of a 2d patch
        patch = jnp.arange(numpy.prod(window_shape)).reshape(window_shape)
        offset = jnp.expand_dims(jnp.arange(window_shape[0]), 1)
        patch_indices = patch + offset * (image.shape[3] - window_shape[1])

        # create all the shifted versions of it
        ver_shifts = jnp.reshape(
            jnp.arange(N[0]) * hop[0] * image.shape[3], (-1, 1, 1, 1)
        )
        hor_shifts = jnp.reshape(jnp.arange(N[1]) * hop[1], (-1, 1, 1))
        all_cols = patch_indices + jnp.reshape(jnp.arange(N[1]) * hop[1], (-1, 1, 1))
        indices = patch_indices + ver_shifts + hor_shifts

        # now extract shape (1, 1, H'W'a'b')
        flat_indices = jnp.reshape(indices, [1, 1, -1])
        # shape is now (N, C, W*H)
        flat_image = jnp.reshape(image, (image.shape[0], image.shape[1], -1))
        # shape is now (N, C)
        patches = jnp.take_along_axis(flat_image, flat_indices, 2)
        return jnp.reshape(patches, image.shape[:2] + N + tuple(window_shape))
    else:
        error 
开发者ID:SymJAX,项目名称:SymJAX,代码行数:43,代码来源:ops_special.py

示例15: one_hot

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import arange [as 别名]
def one_hot(i, N, dtype="float32"):
    """Create a one-hot encoding of x of size k."""
    if hasattr(i, "shape"):
        return (x[:, None] == arange(k)).astype(dtype)
    else:
        z = T.zeros(N, dtype)
        return index_add(z, i, 1) 
开发者ID:SymJAX,项目名称:SymJAX,代码行数:9,代码来源:ops_special.py


注:本文中的jax.numpy.arange方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。