當前位置: 首頁>>代碼示例>>Python>>正文


Python lax.max方法代碼示例

本文整理匯總了Python中jax.lax.max方法的典型用法代碼示例。如果您正苦於以下問題:Python lax.max方法的具體用法?Python lax.max怎麽用?Python lax.max使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在jax.lax的用法示例。


在下文中一共展示了lax.max方法的6個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: jax_max_pool

# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import max [as 別名]
def jax_max_pool(x, pool_size, strides, padding):
  return _pooling_general(x, lax.max, -jnp.inf, pool_size=pool_size,
                          strides=strides, padding=padding) 
開發者ID:yyht,項目名稱:BERT,代碼行數:5,代碼來源:backend.py

示例2: softmax

# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import max [as 別名]
def softmax(self, x, T=1.0):
        unnormalized = np.exp(x - x.max(-1, keepdims=True))
        return unnormalized / unnormalized.sum(-1, keepdims=True) 
開發者ID:sharadmv,項目名稱:deepx,代碼行數:5,代碼來源:jax.py

示例3: pool2d

# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import max [as 別名]
def pool2d(self, x, pool_size, strides=(1, 1),
               border_mode='valid', pool_mode='max'):
        dims = (1,) + pool_size + (1,)
        strides = (1,) + strides + (1,)
        return lax.reduce_window(x, -np.inf, lax.max, dims, strides, border_mode) 
開發者ID:sharadmv,項目名稱:deepx,代碼行數:7,代碼來源:jax.py

示例4: reduce_max

# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import max [as 別名]
def reduce_max(self, x, axis=None, keepdims=False):
        return np.max(x, axis=axis, keepdim=keepdims) 
開發者ID:sharadmv,項目名稱:deepx,代碼行數:4,代碼來源:jax.py

示例5: max

# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import max [as 別名]
def max(self, x, axis=None, keepdims=False):
        return self.reduce_max(x, axis=axis, keepdims=keepdims) 
開發者ID:sharadmv,項目名稱:deepx,代碼行數:4,代碼來源:jax.py

示例6: poolNd

# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import max [as 別名]
def poolNd(
    input,
    window_shape,
    reducer="MAX",
    strides=None,
    padding="VALID",
    init_val=None,
    rescalor=None,
):

    # set up the reducer
    if reducer == "MAX":
        reducer = jla.max
        rescalor = numpy.float32(1.0)
        init_val = -numpy.inf
    elif reducer == "SUM" or reducer == "AVG":
        reducer = jla.add
        if reducer == "AVG":
            rescalor = numpy.float32(1.0 / numpy.prod(window_shape))
        else:
            rescalor = numpy.float32(1.0)
        if init_val is None:
            init_val = 0.0

    # set up the window_shape
    if numpy.isscalar(window_shape):
        window_shape = (window_shape,) * input.ndim
    elif len(window_shape) != input.ndim:
        msg = "Given window_shape {} not the same length ".format(
            strides
        ) + "as input shape {}".format(input.ndim)
        raise ValueError(msg)

    # set up the strides
    if strides is None:
        strides = window_shape
    elif numpy.isscalar(strides):
        strides = (strides,) * len(window_shape)
    elif len(strides) != len(window_shape):
        msg = "Given strides {} not the same length ".format(
            strides
        ) + "as window_shape {}".format(window_shape)
        raise ValueError(msg)

    out = jla.reduce_window(
        operand=input * rescalor,
        init_value=init_val,
        computation=reducer,
        window_dimensions=window_shape,
        window_strides=strides,
        padding=padding,
    )
    return out 
開發者ID:SymJAX,項目名稱:SymJAX,代碼行數:55,代碼來源:ops_nn.py


注:本文中的jax.lax.max方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。