當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Tensorflow.js tf.Sequential.trainOnBatch()用法及代碼示例


Tensorflow.js是由Google開發的開源庫,用於在瀏覽器或節點環境中運行機器學習模型以及深度學習神經網絡。

.trainOnBatch() 函數用於對特定批次的數據運行單獨的梯度更新。

注意:

此方法與 fit() 和 fitDataset() 的不同之處如下:

  • 這種方法絕對適用於一批數據。
  • 此方法僅返回損失和度量值,而不是按批次損失和度量值返回批次。
  • 此方法不支持 fine-grained 選項,如冗長和回調。


用法:

trainOnBatch(x, y)

參數:

  • x:規定的輸入數據。它可以是 tf.Tensor、tf.Tensor[] 或 {[inputName:string]:tf.Tensor} 類型。它可以是以下任何一種:
    1. 指定的 tf.Tensor,或者如果指定的模型具有多個輸入,則為 tf.Tensor 數組。
    2. 將輸入名稱繪製到匹配的 tf.Tensor 的對象,以防所述模型擁有命名輸入。
  • y:所述的目標數據。它可以是 tf.Tensor、tf.Tensor[] 或 {[inputName:string]:tf.Tensor} 類型。它必須是關於 x 的常數。

返回值:它返回 number 或 number[] 的承諾。

範例1:

Javascript


// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Training Model 
const gfg = tf.sequential();
    
// Adding layer to model  
const layer = tf.layers.dense({units:3, 
               inputShape:[5]});
   gfg.add(layer);
      
// Compiling our model 
const config = {optimizer:'sgd', 
              loss:'meanSquaredError'};
  gfg.compile(config);
  
// Test tensor and target tensor
const xs = tf.ones([3, 5]);
const ys = tf.ones([3, 3]);
      
// Calling trainOneBatch() method
const result = await gfg.trainOnBatch(xs, ys);
  
// Printing output
console.log(result);

輸出:

0.3589147925376892

範例2:

Javascript


// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
async function run() {
  
  // Training Model 
  const gfg = tf.sequential();
    
  // Adding layer to model  
  const layer = tf.layers.dense({units:2, 
               inputShape:[2]});
  gfg.add(layer);
      
  // Compiling our model 
  const config = {optimizer:'sgd', 
              loss:'meanSquaredError'};
  gfg.compile(config);
  
  // Test tensor and target tensor
  const xs = tf.truncatedNormal([3, 2]);
  const ys = tf.randomNormal([3, 2]);
      
  // Calling trainOneBatch() method
  const result = await gfg.trainOnBatch(xs, ys);
  
  // Printing output
  console.log(JSON.stringify(+result));
}
    
// Function call
await run();

輸出:

1.6889342069625854

參考: https://js.tensorflow.org/api/latest/#tf.Sequential.trainOnBatch




相關用法


注:本文由純淨天空篩選整理自nidhi1352singh大神的英文原創作品 Tensorflow.js tf.Sequential class .trainOnBatch() Method。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。