當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Tensorflow.js tf.GraphModel.predict()用法及代碼示例

Tensorflow.js是由Google開發的開源庫,用於在瀏覽器或節點環境中運行機器學習模型以及深度學習神經網絡。

.predict() 函數用於實現有利於輸入張量的含義。

用法:

predict(inputs, config?)

Parameters: 

  • inputs:它是規定的輸入。它的類型是 (tf.Tensor|tf.Tensor[]|{[name:string]:tf.Tensor})。
  • config:它是規定的預測配置,用於定義批量大小以及輸出節點名稱。此外,目前圖模型忽略了批量大小的選擇。它是可選的並且是對象類型。
    • batchSize:指定的批次維度是可選的,並且是整數類型。如果未定義,則默認值為 32。
    • verbose:它是規定的詳細模式,其默認值為 false 並且是可選的。

返回值:它返回 tf.Tensor|tf.Tensor[]|{[name:string]:tf.Tensor}。



範例1:在這個例子中,我們從一個 URL 加載 MobileNetV2 並保存一個帶有零輸入的預測。

Javascript


// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining tensor input elements
const model_Url =
'https://storage.googleapis.com/tfjs-models/savedmodel/mobilenet_v2_1.0_224/model.json';
  
// Calling the loadGraphModel() method
const mymodel = await tf.loadGraphModel(model_Url);
  
// Defining inputs
const inputs = tf.zeros([1, 224, 224, 3]);
  
// Calling predict() method and 
// Printing output
mymodel.predict(inputs).print();

輸出:

Tensor
     [[-0.1800361, -0.4059965, 0.8190175, 
     ..., 
     -0.8953396, -1.0841646, 1.2912753],]

範例2:在這個例子中,我們從 TF Hub URL 加載 MobileNetV2 並保持一個帶有零輸入的預測。

Javascript


// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining tensor input elements
const model_Url =
'https://tfhub.dev/google/imagenet/mobilenet_v2_140_224/classification/2';
  
// Calling the loadGraphModel() method
const model = await tf.loadGraphModel(
        model_Url, {fromTFHub:true});
  
// Defining inputs
const inputs = tf.zeros([1, 224, 224, 3]);
  
// Defining batchsize
const batchsize = 1;
  
// Defining verbose
const verbose = true;
  
// Calling predict() method and
// Printing output
model.predict(inputs, batchsize, verbose).print();

輸出:

Tensor
     [[-1.1690605, 0.0195426, 1.1962479, 
     ..., 
     -0.4825858, -0.0055641, 1.1937635],]

參考: https://js.tensorflow.org/api/latest/#tf.GraphModel.predict




相關用法


注:本文由純淨天空篩選整理自nidhi1352singh大神的英文原創作品 Tensorflow.js tf.GraphModel class .predict() Method。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。