Tensorflow.js 是 Google 開發的一個開源庫,用於在瀏覽器或節點環境中運行機器學習模型和深度學習神經網絡。 Tensorflow.js tf.Sequential 類 .fitDataset() 方法用於使用數據集對象訓練模型。
用法:
model.fitDataset(dataset, args);
參數:該方法包含以下參數:
- dataset:它是輸入值的數據集。它可以是原始數據集、數組或對象。
- args:它包含以下值:
- epochs:它是訓練模型期間訓練數據集中的總通過次數。它是一個整數值。
- batchesPerEpoch:它定義了每個 epoch 中的批次數。它的值取決於批量大小,隨著批量大小的增加而減小。
- verbose:它有助於顯示每個時期的進度。如果值為 0 - 表示在 fit() 調用期間沒有打印消息。如果值為 1 - 這意味著在 Node.js 中,它會打印進度條。在瀏覽器中它不顯示任何操作。值 1 是默認值。 2 - 值 2 尚未實現。
- callbacks:它定義了在訓練期間要調用的回調列表。變量可以有一個或多個這些回調onTrainBegin()、onTrainEnd()、onEpochBegin()、onEpochEnd()、onBatchBegin()、onBatchEnd()、onYield()。
- validationData:它用於在最終模型之間進行選擇時給出最終模型的估計。這可以是以下任何一種:[ xVal, yVal ] 的數組,具有 { xs:xVal, ys:yVal } 形式元素的數據集對象。
- validationBatchSize:它是定義批次大小的數字。它用於驗證批量大小。這意味著我們不能一次性放置超過這個值的所有數據集。其默認值為 32。
- validationBatches:它用於驗證批次的樣品。它用於在 epoch 的每一端繪製驗證數據以進行驗證。
- classWeight:它用於對損失函數進行加權。告訴模型更多地關注來自 under-represented 類的樣本會很有用。
- initialEpoch:它用於定義開始訓練的紀元值。這對於恢複之前的訓練運行很有用。
- yieldEvery:它定義了將主線程交給其他任務的頻率的配置。它可以是自動的,這意味著屈服發生在一定的幀率下。批次,如果值是這個,它會產生每個批次。 epoch,如果值是這個,它產生每個紀元。任何數字,如果該值是任何數字,它會產生每個數字毫秒。從不,如果值是這個,它永遠不會產生。
返回值:承諾<曆史>
範例1:在這個例子中,我們將使用數組數據集訓練我們的模型。
Javascript
import * as tf from "@tensorflow/tfjs"
// Creating model
const gfg_Model = tf.sequential() ;
// Adding layer to model
const config = {units:1, inputShape:[2]}
const gfg_layer = tf.layers.dense(config);
gfg_Model.add(gfg_layer);
// Compiling the model
const config2 = {optimizer:'sgd', loss:'meanSquaredError'}
gfg_Model.compile(config2);
// Creating Datasets for training
const array1 = [[1,2], [1,4], [1,3], [3,4]];
const array2 = [1, 1];
const arrData1 = tf.data.array(array1);
const arrData2 = tf.data.array(array2);
const config3 = {xs:arrData1, ys:arrData2}
const arrayDataset = tf.data.zip(config3)
const ArrayDataset = arrayDataset.batch(3).shuffle(6);
// Training the model
const Tm = await gfg_Model.fitDataset(ArrayDataset, { epochs:3 });
// Printing the loss after training
console.log("Loss " + ":" + Tm.history.loss[0]);
輸出:
Loss:0.428712397813797
範例2:在這個例子中,我們將使用由 csv 文件製作的數據集來訓練我們的模型。
Javascript
import * as tf from "@tensorflow/tfjs";
// Path for the CSV file
const gfg_CsvFile =
"https://storage.googleapis.com/tfjs-examples/multivariate-linear-regression/data/boston-housing-train.csv";
// Creating model
const gfg_Model = tf.sequential();
// Adding layer to model
const config = { units:1, inputShape:[12] };
const gfg_layer = tf.layers.dense(config);
gfg_Model.add(gfg_layer);
// Compiling the model
const opt = tf.train.sgd(0.0001);
gfg_Model.compile({ optimizer:opt, loss:"meanSquaredError" });
// Here we want to predict column tax
const config2 = { columnConfigs:{ tax:{ isLabel:true } } };
const csvDataset = tf.data.csv(gfg_CsvFile, config2);
// Creating dataset for training
const flattenedDataset = csvDataset
.map(({ xs, ys }) => {
return { xs:Object.values(xs), ys:Object.values(ys) };
})
.batch(5);
// Training the model
const Tm = await gfg_Model.fitDataset(flattenedDataset, { epochs:5 });
for (let i = 0; i < 5; i++) {
console.log(Tm.history.loss[i]);
}
輸出:
21489.68359375 8750.29296875 6632.365234375 5908.6171875 5546.45654296875
參考: https://js.tensorflow.org/api/latest/#tf.Sequential.fitDataset
相關用法
- Tensorflow.js tf.LayersModel.fitDataset()用法及代碼示例
- Tensorflow.js tf.Tensor.buffer()用法及代碼示例
- Java String repeat()用法及代碼示例
- Tensorflow.js tf.LayersModel.evaluate()用法及代碼示例
- Tensorflow.js tf.data.Dataset.batch()用法及代碼示例
- Tensorflow.js tf.Sequential.add()用法及代碼示例
注:本文由純淨天空篩選整理自satyam00so大神的英文原創作品 Tensorflow.js tf.Sequential class .fitDataset() Method。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。