当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Tensorflow.js tf.pool()用法及代码示例


简介:Tensorflow.js 是 Google 开发的一个开源库,用于在浏览器或节点环境中运行机器学习模型以及深度学习神经网络。

.pool() 函数用于执行 N-D 池化函数。

用法:

tf.pool(input, windowShape, poolingType, pad, dilations?, strides?)

参数:

  • input:指定的输入张量,其等级为 4 或等级 3,形状为:[batch, height, width, inChannels]。此外,如果等级为 3,则假定批次大小为 1。它可以是 tf.Tensor3D、tf.Tensor4D、TypedArray 或 Array 类型。
  • windowShape:规定的过滤器尺寸:[filterHeight, filterWidth]。它可以是 [number, number] 或 number 类型。如果 filterSize 是一个单数,那么 filterHeight == filterWidth。
  • poolingType:规定的池化类型,可以是 ‘max’ 或 ‘avg’。
  • pad:规定的填充算法类型。它的类型可以是 valid、same、number 或 conv_util.ExplicitPadding。
    1. 在这里,对于相同和步长 1,无论过滤器大小如何,输出都将具有与输入相同的大小。
    2. 对于,‘valid’,在滤波器尺寸大于1*1×1的情况下,输出应小于输入。
  • dilations:所述膨胀率:[dilationHeight, dilationWidth] 输入值在膨胀池中的高度和宽度维度上进行采样。默认值为 [1, 1]。此外,如果 dilations 是单个数字,则 dilationHeight == dilationWidth。如果它大于 1,那么步幅的所有值都应该是 1。它是可选的,并且是 [number, number], number 类型。
  • strides:池化的规定步幅:[strideHeight, strideWidth]。如果 strides 是一个单数,那么 strideHeight == strideWidth。它是可选的,类型为 [number, number] 或 number。

返回值:它返回 tf.Tensor3D 或 tf.Tensor4D。

范例1:

Javascript


// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining input tensor
const x = tf.tensor3d([1, 2, 3, 4], [2, 2, 1]);
  
// Calling pool() method
const result = tf.pool(x, 3, 'avg', 'same', [1, 2], 1);
   
// Printing output
result.print();

输出:

Tensor
    [[[0.4444444],
      [0.6666667]],

     [[0.4444444],
      [0.6666667]]]

范例2:

Javascript


// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Calling pool() method
tf.tensor3d([1.2, 2.1, 3.0, -4], [2, 2, 1]).pool(3,
                    'conv_util.ExplicitPadding', 1, 1).print();

输出:

Tensor
    [[[3],
      [3]],

     [[3],
      [3]]]

参考:https://js.tensorflow.org/api/latest/#pool


相关用法


注:本文由纯净天空筛选整理自nidhi1352singh大神的英文原创作品 Tensorflow.js tf.pool() Function。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。