当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Tensorflow.js tf.LayersModel.fit()用法及代码示例


Tensorflow.js是Google开发的开源库,用于在浏览器或节点环境中运行机器学习模型和深度学习神经网络。

tf.LayersModel 类的 .fit() 方法用于为固定数量的时期(数据集上的迭代)训练模型。

用法:

fit(x, y, args?)

参数:此方法接受以下参数。

  • x:tf.Tensor 包含所有输入数据。
  • y:tf.Tensor 包含所有输出数据。
  • args:它是对象类型,它的变量如下:
    1. batchSize:它定义了将通过训练传播的样本数量。
    2. epochs:它定义了训练数据数组的迭代。
    3. verbose:它有助于显示每个时期的进度。如果值为 0 - 表示在 fit() 调用期间没有打印消息。如果值为 1 - 这意味着在 Node-js 中,它会打印进度条。在浏览器中,它不显示任何操作。值 1 是默认值。 2 - 值 2 尚未实现。
    4. callbacks:它定义了在训练期间要调用的回调列表。变量可以有一个或多个这些回调onTrainBegin()、onTrainEnd()、onEpochBegin()、onEpochEnd()、onBatchBegin()、onBatchEnd()、onYield()。
    5. validationSplit:它使用户可以轻松地将训练数据集拆分为训练和验证。例如:如果它的值为 validation-Split = 0.5 ,则表示使用 shuffle 前最后 50% 的数据进行验证。
    6. validationData:它用于在最终模型之间进行选择时给出最终模型的估计。
    7. shuffle:该值定义了每个 epoch 之前数据的 shuffle。当stepsPerEpoch 不为空时,它不起作用。
    8. classWeight:它用于对损失函数进行加权。告诉模型更多地关注来自 under-represented 类的样本会很有用。
    9. sampleWeight:它是应用于每个样本的模型损失的权重数组。
    10. initialEpoch:它是值定义开始训练的时期。这对于恢复之前的训练运行很有用。
    11. stepsPerEpoch:它在宣布一个 epoch 完成并开始下一个 epoch 之前定义了许多批次的样本。如果未确定,则等于 1。
    12. validationSteps:如果指定了 stepsPerEpoch,则相关。停止前要验证的总步数。
    13. yieldEvery:它定义了将主线程交给其他任务的频率的配置。它可以是自动的,这意味着屈服发生在一定的帧率下。批次,如果值是这个,它会产生每个批次。 epoch,如果值是这个,它产生每个纪元。任何数字,如果该值是任何数字,则每毫秒产生一个数字。从不,如果该值是这个,则它永远不会产生。

返回值:它返回了历史的承诺。

范例1:

Javascript


// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining model
const mymodel = tf.sequential({
     layers:[tf.layers.dense({units:2, inputShape:[6]})]
});
  
// Compiling the above model
mymodel.compile({optimizer:'sgd', loss:'meanSquaredError'});
  
// Using for loop
for (let i = 0; i < 4; i++) {
     
  // Calling fit() method
  const his = await mymodel.fit(tf.zeros([6, 6]), tf.ones([6, 2]), {
       batchSize:5,
       epochs:4
   });
    
    // Printing output
    console.log(his.history.loss[1]);
}

输出:

0.9574100375175476
0.8151942491531372
0.694103479385376
0.5909997820854187

范例2:

Javascript


// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining model
const mymodel = tf.sequential({
     layers:[tf.layers.dense({units:2, inputShape:[6], 
                               activation:"sigmoid"})]});
  
// Compiling the above model
mymodel.compile({optimizer:'sgd', loss:'meanSquaredError'});   
  
// Calling fit() method
 const his = await mymodel.fit(tf.truncatedNormal([6, 6]), 
                             tf.randomNormal([6, 2]), { batchSize:5,
                             epochs:4, validationSplit:0.2, 
                             shuffle:true, initialEpoch:2, 
                             stepsPerEpoch:1, validationSteps:2});
    
// Printing output
console.log(JSON.stringify(his.history));

输出:

{"val_loss":[0.35800713300704956,0.35819053649902344],
"loss":[0.633269190788269,0.632409930229187]}

参考: https://js.tensorflow.org/api/latest/#tf.LayersModel.fit


相关用法


注:本文由纯净天空筛选整理自nidhi1352singh大神的英文原创作品 Tensorflow.js tf.LayersModel class .fit() Method。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。