Tensorflow.js是由Google開發的開源庫,用於在瀏覽器或節點環境中運行機器學習模型以及深度學習神經網絡。
.trainOnBatch() 函數用於對特定批次的數據運行單獨的梯度更新。
注意:此方法與 fit() 和 fitDataset() 的不同之處如下:
- 這種方法絕對適用於一批數據。
- 此方法僅返回損失和度量值,而不是按批次損失和度量值返回批次。
- 這種方法不支持像冗長和回調這樣的細粒度選項。
用法:
trainOnBatch(x, y)
參數:
- x:規定的輸入數據。它可以是 tf.Tensor、tf.Tensor[] 或 {[inputName:string]:tf.Tensor} 類型。它可以是以下任何一種:
- 指定的 tf.Tensor,或者如果指定的模型具有多個輸入,則為 tf.Tensor 數組。
- 將輸入名稱繪製到匹配的 tf.Tensor 的對象,以防所述模型擁有命名輸入。
- y:所述的目標數據。它可以是 tf.Tensor、tf.Tensor[] 或 {[inputName:string]:tf.Tensor} 類型。它必須是關於 x 的常數。
返回值:它返回 number 或 number[] 的承諾。
範例1:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Training Model
const mymodel = tf.sequential(
{layers:[tf.layers.dense({units:2, inputShape:[2]})]});
// Compiling our model
const config = {optimizer:'sgd',
loss:'meanSquaredError'};
mymodel.compile(config);
// Test tensor and target tensor
const xs = tf.ones([3,2]);
const ys = tf.ones([3,2]);
// Calling trainOneBatch() method
const result = await mymodel.trainOnBatch(xs, ys);
// Printing output
console.log(result);
輸出:
2.0696773529052734
範例2:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
async function run() {
// Training Model
const mymodel = tf.sequential(
{layers:[tf.layers.dense({units:2, inputShape:[2],
activation:'sigmoid'})]});
// Compiling our model
const config = {optimizer:'sgd',
loss:'meanSquaredError'};
mymodel.compile(config);
// Test tensor and target tensor
const xs = tf.truncatedNormal([3,2]);
const ys = tf.randomNormal([3,2]);
// Calling trainOneBatch() method
const result = await mymodel.trainOnBatch(xs, ys);
// Printing output
console.log(JSON.stringify(+result));
}
// Function call
await run();
輸出:
0.5935208797454834
參考: https://js.tensorflow.org/api/latest/#tf.LayersModel.trainOnBatch
相關用法
- Tensorflow.js tf.Sequential.trainOnBatch()用法及代碼示例
- Tensorflow.js tf.Tensor.buffer()用法及代碼示例
- Java String repeat()用法及代碼示例
- Tensorflow.js tf.LayersModel.evaluate()用法及代碼示例
- Tensorflow.js tf.data.Dataset.batch()用法及代碼示例
- Tensorflow.js tf.Sequential.add()用法及代碼示例
注:本文由純淨天空篩選整理自nidhi1352singh大神的英文原創作品 Tensorflow.js tf.LayersModel class .trainOnBatch() Method。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。