当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.nn.max_pool用法及代码示例


对输入执行最大池化。

用法

tf.nn.max_pool(
    input, ksize, strides, padding, data_format=None, name=None
)

参数

  • input 如果 data_format 不以 "NC"(默认)开头,则为 N+2 阶张量,形状为 [batch_size] + input_spatial_shape + [num_channels],如果 data_format 以 "NC" 开头,则为 [batch_size, num_channels] + input_spatial_shape。池化仅发生在空间维度上。
  • ksize 长度为 1 , NN+2ints 的 int 或列表。输入张量的每个维度的窗口大小。
  • strides 长度为 1 , NN+2ints 的 int 或列表。输入张量的每个维度的滑动窗口的步幅。
  • padding 无论是string "SAME"或者"VALID"指示要使用的填充算法的类型,或指示每个维度开始和结束处的显式填充的列表。看这里了解更多信息。当使用显式填充并且data_format 是"NHWC", 这应该是形式[[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]].当使用显式填充并且 data_format 是"NCHW", 这应该是形式[[0, 0], [0, 0], [pad_top, pad_bottom], [pad_left, pad_right]].使用显式填充时,填充的大小不能大于滑动窗口大小。
  • data_format 一个字符串。指定通道尺寸。 N=1 可以是"NWC"(默认)或"NCW",N=2 可以是"NHWC"(默认)或"NCHW",N=3 可以是"NDHWC"(默认)或"NCDHW"。
  • name 操作的可选名称。

返回

  • data_format 指定的格式的 Tensor 。最大池化输出张量。

对于 ksize 的给定窗口,取该窗口内的最大值。用于减少计算量和防止过拟合。

考虑一个使用 2x2 非重叠窗口进行池化的示例:

matrix = tf.constant([
    [0, 0, 1, 7],
    [0, 2, 0, 0],
    [5, 2, 0, 0],
    [0, 0, 9, 8],
])
reshaped = tf.reshape(matrix, (1, 4, 4, 1))
tf.nn.max_pool(reshaped, ksize=2, strides=2, padding="SAME")
<tf.Tensor:shape=(1, 2, 2, 1), dtype=int32, numpy=
array([[[[2],
         [7]],
        [[5],
         [9]]]], dtype=int32)>

我们可以使用ksize 参数调整窗口大小。例如,如果我们要将窗口扩大到 3:

tf.nn.max_pool(reshaped, ksize=3, strides=2, padding="SAME")
<tf.Tensor:shape=(1, 2, 2, 1), dtype=int32, numpy=
array([[[[5],
         [7]],
        [[9],
         [9]]]], dtype=int32)>

现在,我们在两个集合点中增加了两个额外的大数(5 和 9)。

请注意,我们的窗口现在是重叠的,因为每次迭代我们仍然移动 2 个单位。这导致我们看到相同的 9 重复了两次,因为它是两个重叠窗口的一部分。

我们可以使用strides 参数调整每次迭代移动窗口的距离。将其更新为与我们的窗口大小相同的值可以消除重叠:

tf.nn.max_pool(reshaped, ksize=3, strides=3, padding="SAME")
<tf.Tensor:shape=(1, 2, 2, 1), dtype=int32, numpy=
array([[[[2],
         [7]],
        [[5],
         [9]]]], dtype=int32)>

因为窗口不能很好地适应我们的输入,所以在边添加了填充,给我们与使用 2x2 窗口时相同的结果。我们可以完全跳过填充,并通过将"VALID" 传递给padding 参数来简单地删除不完全适合我们输入的窗口:

tf.nn.max_pool(reshaped, ksize=3, strides=3, padding="VALID")
<tf.Tensor:shape=(1, 1, 1, 1), dtype=int32, numpy=array([[[[5]]]],
 dtype=int32)>

现在我们已经从左上角开始抓取了 3x3 窗口中的最大值。由于没有其他窗口适合我们的输入,因此它们被删除。

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.nn.max_pool。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。