当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Tensorflow.js tf.callbacks.earlyStopping()用法及代码示例


Tensorflow.js是Google开发的开源库,用于在浏览器或节点环境中运行机器学习模型和深度学习神经网络。

Tensorflow.js tf.callbacks.earlyStopping() 是一个回调函数,用于在训练数据停止改善时停止训练。

用法:

tf.callbacks.earlyStopping(args);

参数:该方法接受以下参数。

  • args: 它是一个具有以下字段的对象:
    • monitor: 它应该是一个字符串。这是要监视的值。
    • minDelta: 它应该是一个数字。这是最小值,低于该值不被视为训练的改进。
    • patience: 它应该是一个数字。它是当遇到低于 minDelta 的值时不应停止的次数。
    • verbose: 它应该是一个数字。这就是冗长的值。
    • mode: 它应该是以下三个之一:
      • “auto”:在自动模式下,方向是根据监控量的名称自动推断的。
      • “min”:在min模式下,当监测到的数据值停止减少时,训练将停止。
      • “max”:在 max 模式下,当监控的数据值停止增加时,训练将停止。
    • baseline: 它应该是一个数字。这个数字表明训练何时跟不上这个值,训练就会停止。它是受监控数量的结束线。
    • restoreBestWeights: 它应该是一个布尔值。它告诉我们是否从每个时期的监控数量中恢复最佳值。

返回值:它返回一个对象(EarlyStopping)。

下面是该函数的一些示例。

示例 1:在此示例中,我们将看到如何在 fitDataset 中使用 tf.callbacks.earlyStopping() 函数:

Javascript


import * as tf from "@tensorflow/tfjs"; 
  
const xArray = [ 
    [1, 2, 3, 4], 
    [5, 6, 7, 8], 
    [8, 7, 6, 5], 
    [1, 2, 3, 4], 
]; 
  
const x1Array = [ 
    [0, 1, 0.5, 0], 
    [1, 0.5, 0, 1], 
    [0.5, 1, 1, 0], 
    [1, 0, 0, 1], 
]; 
  
const yArray = [1, 2, 3, 4]; 
const y1Array = [4, 3, 2, 1]; 
  
// Create a dataset from the JavaScript array. 
const xDataset = tf.data.array(xArray); 
const x1Dataset = tf.data.array(x1Array); 
const y1Dataset = tf.data.array(x1Array); 
const yDataset = tf.data.array(yArray); 
  
// Combining the Dataset with zip function 
const xyDataset = tf.data 
    .zip({ xs: xDataset, ys: yDataset }) 
    .batch(4) 
    .shuffle(4); 
const xy1Dataset = tf.data 
    .zip({ xs: x1Dataset, ys: y1Dataset }) 
    .batch(4) 
    .shuffle(4); 
  
// Creating model 
const model = tf.sequential(); 
model.add( 
    tf.layers.dense({ 
        units: 1, 
        inputShape: [4], 
    }) 
); 
  
// Compiling model 
model.compile({ loss: "meanSquaredError",  
    optimizer: "sgd", metrics: ["acc"] }); 
  
// Using tf.callbacks.earlyStopping in fitDataset. 
const history = await model.fitDataset(xyDataset, { 
    epochs: 10, 
    validationData: xy1Dataset, 
    callbacks: tf.callbacks.earlyStopping({  
        monitor: "val_acc" }), 
}); 
  
// Printing value 
console.log("The value of val_acc is :",  
    history.history.val_acc);

输出:您获得的值会有所不同,因为随着训练值的变化,val_acc 值会发生变化。

The value of val_acc is :0.4375,0.375

示例 2:在此示例中,我们将了解如何将 tf.callbacks.earlyStopping() 与 fit 结合使用:

Javascript


import * as tf from "@tensorflow/tfjs"; 
  
// Creating tensor for training 
const x = tf.tensor([5, 6, 7, 8, 9, 2], [3, 2]); 
const x1 = tf.tensor([8, 7, 6, 5, 2, 9], [3, 2]); 
const y = tf.tensor([1, 3, 3, 4, 4, 6, 6, 8, 9], [3, 3]); 
const y1 = tf.tensor([2, 2, 2, 1, 5, 5, 2, 3, 8], [3, 3]); 
  
// Creating model 
const model = tf.sequential(); 
  
model.add( 
    tf.layers.dense({ 
        units: 3, 
        inputShape: [2], 
    }) 
); 
  
// Compiling model 
model.compile({ loss: "meanSquaredError",  
    optimizer: "sgd", metrics: ["acc"] }); 
  
// Using tf.callbacks.earlyStopping in fit. 
const history = await model.fit(x, y, { 
    epochs: 10, 
    validationData: [x1, y1], 
    callbacks: tf.callbacks.earlyStopping({  
        monitor: "val_acc" }), 
}); 
  
// Printing value 
console.log("the value of val_acc is :",  
    history.history.val_acc);

输出:执行代码的值将会不同,因为训练数据值会发生变化:

the value of val_acc is : 0.3333333432674408,0.3333333432674408

参考:https://js.tensorflow.org/api/latest/#callbacks.earlyStopping



相关用法


注:本文由纯净天空筛选整理自satyam00so大神的英文原创作品 Tensorflow.js tf.callbacks.earlyStopping() Function。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。