Tensorflow.js是Google开发的开源库,用于在浏览器或节点环境中运行机器学习模型和深度学习神经网络。
tf.LayersModel 类的 .fit() 方法用于为固定数量的时期(数据集上的迭代)训练模型。
用法:
fit(x, y, args?)
参数:此方法接受以下参数。
- x:tf.Tensor 包含所有输入数据。
- y:tf.Tensor 包含所有输出数据。
- args:它是对象类型,它的变量如下:
- batchSize:它定义了将通过训练传播的样本数量。
- epochs:它定义了训练数据数组的迭代。
- verbose:它有助于显示每个时期的进度。如果值为 0 - 表示在 fit() 调用期间没有打印消息。如果值为 1 - 这意味着在 Node-js 中,它会打印进度条。在浏览器中,它不显示任何操作。值 1 是默认值。 2 - 值 2 尚未实现。
- callbacks:它定义了在训练期间要调用的回调列表。变量可以有一个或多个这些回调onTrainBegin()、onTrainEnd()、onEpochBegin()、onEpochEnd()、onBatchBegin()、onBatchEnd()、onYield()。
- validationSplit:它使用户可以轻松地将训练数据集拆分为训练和验证。例如:如果它的值为 validation-Split = 0.5 ,则表示使用 shuffle 前最后 50% 的数据进行验证。
- validationData:它用于在最终模型之间进行选择时给出最终模型的估计。
- shuffle:该值定义了每个 epoch 之前数据的 shuffle。当stepsPerEpoch 不为空时,它不起作用。
- classWeight:它用于对损失函数进行加权。告诉模型更多地关注来自 under-represented 类的样本会很有用。
- sampleWeight:它是应用于每个样本的模型损失的权重数组。
- initialEpoch:它是值定义开始训练的时期。这对于恢复之前的训练运行很有用。
- stepsPerEpoch:它在宣布一个 epoch 完成并开始下一个 epoch 之前定义了许多批次的样本。如果未确定,则等于 1。
- validationSteps:如果指定了 stepsPerEpoch,则相关。停止前要验证的总步数。
- yieldEvery:它定义了将主线程交给其他任务的频率的配置。它可以是自动的,这意味着屈服发生在一定的帧率下。批次,如果值是这个,它会产生每个批次。 epoch,如果值是这个,它产生每个纪元。任何数字,如果该值是任何数字,则每毫秒产生一个数字。从不,如果该值是这个,则它永远不会产生。
返回值:它返回了历史的承诺。
范例1:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Defining model
const mymodel = tf.sequential({
layers:[tf.layers.dense({units:2, inputShape:[6]})]
});
// Compiling the above model
mymodel.compile({optimizer:'sgd', loss:'meanSquaredError'});
// Using for loop
for (let i = 0; i < 4; i++) {
// Calling fit() method
const his = await mymodel.fit(tf.zeros([6, 6]), tf.ones([6, 2]), {
batchSize:5,
epochs:4
});
// Printing output
console.log(his.history.loss[1]);
}
输出:
0.9574100375175476 0.8151942491531372 0.694103479385376 0.5909997820854187
范例2:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Defining model
const mymodel = tf.sequential({
layers:[tf.layers.dense({units:2, inputShape:[6],
activation:"sigmoid"})]});
// Compiling the above model
mymodel.compile({optimizer:'sgd', loss:'meanSquaredError'});
// Calling fit() method
const his = await mymodel.fit(tf.truncatedNormal([6, 6]),
tf.randomNormal([6, 2]), { batchSize:5,
epochs:4, validationSplit:0.2,
shuffle:true, initialEpoch:2,
stepsPerEpoch:1, validationSteps:2});
// Printing output
console.log(JSON.stringify(his.history));
输出:
{"val_loss":[0.35800713300704956,0.35819053649902344], "loss":[0.633269190788269,0.632409930229187]}
参考: https://js.tensorflow.org/api/latest/#tf.LayersModel.fit
相关用法
注:本文由纯净天空筛选整理自nidhi1352singh大神的英文原创作品 Tensorflow.js tf.LayersModel class .fit() Method。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。