Tensorflow.js 是 Google 开发的一个开源库,用于在浏览器或节点环境中运行机器学习模型和深度学习神经网络。 Tensorflow.js tf.Sequential 类 .fitDataset() 方法用于使用数据集对象训练模型。
用法:
model.fitDataset(dataset, args);
参数:该方法包含以下参数:
- dataset:它是输入值的数据集。它可以是原始数据集、数组或对象。
- args:它包含以下值:
- epochs:它是训练模型期间训练数据集中的总通过次数。它是一个整数值。
- batchesPerEpoch:它定义了每个 epoch 中的批次数。它的值取决于批量大小,随着批量大小的增加而减小。
- verbose:它有助于显示每个时期的进度。如果值为 0 - 表示在 fit() 调用期间没有打印消息。如果值为 1 - 这意味着在 Node.js 中,它会打印进度条。在浏览器中它不显示任何操作。值 1 是默认值。 2 - 值 2 尚未实现。
- callbacks:它定义了在训练期间要调用的回调列表。变量可以有一个或多个这些回调onTrainBegin()、onTrainEnd()、onEpochBegin()、onEpochEnd()、onBatchBegin()、onBatchEnd()、onYield()。
- validationData:它用于在最终模型之间进行选择时给出最终模型的估计。这可以是以下任何一种:[ xVal, yVal ] 的数组,具有 { xs:xVal, ys:yVal } 形式元素的数据集对象。
- validationBatchSize:它是定义批次大小的数字。它用于验证批量大小。这意味着我们不能一次性放置超过这个值的所有数据集。其默认值为 32。
- validationBatches:它用于验证批次的样品。它用于在 epoch 的每一端绘制验证数据以进行验证。
- classWeight:它用于对损失函数进行加权。告诉模型更多地关注来自 under-represented 类的样本会很有用。
- initialEpoch:它用于定义开始训练的纪元值。这对于恢复之前的训练运行很有用。
- yieldEvery:它定义了将主线程交给其他任务的频率的配置。它可以是自动的,这意味着屈服发生在一定的帧率下。批次,如果值是这个,它会产生每个批次。 epoch,如果值是这个,它产生每个纪元。任何数字,如果该值是任何数字,它会产生每个数字毫秒。从不,如果值是这个,它永远不会产生。
返回值:承诺<历史>
范例1:在这个例子中,我们将使用数组数据集训练我们的模型。
Javascript
import * as tf from "@tensorflow/tfjs"
// Creating model
const gfg_Model = tf.sequential() ;
// Adding layer to model
const config = {units:1, inputShape:[2]}
const gfg_layer = tf.layers.dense(config);
gfg_Model.add(gfg_layer);
// Compiling the model
const config2 = {optimizer:'sgd', loss:'meanSquaredError'}
gfg_Model.compile(config2);
// Creating Datasets for training
const array1 = [[1,2], [1,4], [1,3], [3,4]];
const array2 = [1, 1];
const arrData1 = tf.data.array(array1);
const arrData2 = tf.data.array(array2);
const config3 = {xs:arrData1, ys:arrData2}
const arrayDataset = tf.data.zip(config3)
const ArrayDataset = arrayDataset.batch(3).shuffle(6);
// Training the model
const Tm = await gfg_Model.fitDataset(ArrayDataset, { epochs:3 });
// Printing the loss after training
console.log("Loss " + ":" + Tm.history.loss[0]);
输出:
Loss:0.428712397813797
范例2:在这个例子中,我们将使用由 csv 文件制作的数据集来训练我们的模型。
Javascript
import * as tf from "@tensorflow/tfjs";
// Path for the CSV file
const gfg_CsvFile =
"https://storage.googleapis.com/tfjs-examples/multivariate-linear-regression/data/boston-housing-train.csv";
// Creating model
const gfg_Model = tf.sequential();
// Adding layer to model
const config = { units:1, inputShape:[12] };
const gfg_layer = tf.layers.dense(config);
gfg_Model.add(gfg_layer);
// Compiling the model
const opt = tf.train.sgd(0.0001);
gfg_Model.compile({ optimizer:opt, loss:"meanSquaredError" });
// Here we want to predict column tax
const config2 = { columnConfigs:{ tax:{ isLabel:true } } };
const csvDataset = tf.data.csv(gfg_CsvFile, config2);
// Creating dataset for training
const flattenedDataset = csvDataset
.map(({ xs, ys }) => {
return { xs:Object.values(xs), ys:Object.values(ys) };
})
.batch(5);
// Training the model
const Tm = await gfg_Model.fitDataset(flattenedDataset, { epochs:5 });
for (let i = 0; i < 5; i++) {
console.log(Tm.history.loss[i]);
}
输出:
21489.68359375 8750.29296875 6632.365234375 5908.6171875 5546.45654296875
参考: https://js.tensorflow.org/api/latest/#tf.Sequential.fitDataset
相关用法
- Tensorflow.js tf.LayersModel.fitDataset()用法及代码示例
- Tensorflow.js tf.Tensor.buffer()用法及代码示例
- Java String repeat()用法及代码示例
- Tensorflow.js tf.LayersModel.evaluate()用法及代码示例
- Tensorflow.js tf.data.Dataset.batch()用法及代码示例
- Tensorflow.js tf.Sequential.add()用法及代码示例
注:本文由纯净天空筛选整理自satyam00so大神的英文原创作品 Tensorflow.js tf.Sequential class .fitDataset() Method。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。