當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.compat.v1.nn.pool用法及代碼示例


執行 N-D 池化操作。

用法

tf.compat.v1.nn.pool(
    input, window_shape, pooling_type, padding, dilation_rate=None, strides=None,
    name=None, data_format=None, dilations=None
)

參數

  • input 如果 data_format 不以 "NC"(默認)開頭,則為 N+2 階張量,形狀為 [batch_size] + input_spatial_shape + [num_channels],如果 data_format 以 "NC" 開頭,則為 [batch_size, num_channels] + input_spatial_shape。池化僅發生在空間維度上。
  • window_shape N 個整數的序列 >= 1。
  • pooling_type 指定池操作,必須是"AVG" 或"MAX"。
  • padding 填充算法,必須是"SAME" 或"VALID"。有關詳細信息,請參閱tf.nn.convolution 的"returns" 部分。
  • dilation_rate 可選的。膨脹率。 N 個整數列表 >= 1。默認為 [1]N. 如果 dilation_rate 的任何值 > 1,則所有 strides 的值都必須為 1。
  • strides 可選的。 N 個整數的序列 >= 1。默認為 [1]N。如果任何 strides 值 > 1,則 dilation_rate 的所有值都必須為 1。
  • name 可選的。操作的名稱。
  • data_format 一個字符串或無。指定input 和輸出的通道維度是最後一個維度(默認,或者如果data_format 不以"NC" 開頭),還是第二個維度(如果data_format 以"NC" 開頭)。對於 N=1,有效值為 "NWC"(默認)和 "NCW"。對於 N=2,有效值為 "NHWC"(默認)和 "NCHW"。對於 N=3,有效值為 "NDHWC"(默認)和 "NCDHW"。
  • dilations dilation_rate 的別名

返回

  • N+2 階張量,形狀為 [batch_size] + output_spatial_shape + [num_channels]

    如果 data_format 為 None 或不以 "NC" 開頭,或

    [batch_size, num_channels] + output_spatial_shape

    如果data_format以"NC"開頭,其中output_spatial_shape取決於填充的值:

    如果填充 = "SAME":output_spatial_shape[i] = ceil(input_spatial_shape[i] /strides[i])

    如果填充 = "VALID":output_spatial_shape[i] = ceil((input_spatial_shape[i] - (window_shape[i] - 1) * dilation_rate[i]) /strides[i])。

拋出

  • ValueError 如果參數無效。

data_format 不以 "NC" 開頭的情況下,計算 0

output[b, x[0], ..., x[N-1], c] =
    REDUCE_{z[0], ..., z[N-1]}
      input[b,
            x[0] * strides[0] - pad_before[0] + dilation_rate[0]*z[0],
            ...
            x[N-1]*strides[N-1] - pad_before[N-1] + dilation_rate[N-1]*z[N-1],
            c],

其中歸約函數 REDUCE 取決於 pooling_type 的值,而 pad_before 是根據 padding 的值定義的,如tf.nn.convolution 的"returns" 部分所述,詳情參見tf.nn.convolution。減少從不包括越界位置。

data_format"NC" 開頭的情況下,input 和輸出簡單地轉置如下:

pool(input, data_format, **kwargs) =
    tf.transpose(pool(tf.transpose(input, [0] + range(2,N+2) + [1]),
                      **kwargs),
                 [0, N+1] + range(1, N+1))

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.compat.v1.nn.pool。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。