當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Tensorflow.js tf.callbacks.earlyStopping()用法及代碼示例


Tensorflow.js是Google開發的開源庫,用於在瀏覽器或節點環境中運行機器學習模型和深度學習神經網絡。

Tensorflow.js tf.callbacks.earlyStopping() 是一個回調函數,用於在訓練數據停止改善時停止訓練。

用法:

tf.callbacks.earlyStopping(args);

參數:該方法接受以下參數。

  • args: 它是一個具有以下字段的對象:
    • monitor: 它應該是一個字符串。這是要監視的值。
    • minDelta: 它應該是一個數字。這是最小值,低於該值不被視為訓練的改進。
    • patience: 它應該是一個數字。它是當遇到低於 minDelta 的值時不應停止的次數。
    • verbose: 它應該是一個數字。這就是冗長的值。
    • mode: 它應該是以下三個之一:
      • “auto”:在自動模式下,方向是根據監控量的名稱自動推斷的。
      • “min”:在min模式下,當監測到的數據值停止減少時,訓練將停止。
      • “max”:在 max 模式下,當監控的數據值停止增加時,訓練將停止。
    • baseline: 它應該是一個數字。這個數字表明訓練何時跟不上這個值,訓練就會停止。它是受監控數量的結束線。
    • restoreBestWeights: 它應該是一個布爾值。它告訴我們是否從每個時期的監控數量中恢複最佳值。

返回值:它返回一個對象(EarlyStopping)。

下麵是該函數的一些示例。

示例 1:在此示例中,我們將看到如何在 fitDataset 中使用 tf.callbacks.earlyStopping() 函數:

Javascript


import * as tf from "@tensorflow/tfjs"; 
  
const xArray = [ 
    [1, 2, 3, 4], 
    [5, 6, 7, 8], 
    [8, 7, 6, 5], 
    [1, 2, 3, 4], 
]; 
  
const x1Array = [ 
    [0, 1, 0.5, 0], 
    [1, 0.5, 0, 1], 
    [0.5, 1, 1, 0], 
    [1, 0, 0, 1], 
]; 
  
const yArray = [1, 2, 3, 4]; 
const y1Array = [4, 3, 2, 1]; 
  
// Create a dataset from the JavaScript array. 
const xDataset = tf.data.array(xArray); 
const x1Dataset = tf.data.array(x1Array); 
const y1Dataset = tf.data.array(x1Array); 
const yDataset = tf.data.array(yArray); 
  
// Combining the Dataset with zip function 
const xyDataset = tf.data 
    .zip({ xs: xDataset, ys: yDataset }) 
    .batch(4) 
    .shuffle(4); 
const xy1Dataset = tf.data 
    .zip({ xs: x1Dataset, ys: y1Dataset }) 
    .batch(4) 
    .shuffle(4); 
  
// Creating model 
const model = tf.sequential(); 
model.add( 
    tf.layers.dense({ 
        units: 1, 
        inputShape: [4], 
    }) 
); 
  
// Compiling model 
model.compile({ loss: "meanSquaredError",  
    optimizer: "sgd", metrics: ["acc"] }); 
  
// Using tf.callbacks.earlyStopping in fitDataset. 
const history = await model.fitDataset(xyDataset, { 
    epochs: 10, 
    validationData: xy1Dataset, 
    callbacks: tf.callbacks.earlyStopping({  
        monitor: "val_acc" }), 
}); 
  
// Printing value 
console.log("The value of val_acc is :",  
    history.history.val_acc);

輸出:您獲得的值會有所不同,因為隨著訓練值的變化,val_acc 值會發生變化。

The value of val_acc is :0.4375,0.375

示例 2:在此示例中,我們將了解如何將 tf.callbacks.earlyStopping() 與 fit 結合使用:

Javascript


import * as tf from "@tensorflow/tfjs"; 
  
// Creating tensor for training 
const x = tf.tensor([5, 6, 7, 8, 9, 2], [3, 2]); 
const x1 = tf.tensor([8, 7, 6, 5, 2, 9], [3, 2]); 
const y = tf.tensor([1, 3, 3, 4, 4, 6, 6, 8, 9], [3, 3]); 
const y1 = tf.tensor([2, 2, 2, 1, 5, 5, 2, 3, 8], [3, 3]); 
  
// Creating model 
const model = tf.sequential(); 
  
model.add( 
    tf.layers.dense({ 
        units: 3, 
        inputShape: [2], 
    }) 
); 
  
// Compiling model 
model.compile({ loss: "meanSquaredError",  
    optimizer: "sgd", metrics: ["acc"] }); 
  
// Using tf.callbacks.earlyStopping in fit. 
const history = await model.fit(x, y, { 
    epochs: 10, 
    validationData: [x1, y1], 
    callbacks: tf.callbacks.earlyStopping({  
        monitor: "val_acc" }), 
}); 
  
// Printing value 
console.log("the value of val_acc is :",  
    history.history.val_acc);

輸出:執行代碼的值將會不同,因為訓練數據值會發生變化:

the value of val_acc is : 0.3333333432674408,0.3333333432674408

參考:https://js.tensorflow.org/api/latest/#callbacks.earlyStopping



相關用法


注:本文由純淨天空篩選整理自satyam00so大神的英文原創作品 Tensorflow.js tf.callbacks.earlyStopping() Function。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。