本文整理汇总了Python中jax.lax.reduce_window方法的典型用法代码示例。如果您正苦于以下问题:Python lax.reduce_window方法的具体用法?Python lax.reduce_window怎么用?Python lax.reduce_window使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类jax.lax
的用法示例。
在下文中一共展示了lax.reduce_window方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _pooling_general
# 需要导入模块: from jax import lax [as 别名]
# 或者: from jax.lax import reduce_window [as 别名]
def _pooling_general(inputs, reducer, init_val, rescaler=None,
pool_size=(2, 2), strides=None, padding="VALID"):
"""Helper: general pooling computation used in pooling layers later."""
spatial_strides = strides or (1,) * len(pool_size)
rescale = rescaler(pool_size, spatial_strides, padding) if rescaler else None
dims = (1,) + pool_size + (1,) # NHWC
strides = (1,) + spatial_strides + (1,)
out = lax.reduce_window(inputs, init_val, reducer, dims, strides, padding)
return rescale(out, inputs) if rescale else out
示例2: _normalize_by_window_size
# 需要导入模块: from jax import lax [as 别名]
# 或者: from jax.lax import reduce_window [as 别名]
def _normalize_by_window_size(dims, spatial_strides, padding): # pylint: disable=invalid-name
def rescale(outputs, inputs):
one = jnp.ones(inputs.shape[1:-1], dtype=inputs.dtype)
window_sizes = lax.reduce_window(
one, 0., lax.add, dims, spatial_strides, padding)
return outputs / window_sizes[..., jnp.newaxis]
return rescale
示例3: _pooling_general
# 需要导入模块: from jax import lax [as 别名]
# 或者: from jax.lax import reduce_window [as 别名]
def _pooling_general(inputs, reducer, init_val, rescaler=None,
pool_size=(2, 2), strides=None, padding='VALID'):
"""Helper: general pooling computation used in pooling layers later."""
spatial_strides = strides or (1,) * len(pool_size)
rescale = rescaler(pool_size, spatial_strides, padding) if rescaler else None
dims = (1,) + pool_size + (1,) # NHWC
strides = (1,) + spatial_strides + (1,)
out = lax.reduce_window(inputs, init_val, reducer, dims, strides, padding)
return rescale(out, inputs) if rescale else out # pylint: disable=not-callable
示例4: _pool
# 需要导入模块: from jax import lax [as 别名]
# 或者: from jax.lax import reduce_window [as 别名]
def _pool(reducer, init_val, rescaler=None):
def Pool(window_shape, strides=None, padding='VALID'):
"""Layer construction function for a pooling layer."""
strides = strides or (1,) * len(window_shape)
rescale = rescaler(window_shape, strides, padding) if rescaler else None
dims = (1,) + window_shape + (1,) # NHWC
strides = (1,) + strides + (1,)
def pool(inputs):
out = lax.reduce_window(inputs, init_val, reducer, dims, strides, padding)
return rescale(out, inputs) if rescale else out
return pool
return Pool
示例5: _normalize_by_window_size
# 需要导入模块: from jax import lax [as 别名]
# 或者: from jax.lax import reduce_window [as 别名]
def _normalize_by_window_size(dims, strides, padding):
def rescale(outputs, inputs):
one = np.ones(inputs.shape[1:-1], dtype=inputs.dtype)
window_sizes = lax.reduce_window(one, 0., lax.add, dims, strides, padding)
return outputs / window_sizes[..., np.newaxis]
return rescale
示例6: pool2d
# 需要导入模块: from jax import lax [as 别名]
# 或者: from jax.lax import reduce_window [as 别名]
def pool2d(self, x, pool_size, strides=(1, 1),
border_mode='valid', pool_mode='max'):
dims = (1,) + pool_size + (1,)
strides = (1,) + strides + (1,)
return lax.reduce_window(x, -np.inf, lax.max, dims, strides, border_mode)
示例7: poolNd
# 需要导入模块: from jax import lax [as 别名]
# 或者: from jax.lax import reduce_window [as 别名]
def poolNd(
input,
window_shape,
reducer="MAX",
strides=None,
padding="VALID",
init_val=None,
rescalor=None,
):
# set up the reducer
if reducer == "MAX":
reducer = jla.max
rescalor = numpy.float32(1.0)
init_val = -numpy.inf
elif reducer == "SUM" or reducer == "AVG":
reducer = jla.add
if reducer == "AVG":
rescalor = numpy.float32(1.0 / numpy.prod(window_shape))
else:
rescalor = numpy.float32(1.0)
if init_val is None:
init_val = 0.0
# set up the window_shape
if numpy.isscalar(window_shape):
window_shape = (window_shape,) * input.ndim
elif len(window_shape) != input.ndim:
msg = "Given window_shape {} not the same length ".format(
strides
) + "as input shape {}".format(input.ndim)
raise ValueError(msg)
# set up the strides
if strides is None:
strides = window_shape
elif numpy.isscalar(strides):
strides = (strides,) * len(window_shape)
elif len(strides) != len(window_shape):
msg = "Given strides {} not the same length ".format(
strides
) + "as window_shape {}".format(window_shape)
raise ValueError(msg)
out = jla.reduce_window(
operand=input * rescalor,
init_value=init_val,
computation=reducer,
window_dimensions=window_shape,
window_strides=strides,
padding=padding,
)
return out