当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Tensorflow.js tf.maxPoolWithArgmax()用法及代码示例


Tensorflow.js是由Google开发的开源库,用于在浏览器或节点环境中运行机器学习模型以及深度学习神经网络。

.maxPoolWithArgmax() 函数用于确定图像的 2D 最大池化以及 argmax 列表,即索引。其中,argmax 中的索引是水平的,以便位置 [b, y, x, c] 处的峰值变为压缩索引:(y * width + x) * channels + c 以防 include_batch_in_index 为假,如果 include_batch_in_index为真,则为 ((b * height + y) * width + x) * channels +c。此外,在展平之前,返回的索引始终在 [0, height) x [0, width) 中。

用法:

tf.maxPoolWithArgmax(x, filterSize, 
    strides, pad, includeBatchInIndex?)

参数:

  • x:指定的输入张量,其等级为 4 或等级 3,形状为:[batch, height, width, inChannels]。此外,如果等级为 3,则假定批次大小为 1。它可以是 tf.Tensor4D、TypedArray 或 Array 类型。
  • filterSize:指定形状的过滤器大小:[filterHeight, filterWidth]。如果过滤器大小是一个奇异数,那么 filterHeight == filterWidth。它可以是 [number, number] 或 number 类型。
  • strides:形状池的规定步幅:[strideHeight, strideWidth]。如果 strides 是一个单数,那么 strideHeight == strideWidth。它可以是 [number, number] 或 number 类型。
  • pad:规定的填充算法类型。它可以是类型 valid、same 或 number。
    • 在这里,对于相同和步长 1,输出将具有与输入相同的大小,而与滤波器大小无关。
    • 对于,‘valid’,在滤波器尺寸大于1*1×1的情况下,输出应小于输入。
  • includeBatchInIndex:它是可选的并且是布尔类型。

返回值:它返回 {[name:string]:tf.Tensor}。

范例1:

Javascript


// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining input tensor
const x = tf.tensor4d([1, 2, 3, 4], [2, 2, 1, 1]);
  
// Calling maxPoolWithArgmax() method
const result = tf.maxPoolWithArgmax(x, 3, 2, 'same');
  
// Printing output
console.log(result)

输出:

{
  "result":{
    "kept":false,
    "isDisposedInternal":false,
    "shape":[
      2,
      1,
      1,
      1
    ],
    "dtype":"float32",
    "size":2,
    "strides":[
      1,
      1,
      1
    ],
    "dataId":{
      "id":20
    },
    "id":20,
    "rankType":"4",
    "scopeId":14
  },
  "indexes":{
    "kept":false,
    "isDisposedInternal":false,
    "shape":[
      2,
      1,
      1,
      1
    ],
    "dtype":"float32",
    "size":2,
    "strides":[
      1,
      1,
      1
    ],
    "dataId":{
      "id":21
    },
    "id":21,
    "rankType":"4",
    "scopeId":14
  }
}

范例2:

Javascript


// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Calling maxPoolWithArgmax() method
console.log(tf.maxPoolWithArgmax(
    tf.tensor4d([1.1, 2.1, 3.1, 4.1], 
    [1, 2, 2, 1]), [1, 2], [1, 1], 
    'valid', true
));

输出:

{
  "result":{
    "kept":false,
    "isDisposedInternal":false,
    "shape":[
      1,
      2,
      1,
      1
    ],
    "dtype":"float32",
    "size":2,
    "strides":[
      2,
      1,
      1
    ],
    "dataId":{
      "id":80
    },
    "id":80,
    "rankType":"4",
    "scopeId":54
  },
  "indexes":{
    "kept":false,
    "isDisposedInternal":false,
    "shape":[
      1,
      2,
      1,
      1
    ],
    "dtype":"float32",
    "size":2,
    "strides":[
      2,
      1,
      1
    ],
    "dataId":{
      "id":81
    },
    "id":81,
    "rankType":"4",
    "scopeId":54
  }
}

参考:https://js.tensorflow.org/api/latest/#maxPoolWithArgmax


相关用法


注:本文由纯净天空筛选整理自nidhi1352singh大神的英文原创作品 Tensorflow.js tf.maxPoolWithArgmax() Function。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。