当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.compat.v1.nn.pool用法及代码示例


执行 N-D 池化操作。

用法

tf.compat.v1.nn.pool(
    input, window_shape, pooling_type, padding, dilation_rate=None, strides=None,
    name=None, data_format=None, dilations=None
)

参数

  • input 如果 data_format 不以 "NC"(默认)开头,则为 N+2 阶张量,形状为 [batch_size] + input_spatial_shape + [num_channels],如果 data_format 以 "NC" 开头,则为 [batch_size, num_channels] + input_spatial_shape。池化仅发生在空间维度上。
  • window_shape N 个整数的序列 >= 1。
  • pooling_type 指定池操作,必须是"AVG" 或"MAX"。
  • padding 填充算法,必须是"SAME" 或"VALID"。有关详细信息,请参阅tf.nn.convolution 的"returns" 部分。
  • dilation_rate 可选的。膨胀率。 N 个整数列表 >= 1。默认为 [1]N. 如果 dilation_rate 的任何值 > 1,则所有 strides 的值都必须为 1。
  • strides 可选的。 N 个整数的序列 >= 1。默认为 [1]N。如果任何 strides 值 > 1,则 dilation_rate 的所有值都必须为 1。
  • name 可选的。操作的名称。
  • data_format 一个字符串或无。指定input 和输出的通道维度是最后一个维度(默认,或者如果data_format 不以"NC" 开头),还是第二个维度(如果data_format 以"NC" 开头)。对于 N=1,有效值为 "NWC"(默认)和 "NCW"。对于 N=2,有效值为 "NHWC"(默认)和 "NCHW"。对于 N=3,有效值为 "NDHWC"(默认)和 "NCDHW"。
  • dilations dilation_rate 的别名

返回

  • N+2 阶张量,形状为 [batch_size] + output_spatial_shape + [num_channels]

    如果 data_format 为 None 或不以 "NC" 开头,或

    [batch_size, num_channels] + output_spatial_shape

    如果data_format以"NC"开头,其中output_spatial_shape取决于填充的值:

    如果填充 = "SAME":output_spatial_shape[i] = ceil(input_spatial_shape[i] /strides[i])

    如果填充 = "VALID":output_spatial_shape[i] = ceil((input_spatial_shape[i] - (window_shape[i] - 1) * dilation_rate[i]) /strides[i])。

抛出

  • ValueError 如果参数无效。

data_format 不以 "NC" 开头的情况下,计算 0

output[b, x[0], ..., x[N-1], c] =
    REDUCE_{z[0], ..., z[N-1]}
      input[b,
            x[0] * strides[0] - pad_before[0] + dilation_rate[0]*z[0],
            ...
            x[N-1]*strides[N-1] - pad_before[N-1] + dilation_rate[N-1]*z[N-1],
            c],

其中归约函数 REDUCE 取决于 pooling_type 的值,而 pad_before 是根据 padding 的值定义的,如tf.nn.convolution 的"returns" 部分所述,详情参见tf.nn.convolution。减少从不包括越界位置。

data_format"NC" 开头的情况下,input 和输出简单地转置如下:

pool(input, data_format, **kwargs) =
    tf.transpose(pool(tf.transpose(input, [0] + range(2,N+2) + [1]),
                      **kwargs),
                 [0, N+1] + range(1, N+1))

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.compat.v1.nn.pool。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。