当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Tensorflow.js tf.GraphModel.save()用法及代码示例


简介:Tensorflow.js 是 Google 开发的一个开源库,用于在浏览器或节点环境中运行机器学习模型以及深度学习神经网络。

.save() 函数用于保存所述 GraphModel 的结构和/或权重。

注意:

  • IOHandler 是一个对象,它拥有与指定的准确签名相关的保存方法。
  • save 方法控制顺序数据的累积或传输,即描述模型拓扑的工件以及特定介质上或通过特定介质的权重,如本地存储文件、文件下载、Web 浏览器中的 IndexedDB 以及 HTTP 请求服务器。
  • TensorFlow.js 启用 IOHandler 实现以支持许多重复使用的保存介质,例如 tf.io.browserDownloads() 和 tf.io.browserLocalStorage。
  • 此外,该方法还允许我们应用特定类型的 IOHandler,例如 URL-like 字符串技术,例如“localstorage://”和“indexeddb://”。

用法:

save(handlerOrURL, config?)


Parameters: 

  • handlerOrURL:IOHandler 的声明实例或类似的 URL,基于设计的字符串技术支持 IOHandler。它是 io.IOHandler|string 类型。
  • config:所述选项以保存所述模型。它是可选的并且是对象类型。它下面有两个参数,如下所示:
  1. trainableOnly:它说明是否只保存了所述模型的可训练权重,而忽略了不可训练的权重。它是 Boolean 类型,默认为 false。
  2. includeOptimizer:它说明是否存储指定的优化器。它是 Boolean 类型,默认为 false。

返回值:它返回 io.SaveResult 的承诺。

范例1:

Javascript


// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining model url
const model_Url =
'https://storage.googleapis.com/tfjs-models/savedmodel/mobilenet_v2_1.0_224/model.json';
  
// Calling the loadGraphModel() method
const mymodel = await tf.loadGraphModel(model_Url);
  
// Calling save() method
const output = await mymodel.save('downloads://mymodel');
  
// Printing output
console.log(output)

输出:

{
  "modelArtifactsInfo":{
    "dateSaved":"2021-08-19T12:00:15.603Z",
    "modelTopologyType":"JSON",
    "modelTopologyBytes":90375,
    "weightSpecsBytes":15791,
    "weightDataBytes":13984940
  }
}

范例2:

Javascript


// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Calling the loadGraphModel() method
const mymodel = await tf.loadGraphModel(
'https://storage.googleapis.com/tfjs-models/savedmodel/mobilenet_v2_1.0_224/model.json');
  
// Calling save() method with all its
// parameters
const output = await mymodel.save('downloads://mymodel', true, true);
  
// Printing output
console.log(JSON.stringify(output))

输出:

{"modelArtifactsInfo":{"dateSaved":"2021-08-19T12:05:35.906Z",
"modelTopologyType":"JSON","modelTopologyBytes":90375,
"weightSpecsBytes":15791,"weightDataBytes":13984940}}

参考: https://js.tensorflow.org/api/latest/#tf.GraphModel.save




相关用法


注:本文由纯净天空筛选整理自nidhi1352singh大神的英文原创作品 Tensorflow.js tf.GraphModel class .save() Method。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。