本文整理汇总了Python中jax.numpy.int32方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.int32方法的具体用法?Python numpy.int32怎么用?Python numpy.int32使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类jax.numpy
的用法示例。
在下文中一共展示了numpy.int32方法的10个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: serialize
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import int32 [as 别名]
def serialize(self, data):
"""Serializes a batch of space elements into discrete sequences.
Should be defined in subclasses.
Args:
data: A batch of batch_size elements of the Gym space to be serialized.
Returns:
int32 array of shape (batch_size, self.representation_length).
"""
raise NotImplementedError
示例2: deserialize
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import int32 [as 别名]
def deserialize(self, representation):
"""Deserializes a batch of discrete sequences into space elements.
Should be defined in subclasses.
Args:
representation: int32 Numpy array of shape
(batch_size, self.representation_length) to be deserialized.
Returns:
A batch of batch_size deserialized elements of the Gym space.
"""
raise NotImplementedError
示例3: significance_map
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import int32 [as 别名]
def significance_map(self):
return np.zeros(1, dtype=np.int32)
示例4: int32
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import int32 [as 别名]
def int32(self):
return np.int32
示例5: multigammaln
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import int32 [as 别名]
def multigammaln(self, a, p):
p = self.to_float(p)
p_ = self.cast(p, 'int32')
a = a[..., None]
i = self.to_float(self.range(1, p_ + 1))
term1 = p * (p - 1) / 4. * self.log(np.pi)
term2 = self.gammaln(a - (i - 1) / 2.)
return term1 + self.sum(term2, axis=-1)
示例6: __init__
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import int32 [as 别名]
def __init__(self, **kwargs):
self.name = 'jax'
self.precision = kwargs.get('precision', '64b')
self.dtypemap = {
'float': np.float64 if self.precision == '64b' else np.float32,
'int': np.int64 if self.precision == '64b' else np.int32,
'bool': np.bool_,
}
config.update('jax_enable_x64', self.precision == '64b')
示例7: _get_num_steps
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import int32 [as 别名]
def _get_num_steps(step_size, trajectory_length):
num_steps = jnp.clip(trajectory_length / step_size, a_min=1)
# NB: casting to jnp.int64 does not take effect (returns jnp.int32 instead)
# if jax_enable_x64 is False
return num_steps.astype(canonicalize_dtype(jnp.int64))
示例8: test_change_point_x64
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import int32 [as 别名]
def test_change_point_x64():
# Ref: https://forum.pyro.ai/t/i-dont-understand-why-nuts-code-is-not-working-bayesian-hackers-mail/696
warmup_steps, num_samples = 500, 3000
def model(data):
alpha = 1 / jnp.mean(data)
lambda1 = numpyro.sample('lambda1', dist.Exponential(alpha))
lambda2 = numpyro.sample('lambda2', dist.Exponential(alpha))
tau = numpyro.sample('tau', dist.Uniform(0, 1))
lambda12 = jnp.where(jnp.arange(len(data)) < tau * len(data), lambda1, lambda2)
numpyro.sample('obs', dist.Poisson(lambda12), obs=data)
count_data = jnp.array([
13, 24, 8, 24, 7, 35, 14, 11, 15, 11, 22, 22, 11, 57,
11, 19, 29, 6, 19, 12, 22, 12, 18, 72, 32, 9, 7, 13,
19, 23, 27, 20, 6, 17, 13, 10, 14, 6, 16, 15, 7, 2,
15, 15, 19, 70, 49, 7, 53, 22, 21, 31, 19, 11, 18, 20,
12, 35, 17, 23, 17, 4, 2, 31, 30, 13, 27, 0, 39, 37,
5, 14, 13, 22,
])
kernel = NUTS(model=model)
mcmc = MCMC(kernel, warmup_steps, num_samples)
mcmc.run(random.PRNGKey(4), count_data)
samples = mcmc.get_samples()
tau_posterior = (samples['tau'] * len(count_data)).astype(jnp.int32)
tau_values, counts = np.unique(tau_posterior, return_counts=True)
mode_ind = jnp.argmax(counts)
mode = tau_values[mode_ind]
assert mode == 44
if 'JAX_ENABLE_X64' in os.environ:
assert samples['lambda1'].dtype == jnp.float64
assert samples['lambda2'].dtype == jnp.float64
assert samples['tau'].dtype == jnp.float64
示例9: z
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import int32 [as 别名]
def z(*shape):
return jnp.zeros(shape, dtype=jnp.int32)
示例10: test_log_prob_gradient
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import int32 [as 别名]
def test_log_prob_gradient(jax_dist, sp_dist, params):
if jax_dist in [dist.LKJ, dist.LKJCholesky]:
pytest.skip('we have separated tests for LKJCholesky distribution')
if jax_dist is _ImproperWrapper:
pytest.skip('no param for ImproperUniform to test for log_prob gradient')
rng_key = random.PRNGKey(0)
value = jax_dist(*params).sample(rng_key)
def fn(*args):
return jnp.sum(jax_dist(*args).log_prob(value))
eps = 1e-3
for i in range(len(params)):
if params[i] is None or jnp.result_type(params[i]) in (jnp.int32, jnp.int64):
continue
actual_grad = jax.grad(fn, i)(*params)
args_lhs = [p if j != i else p - eps for j, p in enumerate(params)]
args_rhs = [p if j != i else p + eps for j, p in enumerate(params)]
fn_lhs = fn(*args_lhs)
fn_rhs = fn(*args_rhs)
# finite diff approximation
expected_grad = (fn_rhs - fn_lhs) / (2. * eps)
assert jnp.shape(actual_grad) == jnp.shape(params[i])
if i == 0 and jax_dist is dist.Delta:
# grad w.r.t. `value` of Delta distribution will be 0
# but numerical value will give nan (= inf - inf)
expected_grad = 0.
assert_allclose(jnp.sum(actual_grad), expected_grad, rtol=0.01, atol=0.01)