当前位置: 首页>>代码示例>>Python>>正文


Python numpy.int32方法代码示例

本文整理汇总了Python中jax.numpy.int32方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.int32方法的具体用法?Python numpy.int32怎么用?Python numpy.int32使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在jax.numpy的用法示例。


在下文中一共展示了numpy.int32方法的10个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: serialize

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import int32 [as 别名]
def serialize(self, data):
    """Serializes a batch of space elements into discrete sequences.

    Should be defined in subclasses.

    Args:
      data: A batch of batch_size elements of the Gym space to be serialized.

    Returns:
      int32 array of shape (batch_size, self.representation_length).
    """
    raise NotImplementedError 
开发者ID:google,项目名称:trax,代码行数:14,代码来源:space_serializer.py

示例2: deserialize

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import int32 [as 别名]
def deserialize(self, representation):
    """Deserializes a batch of discrete sequences into space elements.

    Should be defined in subclasses.

    Args:
      representation: int32 Numpy array of shape
        (batch_size, self.representation_length) to be deserialized.

    Returns:
      A batch of batch_size deserialized elements of the Gym space.
    """
    raise NotImplementedError 
开发者ID:google,项目名称:trax,代码行数:15,代码来源:space_serializer.py

示例3: significance_map

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import int32 [as 别名]
def significance_map(self):
    return np.zeros(1, dtype=np.int32) 
开发者ID:google,项目名称:trax,代码行数:4,代码来源:space_serializer.py

示例4: int32

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import int32 [as 别名]
def int32(self):
        return np.int32 
开发者ID:sharadmv,项目名称:deepx,代码行数:4,代码来源:jax.py

示例5: multigammaln

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import int32 [as 别名]
def multigammaln(self, a, p):
        p = self.to_float(p)
        p_ = self.cast(p, 'int32')
        a = a[..., None]
        i = self.to_float(self.range(1, p_ + 1))
        term1 = p * (p - 1) / 4. * self.log(np.pi)
        term2 = self.gammaln(a - (i - 1) / 2.)
        return term1 + self.sum(term2, axis=-1) 
开发者ID:sharadmv,项目名称:deepx,代码行数:10,代码来源:jax.py

示例6: __init__

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import int32 [as 别名]
def __init__(self, **kwargs):
        self.name = 'jax'
        self.precision = kwargs.get('precision', '64b')
        self.dtypemap = {
            'float': np.float64 if self.precision == '64b' else np.float32,
            'int': np.int64 if self.precision == '64b' else np.int32,
            'bool': np.bool_,
        }
        config.update('jax_enable_x64', self.precision == '64b') 
开发者ID:scikit-hep,项目名称:pyhf,代码行数:11,代码来源:jax_backend.py

示例7: _get_num_steps

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import int32 [as 别名]
def _get_num_steps(step_size, trajectory_length):
    num_steps = jnp.clip(trajectory_length / step_size, a_min=1)
    # NB: casting to jnp.int64 does not take effect (returns jnp.int32 instead)
    # if jax_enable_x64 is False
    return num_steps.astype(canonicalize_dtype(jnp.int64)) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:7,代码来源:mcmc.py

示例8: test_change_point_x64

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import int32 [as 别名]
def test_change_point_x64():
    # Ref: https://forum.pyro.ai/t/i-dont-understand-why-nuts-code-is-not-working-bayesian-hackers-mail/696
    warmup_steps, num_samples = 500, 3000

    def model(data):
        alpha = 1 / jnp.mean(data)
        lambda1 = numpyro.sample('lambda1', dist.Exponential(alpha))
        lambda2 = numpyro.sample('lambda2', dist.Exponential(alpha))
        tau = numpyro.sample('tau', dist.Uniform(0, 1))
        lambda12 = jnp.where(jnp.arange(len(data)) < tau * len(data), lambda1, lambda2)
        numpyro.sample('obs', dist.Poisson(lambda12), obs=data)

    count_data = jnp.array([
        13,  24,   8,  24,   7,  35,  14,  11,  15,  11,  22,  22,  11,  57,
        11,  19,  29,   6,  19,  12,  22,  12,  18,  72,  32,   9,   7,  13,
        19,  23,  27,  20,   6,  17,  13,  10,  14,   6,  16,  15,   7,   2,
        15,  15,  19,  70,  49,   7,  53,  22,  21,  31,  19,  11,  18,  20,
        12,  35,  17,  23,  17,   4,   2,  31,  30,  13,  27,   0,  39,  37,
        5,  14,  13,  22,
    ])
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, warmup_steps, num_samples)
    mcmc.run(random.PRNGKey(4), count_data)
    samples = mcmc.get_samples()
    tau_posterior = (samples['tau'] * len(count_data)).astype(jnp.int32)
    tau_values, counts = np.unique(tau_posterior, return_counts=True)
    mode_ind = jnp.argmax(counts)
    mode = tau_values[mode_ind]
    assert mode == 44

    if 'JAX_ENABLE_X64' in os.environ:
        assert samples['lambda1'].dtype == jnp.float64
        assert samples['lambda2'].dtype == jnp.float64
        assert samples['tau'].dtype == jnp.float64 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:36,代码来源:test_mcmc.py

示例9: z

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import int32 [as 别名]
def z(*shape):
    return jnp.zeros(shape, dtype=jnp.int32) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:4,代码来源:test_indexing.py

示例10: test_log_prob_gradient

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import int32 [as 别名]
def test_log_prob_gradient(jax_dist, sp_dist, params):
    if jax_dist in [dist.LKJ, dist.LKJCholesky]:
        pytest.skip('we have separated tests for LKJCholesky distribution')
    if jax_dist is _ImproperWrapper:
        pytest.skip('no param for ImproperUniform to test for log_prob gradient')

    rng_key = random.PRNGKey(0)
    value = jax_dist(*params).sample(rng_key)

    def fn(*args):
        return jnp.sum(jax_dist(*args).log_prob(value))

    eps = 1e-3
    for i in range(len(params)):
        if params[i] is None or jnp.result_type(params[i]) in (jnp.int32, jnp.int64):
            continue
        actual_grad = jax.grad(fn, i)(*params)
        args_lhs = [p if j != i else p - eps for j, p in enumerate(params)]
        args_rhs = [p if j != i else p + eps for j, p in enumerate(params)]
        fn_lhs = fn(*args_lhs)
        fn_rhs = fn(*args_rhs)
        # finite diff approximation
        expected_grad = (fn_rhs - fn_lhs) / (2. * eps)
        assert jnp.shape(actual_grad) == jnp.shape(params[i])
        if i == 0 and jax_dist is dist.Delta:
            # grad w.r.t. `value` of Delta distribution will be 0
            # but numerical value will give nan (= inf - inf)
            expected_grad = 0.
        assert_allclose(jnp.sum(actual_grad), expected_grad, rtol=0.01, atol=0.01) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:31,代码来源:test_distributions.py


注:本文中的jax.numpy.int32方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。