当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Tensorflow.js tf.LayersModel.predict()用法及代码示例


Tensorflow.js是由Google开发的开源库,用于在浏览器或节点环境中运行机器学习模型以及深度学习神经网络。

.predict() 函数用于生成给定输入实例的输出估计。此外,这里的计算是成套进行的。其中,由于只需要tensorflow.js的核心后端,目前不支持step操作。

用法:

predict(x, args?)

Parameters: 

  • x:它是规定的输入数据,如张量,否则是 tf.Tensors 数组,以防模型有各种输入。它可以是 tf.Tensor 或 tf.Tensor[] 类型。
  • args:它是规定的 ModelPredictArgs 包含可选字段的对象。
    1. batchSize:它是指定的批处理维度,它是整数类型。如果未定义,默认值为 32。
    2. verbose:它是规定的详细模式,其默认值为 false。

返回值:它返回 tf.Tensor 对象或 tf.Tensor[]。



范例1:

Javascript


// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining model
const Mod = tf.sequential({
   layers:[tf.layers.dense({units:2, inputShape:[30]})]
});
  
// Calling predict() method and
// Printing output
Mod.predict(tf.randomNormal([6, 30])).print();

输出:

Tensor
    [[-0.7650393, -0.8317917],
     [-0.7274997, 1.827635  ],
     [-0.9398478, -0.2998275],
     [-1.0945926, -1.9154934],
     [0.0067322 , -1.9220339],
     [0.2052939 , 0.6488774 ]]

范例2:

Javascript


// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Calling predict() method and
// Printing output
tf.sequential({
   layers:[tf.layers.dense({units:3, inputShape:[10]})]
}).predict(tf.truncatedNormal([2, 10]), {batchSize:2}, true).print();

输出:

Tensor
    [[0.2670097, -1.2741219, -0.3159108],
     [0.9108799, -0.1305539, -0.1370454]]

参考: https://js.tensorflow.org/api/latest/#tf.LayersModel.predict

相关用法


注:本文由纯净天空筛选整理自nidhi1352singh大神的英文原创作品 Tensorflow.js tf.LayersModel class .predict() Method。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。