當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Tensorflow.js tf.pool()用法及代碼示例

簡介:Tensorflow.js 是 Google 開發的一個開源庫,用於在瀏覽器或節點環境中運行機器學習模型以及深度學習神經網絡。

.pool() 函數用於執行 N-D 池化函數。

用法:

tf.pool(input, windowShape, poolingType, pad, dilations?, strides?)

參數:

  • input:指定的輸入張量,其等級為 4 或等級 3,形狀為:[batch, height, width, inChannels]。此外,如果等級為 3,則假定批次大小為 1。它可以是 tf.Tensor3D、tf.Tensor4D、TypedArray 或 Array 類型。
  • windowShape:規定的過濾器尺寸:[filterHeight, filterWidth]。它可以是 [number, number] 或 number 類型。如果 filterSize 是一個單數,那麽 filterHeight == filterWidth。
  • poolingType:規定的池化類型,可以是 ‘max’ 或 ‘avg’。
  • pad:規定的填充算法類型。它的類型可以是 valid、same、number 或 conv_util.ExplicitPadding。
    1. 在這裏,對於相同和步長 1,無論過濾器大小如何,輸出都將具有與輸入相同的大小。
    2. 對於,‘valid’,在濾波器尺寸大於1*1×1的情況下,輸出應小於輸入。
  • dilations:所述膨脹率:[dilationHeight, dilationWidth] 輸入值在膨脹池中的高度和寬度維度上進行采樣。默認值為 [1, 1]。此外,如果 dilations 是單個數字,則 dilationHeight == dilationWidth。如果它大於 1,那麽步幅的所有值都應該是 1。它是可選的,並且是 [number, number], number 類型。
  • strides:池化的規定步幅:[strideHeight, strideWidth]。如果 strides 是一個單數,那麽 strideHeight == strideWidth。它是可選的,類型為 [number, number] 或 number。

返回值:它返回 tf.Tensor3D 或 tf.Tensor4D。

範例1:

Javascript


// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining input tensor
const x = tf.tensor3d([1, 2, 3, 4], [2, 2, 1]);
  
// Calling pool() method
const result = tf.pool(x, 3, 'avg', 'same', [1, 2], 1);
   
// Printing output
result.print();

輸出:

Tensor
    [[[0.4444444],
      [0.6666667]],

     [[0.4444444],
      [0.6666667]]]

範例2:

Javascript


// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Calling pool() method
tf.tensor3d([1.2, 2.1, 3.0, -4], [2, 2, 1]).pool(3,
                    'conv_util.ExplicitPadding', 1, 1).print();

輸出:

Tensor
    [[[3],
      [3]],

     [[3],
      [3]]]

參考:https://js.tensorflow.org/api/latest/#pool


相關用法


注:本文由純淨天空篩選整理自nidhi1352singh大神的英文原創作品 Tensorflow.js tf.pool() Function。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。