當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.nn.max_pool用法及代碼示例


對輸入執行最大池化。

用法

tf.nn.max_pool(
    input, ksize, strides, padding, data_format=None, name=None
)

參數

  • input 如果 data_format 不以 "NC"(默認)開頭,則為 N+2 階張量,形狀為 [batch_size] + input_spatial_shape + [num_channels],如果 data_format 以 "NC" 開頭,則為 [batch_size, num_channels] + input_spatial_shape。池化僅發生在空間維度上。
  • ksize 長度為 1 , NN+2ints 的 int 或列表。輸入張量的每個維度的窗口大小。
  • strides 長度為 1 , NN+2ints 的 int 或列表。輸入張量的每個維度的滑動窗口的步幅。
  • padding 無論是string "SAME"或者"VALID"指示要使用的填充算法的類型,或指示每個維度開始和結束處的顯式填充的列表。看這裏了解更多信息。當使用顯式填充並且data_format 是"NHWC", 這應該是形式[[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]].當使用顯式填充並且 data_format 是"NCHW", 這應該是形式[[0, 0], [0, 0], [pad_top, pad_bottom], [pad_left, pad_right]].使用顯式填充時,填充的大小不能大於滑動窗口大小。
  • data_format 一個字符串。指定通道尺寸。 N=1 可以是"NWC"(默認)或"NCW",N=2 可以是"NHWC"(默認)或"NCHW",N=3 可以是"NDHWC"(默認)或"NCDHW"。
  • name 操作的可選名稱。

返回

  • data_format 指定的格式的 Tensor 。最大池化輸出張量。

對於 ksize 的給定窗口,取該窗口內的最大值。用於減少計算量和防止過擬合。

考慮一個使用 2x2 非重疊窗口進行池化的示例:

matrix = tf.constant([
    [0, 0, 1, 7],
    [0, 2, 0, 0],
    [5, 2, 0, 0],
    [0, 0, 9, 8],
])
reshaped = tf.reshape(matrix, (1, 4, 4, 1))
tf.nn.max_pool(reshaped, ksize=2, strides=2, padding="SAME")
<tf.Tensor:shape=(1, 2, 2, 1), dtype=int32, numpy=
array([[[[2],
         [7]],
        [[5],
         [9]]]], dtype=int32)>

我們可以使用ksize 參數調整窗口大小。例如,如果我們要將窗口擴大到 3:

tf.nn.max_pool(reshaped, ksize=3, strides=2, padding="SAME")
<tf.Tensor:shape=(1, 2, 2, 1), dtype=int32, numpy=
array([[[[5],
         [7]],
        [[9],
         [9]]]], dtype=int32)>

現在,我們在兩個集合點中增加了兩個額外的大數(5 和 9)。

請注意,我們的窗口現在是重疊的,因為每次迭代我們仍然移動 2 個單位。這導致我們看到相同的 9 重複了兩次,因為它是兩個重疊窗口的一部分。

我們可以使用strides 參數調整每次迭代移動窗口的距離。將其更新為與我們的窗口大小相同的值可以消除重疊:

tf.nn.max_pool(reshaped, ksize=3, strides=3, padding="SAME")
<tf.Tensor:shape=(1, 2, 2, 1), dtype=int32, numpy=
array([[[[2],
         [7]],
        [[5],
         [9]]]], dtype=int32)>

因為窗口不能很好地適應我們的輸入,所以在邊添加了填充,給我們與使用 2x2 窗口時相同的結果。我們可以完全跳過填充,並通過將"VALID" 傳遞給padding 參數來簡單地刪除不完全適合我們輸入的窗口:

tf.nn.max_pool(reshaped, ksize=3, strides=3, padding="VALID")
<tf.Tensor:shape=(1, 1, 1, 1), dtype=int32, numpy=array([[[[5]]]],
 dtype=int32)>

現在我們已經從左上角開始抓取了 3x3 窗口中的最大值。由於沒有其他窗口適合我們的輸入,因此它們被刪除。

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.nn.max_pool。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。