本文整理匯總了Python中jax.lax.max方法的典型用法代碼示例。如果您正苦於以下問題:Python lax.max方法的具體用法?Python lax.max怎麽用?Python lax.max使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類jax.lax
的用法示例。
在下文中一共展示了lax.max方法的6個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: jax_max_pool
# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import max [as 別名]
def jax_max_pool(x, pool_size, strides, padding):
return _pooling_general(x, lax.max, -jnp.inf, pool_size=pool_size,
strides=strides, padding=padding)
示例2: softmax
# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import max [as 別名]
def softmax(self, x, T=1.0):
unnormalized = np.exp(x - x.max(-1, keepdims=True))
return unnormalized / unnormalized.sum(-1, keepdims=True)
示例3: pool2d
# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import max [as 別名]
def pool2d(self, x, pool_size, strides=(1, 1),
border_mode='valid', pool_mode='max'):
dims = (1,) + pool_size + (1,)
strides = (1,) + strides + (1,)
return lax.reduce_window(x, -np.inf, lax.max, dims, strides, border_mode)
示例4: reduce_max
# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import max [as 別名]
def reduce_max(self, x, axis=None, keepdims=False):
return np.max(x, axis=axis, keepdim=keepdims)
示例5: max
# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import max [as 別名]
def max(self, x, axis=None, keepdims=False):
return self.reduce_max(x, axis=axis, keepdims=keepdims)
示例6: poolNd
# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import max [as 別名]
def poolNd(
input,
window_shape,
reducer="MAX",
strides=None,
padding="VALID",
init_val=None,
rescalor=None,
):
# set up the reducer
if reducer == "MAX":
reducer = jla.max
rescalor = numpy.float32(1.0)
init_val = -numpy.inf
elif reducer == "SUM" or reducer == "AVG":
reducer = jla.add
if reducer == "AVG":
rescalor = numpy.float32(1.0 / numpy.prod(window_shape))
else:
rescalor = numpy.float32(1.0)
if init_val is None:
init_val = 0.0
# set up the window_shape
if numpy.isscalar(window_shape):
window_shape = (window_shape,) * input.ndim
elif len(window_shape) != input.ndim:
msg = "Given window_shape {} not the same length ".format(
strides
) + "as input shape {}".format(input.ndim)
raise ValueError(msg)
# set up the strides
if strides is None:
strides = window_shape
elif numpy.isscalar(strides):
strides = (strides,) * len(window_shape)
elif len(strides) != len(window_shape):
msg = "Given strides {} not the same length ".format(
strides
) + "as window_shape {}".format(window_shape)
raise ValueError(msg)
out = jla.reduce_window(
operand=input * rescalor,
init_value=init_val,
computation=reducer,
window_dimensions=window_shape,
window_strides=strides,
padding=padding,
)
return out