當前位置: 首頁>>代碼示例>>Python>>正文


Python lax.reduce_window方法代碼示例

本文整理匯總了Python中jax.lax.reduce_window方法的典型用法代碼示例。如果您正苦於以下問題:Python lax.reduce_window方法的具體用法?Python lax.reduce_window怎麽用?Python lax.reduce_window使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在jax.lax的用法示例。


在下文中一共展示了lax.reduce_window方法的7個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: _pooling_general

# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import reduce_window [as 別名]
def _pooling_general(inputs, reducer, init_val, rescaler=None,
                     pool_size=(2, 2), strides=None, padding="VALID"):
  """Helper: general pooling computation used in pooling layers later."""
  spatial_strides = strides or (1,) * len(pool_size)
  rescale = rescaler(pool_size, spatial_strides, padding) if rescaler else None
  dims = (1,) + pool_size + (1,)  # NHWC
  strides = (1,) + spatial_strides + (1,)
  out = lax.reduce_window(inputs, init_val, reducer, dims, strides, padding)
  return rescale(out, inputs) if rescale else out 
開發者ID:yyht,項目名稱:BERT,代碼行數:11,代碼來源:backend.py

示例2: _normalize_by_window_size

# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import reduce_window [as 別名]
def _normalize_by_window_size(dims, spatial_strides, padding):  # pylint: disable=invalid-name
  def rescale(outputs, inputs):
    one = jnp.ones(inputs.shape[1:-1], dtype=inputs.dtype)
    window_sizes = lax.reduce_window(
        one, 0., lax.add, dims, spatial_strides, padding)
    return outputs / window_sizes[..., jnp.newaxis]
  return rescale 
開發者ID:yyht,項目名稱:BERT,代碼行數:9,代碼來源:backend.py

示例3: _pooling_general

# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import reduce_window [as 別名]
def _pooling_general(inputs, reducer, init_val, rescaler=None,
                     pool_size=(2, 2), strides=None, padding='VALID'):
  """Helper: general pooling computation used in pooling layers later."""
  spatial_strides = strides or (1,) * len(pool_size)
  rescale = rescaler(pool_size, spatial_strides, padding) if rescaler else None
  dims = (1,) + pool_size + (1,)  # NHWC
  strides = (1,) + spatial_strides + (1,)
  out = lax.reduce_window(inputs, init_val, reducer, dims, strides, padding)
  return rescale(out, inputs) if rescale else out  # pylint: disable=not-callable 
開發者ID:google,項目名稱:trax,代碼行數:11,代碼來源:jax.py

示例4: _pool

# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import reduce_window [as 別名]
def _pool(reducer, init_val, rescaler=None):
    def Pool(window_shape, strides=None, padding='VALID'):
        """Layer construction function for a pooling layer."""
        strides = strides or (1,) * len(window_shape)
        rescale = rescaler(window_shape, strides, padding) if rescaler else None
        dims = (1,) + window_shape + (1,)  # NHWC
        strides = (1,) + strides + (1,)

        def pool(inputs):
            out = lax.reduce_window(inputs, init_val, reducer, dims, strides, padding)
            return rescale(out, inputs) if rescale else out

        return pool

    return Pool 
開發者ID:JuliusKunze,項目名稱:jaxnet,代碼行數:17,代碼來源:modules.py

示例5: _normalize_by_window_size

# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import reduce_window [as 別名]
def _normalize_by_window_size(dims, strides, padding):
    def rescale(outputs, inputs):
        one = np.ones(inputs.shape[1:-1], dtype=inputs.dtype)
        window_sizes = lax.reduce_window(one, 0., lax.add, dims, strides, padding)
        return outputs / window_sizes[..., np.newaxis]

    return rescale 
開發者ID:JuliusKunze,項目名稱:jaxnet,代碼行數:9,代碼來源:modules.py

示例6: pool2d

# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import reduce_window [as 別名]
def pool2d(self, x, pool_size, strides=(1, 1),
               border_mode='valid', pool_mode='max'):
        dims = (1,) + pool_size + (1,)
        strides = (1,) + strides + (1,)
        return lax.reduce_window(x, -np.inf, lax.max, dims, strides, border_mode) 
開發者ID:sharadmv,項目名稱:deepx,代碼行數:7,代碼來源:jax.py

示例7: poolNd

# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import reduce_window [as 別名]
def poolNd(
    input,
    window_shape,
    reducer="MAX",
    strides=None,
    padding="VALID",
    init_val=None,
    rescalor=None,
):

    # set up the reducer
    if reducer == "MAX":
        reducer = jla.max
        rescalor = numpy.float32(1.0)
        init_val = -numpy.inf
    elif reducer == "SUM" or reducer == "AVG":
        reducer = jla.add
        if reducer == "AVG":
            rescalor = numpy.float32(1.0 / numpy.prod(window_shape))
        else:
            rescalor = numpy.float32(1.0)
        if init_val is None:
            init_val = 0.0

    # set up the window_shape
    if numpy.isscalar(window_shape):
        window_shape = (window_shape,) * input.ndim
    elif len(window_shape) != input.ndim:
        msg = "Given window_shape {} not the same length ".format(
            strides
        ) + "as input shape {}".format(input.ndim)
        raise ValueError(msg)

    # set up the strides
    if strides is None:
        strides = window_shape
    elif numpy.isscalar(strides):
        strides = (strides,) * len(window_shape)
    elif len(strides) != len(window_shape):
        msg = "Given strides {} not the same length ".format(
            strides
        ) + "as window_shape {}".format(window_shape)
        raise ValueError(msg)

    out = jla.reduce_window(
        operand=input * rescalor,
        init_value=init_val,
        computation=reducer,
        window_dimensions=window_shape,
        window_strides=strides,
        padding=padding,
    )
    return out 
開發者ID:SymJAX,項目名稱:SymJAX,代碼行數:55,代碼來源:ops_nn.py


注:本文中的jax.lax.reduce_window方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。