当前位置: 首页>>代码示例>>Python>>正文


Python lax.max方法代码示例

本文整理汇总了Python中jax.lax.max方法的典型用法代码示例。如果您正苦于以下问题:Python lax.max方法的具体用法?Python lax.max怎么用?Python lax.max使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在jax.lax的用法示例。


在下文中一共展示了lax.max方法的6个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: jax_max_pool

# 需要导入模块: from jax import lax [as 别名]
# 或者: from jax.lax import max [as 别名]
def jax_max_pool(x, pool_size, strides, padding):
  return _pooling_general(x, lax.max, -jnp.inf, pool_size=pool_size,
                          strides=strides, padding=padding) 
开发者ID:yyht,项目名称:BERT,代码行数:5,代码来源:backend.py

示例2: softmax

# 需要导入模块: from jax import lax [as 别名]
# 或者: from jax.lax import max [as 别名]
def softmax(self, x, T=1.0):
        unnormalized = np.exp(x - x.max(-1, keepdims=True))
        return unnormalized / unnormalized.sum(-1, keepdims=True) 
开发者ID:sharadmv,项目名称:deepx,代码行数:5,代码来源:jax.py

示例3: pool2d

# 需要导入模块: from jax import lax [as 别名]
# 或者: from jax.lax import max [as 别名]
def pool2d(self, x, pool_size, strides=(1, 1),
               border_mode='valid', pool_mode='max'):
        dims = (1,) + pool_size + (1,)
        strides = (1,) + strides + (1,)
        return lax.reduce_window(x, -np.inf, lax.max, dims, strides, border_mode) 
开发者ID:sharadmv,项目名称:deepx,代码行数:7,代码来源:jax.py

示例4: reduce_max

# 需要导入模块: from jax import lax [as 别名]
# 或者: from jax.lax import max [as 别名]
def reduce_max(self, x, axis=None, keepdims=False):
        return np.max(x, axis=axis, keepdim=keepdims) 
开发者ID:sharadmv,项目名称:deepx,代码行数:4,代码来源:jax.py

示例5: max

# 需要导入模块: from jax import lax [as 别名]
# 或者: from jax.lax import max [as 别名]
def max(self, x, axis=None, keepdims=False):
        return self.reduce_max(x, axis=axis, keepdims=keepdims) 
开发者ID:sharadmv,项目名称:deepx,代码行数:4,代码来源:jax.py

示例6: poolNd

# 需要导入模块: from jax import lax [as 别名]
# 或者: from jax.lax import max [as 别名]
def poolNd(
    input,
    window_shape,
    reducer="MAX",
    strides=None,
    padding="VALID",
    init_val=None,
    rescalor=None,
):

    # set up the reducer
    if reducer == "MAX":
        reducer = jla.max
        rescalor = numpy.float32(1.0)
        init_val = -numpy.inf
    elif reducer == "SUM" or reducer == "AVG":
        reducer = jla.add
        if reducer == "AVG":
            rescalor = numpy.float32(1.0 / numpy.prod(window_shape))
        else:
            rescalor = numpy.float32(1.0)
        if init_val is None:
            init_val = 0.0

    # set up the window_shape
    if numpy.isscalar(window_shape):
        window_shape = (window_shape,) * input.ndim
    elif len(window_shape) != input.ndim:
        msg = "Given window_shape {} not the same length ".format(
            strides
        ) + "as input shape {}".format(input.ndim)
        raise ValueError(msg)

    # set up the strides
    if strides is None:
        strides = window_shape
    elif numpy.isscalar(strides):
        strides = (strides,) * len(window_shape)
    elif len(strides) != len(window_shape):
        msg = "Given strides {} not the same length ".format(
            strides
        ) + "as window_shape {}".format(window_shape)
        raise ValueError(msg)

    out = jla.reduce_window(
        operand=input * rescalor,
        init_value=init_val,
        computation=reducer,
        window_dimensions=window_shape,
        window_strides=strides,
        padding=padding,
    )
    return out 
开发者ID:SymJAX,项目名称:SymJAX,代码行数:55,代码来源:ops_nn.py


注:本文中的jax.lax.max方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。