本文整理汇总了Python中jax.numpy.zeros_like方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.zeros_like方法的具体用法?Python numpy.zeros_like怎么用?Python numpy.zeros_like使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类jax.numpy
的用法示例。
在下文中一共展示了numpy.zeros_like方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test_log_normal
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import zeros_like [as 别名]
def test_log_normal(shape):
loc = np.random.rand(*shape) * 2 - 1
scale = np.random.rand(*shape) + 0.5
def model():
with numpyro.plate_stack("plates", shape):
with numpyro.plate("particles", 100000):
return numpyro.sample("x",
dist.TransformedDistribution(
dist.Normal(jnp.zeros_like(loc),
jnp.ones_like(scale)),
[AffineTransform(loc, scale),
ExpTransform()]).expand_by([100000]))
with handlers.trace() as tr:
value = handlers.seed(model, 0)()
expected_moments = get_moments(value)
with numpyro.handlers.reparam(config={"x": TransformReparam()}):
with handlers.trace() as tr:
value = handlers.seed(model, 0)()
assert tr["x"]["type"] == "deterministic"
actual_moments = get_moments(value)
assert_allclose(actual_moments, expected_moments, atol=0.05)
示例2: _tree_zeros_like
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import zeros_like [as 别名]
def _tree_zeros_like(tree):
def f(x):
return np.zeros_like(x)
return tu.tree_map(f, tree)
示例3: zeros_like
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import zeros_like [as 别名]
def zeros_like(self, x, dtype=None, name=None):
return np.zeros_like(x, dtype=dtype)
示例4: adv_flux_superbee_wgrid
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import zeros_like [as 别名]
def adv_flux_superbee_wgrid(var, u_wgrid, v_wgrid, w_wgrid, maskW, dxt, dyt, dzw, cost, cosu, dt_tracer):
"""
Calculates advection of a tracer defined on Wgrid
"""
maskUtr = np.zeros_like(maskW)
maskUtr = jax.ops.index_update(
maskUtr, jax.ops.index[:-1, :, :],
maskW[1:, :, :] * maskW[:-1, :, :]
)
adv_fe = np.zeros_like(maskW)
adv_fe = jax.ops.index_update(
adv_fe, jax.ops.index[1:-2, 2:-2, :],
_adv_superbee(u_wgrid, var, maskUtr, dxt, 0, cost, cosu, dt_tracer)
)
maskVtr = np.zeros_like(maskW)
maskVtr = jax.ops.index_update(
maskVtr, jax.ops.index[:, :-1, :],
maskW[:, 1:, :] * maskW[:, :-1, :]
)
adv_fn = np.zeros_like(maskW)
adv_fn = jax.ops.index_update(
adv_fn, jax.ops.index[2:-2, 1:-2, :],
_adv_superbee(v_wgrid, var, maskVtr, dyt, 1, cost, cosu, dt_tracer)
)
maskWtr = np.zeros_like(maskW)
maskWtr = jax.ops.index_update(
maskWtr, jax.ops.index[:, :, :-1],
maskW[:, :, 1:] * maskW[:, :, :-1]
)
adv_ft = np.zeros_like(maskW)
adv_ft = jax.ops.index_update(
adv_ft, jax.ops.index[2:-2, 2:-2, :-1],
_adv_superbee(w_wgrid, var, maskWtr, dzw, 2, cost, cosu, dt_tracer)
)
return adv_fe, adv_fn, adv_ft