当前位置: 首页>>代码示例>>Python>>正文


Python numpy.zeros_like方法代码示例

本文整理汇总了Python中jax.numpy.zeros_like方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.zeros_like方法的具体用法?Python numpy.zeros_like怎么用?Python numpy.zeros_like使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在jax.numpy的用法示例。


在下文中一共展示了numpy.zeros_like方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_log_normal

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import zeros_like [as 别名]
def test_log_normal(shape):
    loc = np.random.rand(*shape) * 2 - 1
    scale = np.random.rand(*shape) + 0.5

    def model():
        with numpyro.plate_stack("plates", shape):
            with numpyro.plate("particles", 100000):
                return numpyro.sample("x",
                                      dist.TransformedDistribution(
                                          dist.Normal(jnp.zeros_like(loc),
                                                      jnp.ones_like(scale)),
                                          [AffineTransform(loc, scale),
                                           ExpTransform()]).expand_by([100000]))

    with handlers.trace() as tr:
        value = handlers.seed(model, 0)()
    expected_moments = get_moments(value)

    with numpyro.handlers.reparam(config={"x": TransformReparam()}):
        with handlers.trace() as tr:
            value = handlers.seed(model, 0)()
    assert tr["x"]["type"] == "deterministic"
    actual_moments = get_moments(value)
    assert_allclose(actual_moments, expected_moments, atol=0.05) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:26,代码来源:test_reparam.py

示例2: _tree_zeros_like

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import zeros_like [as 别名]
def _tree_zeros_like(tree):
  def f(x):
    return np.zeros_like(x)
  return tu.tree_map(f, tree) 
开发者ID:google,项目名称:spectral-density,代码行数:6,代码来源:hessian_computation.py

示例3: zeros_like

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import zeros_like [as 别名]
def zeros_like(self, x, dtype=None, name=None):
        return np.zeros_like(x, dtype=dtype) 
开发者ID:sharadmv,项目名称:deepx,代码行数:4,代码来源:jax.py

示例4: adv_flux_superbee_wgrid

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import zeros_like [as 别名]
def adv_flux_superbee_wgrid(var, u_wgrid, v_wgrid, w_wgrid, maskW, dxt, dyt, dzw, cost, cosu, dt_tracer):
    """
    Calculates advection of a tracer defined on Wgrid
    """
    maskUtr = np.zeros_like(maskW)
    maskUtr = jax.ops.index_update(
        maskUtr, jax.ops.index[:-1, :, :],
        maskW[1:, :, :] * maskW[:-1, :, :]
    )

    adv_fe = np.zeros_like(maskW)
    adv_fe = jax.ops.index_update(
        adv_fe, jax.ops.index[1:-2, 2:-2, :],
        _adv_superbee(u_wgrid, var, maskUtr, dxt, 0, cost, cosu, dt_tracer)
    )

    maskVtr = np.zeros_like(maskW)
    maskVtr = jax.ops.index_update(
        maskVtr, jax.ops.index[:, :-1, :],
        maskW[:, 1:, :] * maskW[:, :-1, :]

    )
    adv_fn = np.zeros_like(maskW)
    adv_fn = jax.ops.index_update(
        adv_fn, jax.ops.index[2:-2, 1:-2, :],
        _adv_superbee(v_wgrid, var, maskVtr, dyt, 1, cost, cosu, dt_tracer)
    )

    maskWtr = np.zeros_like(maskW)
    maskWtr = jax.ops.index_update(
        maskWtr, jax.ops.index[:, :, :-1],
        maskW[:, :, 1:] * maskW[:, :, :-1]
    )
    adv_ft = np.zeros_like(maskW)
    adv_ft = jax.ops.index_update(
        adv_ft, jax.ops.index[2:-2, 2:-2, :-1],
        _adv_superbee(w_wgrid, var, maskWtr, dzw, 2, cost, cosu, dt_tracer)
    )

    return adv_fe, adv_fn, adv_ft 
开发者ID:dionhaefner,项目名称:pyhpc-benchmarks,代码行数:42,代码来源:tke_jax.py


注:本文中的jax.numpy.zeros_like方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。