當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Tensorflow.js tf.GraphModel.save()用法及代碼示例


簡介:Tensorflow.js 是 Google 開發的一個開源庫,用於在瀏覽器或節點環境中運行機器學習模型以及深度學習神經網絡。

.save() 函數用於保存所述 GraphModel 的結構和/或權重。

注意:

  • IOHandler 是一個對象,它擁有與指定的準確簽名相關的保存方法。
  • save 方法控製順序數據的累積或傳輸,即描述模型拓撲的工件以及特定介質上或通過特定介質的權重,如本地存儲文件、文件下載、Web 瀏覽器中的 IndexedDB 以及 HTTP 請求服務器。
  • TensorFlow.js 啟用 IOHandler 實現以支持許多重複使用的保存介質,例如 tf.io.browserDownloads() 和 tf.io.browserLocalStorage。
  • 此外,該方法還允許我們應用特定類型的 IOHandler,例如 URL-like 字符串技術,例如“localstorage://”和“indexeddb://”。

用法:

save(handlerOrURL, config?)


Parameters: 

  • handlerOrURL:IOHandler 的聲明實例或類似的 URL,基於設計的字符串技術支持 IOHandler。它是 io.IOHandler|string 類型。
  • config:所述選項以保存所述模型。它是可選的並且是對象類型。它下麵有兩個參數,如下所示:
  1. trainableOnly:它說明是否隻保存了所述模型的可訓練權重,而忽略了不可訓練的權重。它是 Boolean 類型,默認為 false。
  2. includeOptimizer:它說明是否存儲指定的優化器。它是 Boolean 類型,默認為 false。

返回值:它返回 io.SaveResult 的承諾。

範例1:

Javascript


// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining model url
const model_Url =
'https://storage.googleapis.com/tfjs-models/savedmodel/mobilenet_v2_1.0_224/model.json';
  
// Calling the loadGraphModel() method
const mymodel = await tf.loadGraphModel(model_Url);
  
// Calling save() method
const output = await mymodel.save('downloads://mymodel');
  
// Printing output
console.log(output)

輸出:

{
  "modelArtifactsInfo":{
    "dateSaved":"2021-08-19T12:00:15.603Z",
    "modelTopologyType":"JSON",
    "modelTopologyBytes":90375,
    "weightSpecsBytes":15791,
    "weightDataBytes":13984940
  }
}

範例2:

Javascript


// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Calling the loadGraphModel() method
const mymodel = await tf.loadGraphModel(
'https://storage.googleapis.com/tfjs-models/savedmodel/mobilenet_v2_1.0_224/model.json');
  
// Calling save() method with all its
// parameters
const output = await mymodel.save('downloads://mymodel', true, true);
  
// Printing output
console.log(JSON.stringify(output))

輸出:

{"modelArtifactsInfo":{"dateSaved":"2021-08-19T12:05:35.906Z",
"modelTopologyType":"JSON","modelTopologyBytes":90375,
"weightSpecsBytes":15791,"weightDataBytes":13984940}}

參考: https://js.tensorflow.org/api/latest/#tf.GraphModel.save




相關用法


注:本文由純淨天空篩選整理自nidhi1352singh大神的英文原創作品 Tensorflow.js tf.GraphModel class .save() Method。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。