Tensorflow.js是由Google開發的開源庫,用於在瀏覽器或節點環境中運行機器學習模型以及深度學習神經網絡。
.predict() 函數用於生成給定輸入實例的輸出估計。此外,這裏的計算是成套進行的。其中,由於隻需要tensorflow.js的核心後端,目前不支持step操作。
用法:
predict(x, args?)
Parameters:
- x:它是規定的輸入數據,如張量,否則是 tf.Tensors 數組,以防模型有各種輸入。它可以是 tf.Tensor 或 tf.Tensor[] 類型。
- args:它是規定的 ModelPredictArgs 包含可選字段的對象。
- batchSize:它是指定的批處理維度,它是整數類型。如果未定義,默認值為 32。
- verbose:它是規定的詳細模式,其默認值為 false。
返回值:它返回 tf.Tensor 對象或 tf.Tensor[]。
範例1:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Defining model
const Mod = tf.sequential({
layers:[tf.layers.dense({units:2, inputShape:[30]})]
});
// Calling predict() method and
// Printing output
Mod.predict(tf.randomNormal([6, 30])).print();
輸出:
Tensor [[-0.7650393, -0.8317917], [-0.7274997, 1.827635 ], [-0.9398478, -0.2998275], [-1.0945926, -1.9154934], [0.0067322 , -1.9220339], [0.2052939 , 0.6488774 ]]
範例2:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Calling predict() method and
// Printing output
tf.sequential({
layers:[tf.layers.dense({units:3, inputShape:[10]})]
}).predict(tf.truncatedNormal([2, 10]), {batchSize:2}, true).print();
輸出:
Tensor [[0.2670097, -1.2741219, -0.3159108], [0.9108799, -0.1305539, -0.1370454]]
參考: https://js.tensorflow.org/api/latest/#tf.LayersModel.predict
相關用法
- Tensorflow.js tf.GraphModel.predict()用法及代碼示例
- Tensorflow.js tf.Tensor.buffer()用法及代碼示例
- Java String repeat()用法及代碼示例
- Tensorflow.js tf.LayersModel.evaluate()用法及代碼示例
- Tensorflow.js tf.data.Dataset.batch()用法及代碼示例
- Tensorflow.js tf.Sequential.add()用法及代碼示例
注:本文由純淨天空篩選整理自nidhi1352singh大神的英文原創作品 Tensorflow.js tf.LayersModel class .predict() Method。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。