當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Tensorflow.js tf.losses.meanSquaredError()用法及代碼示例


Tensorflow.js 是穀歌開發的開源 JavaScript 庫,用於在瀏覽器和 node.js 環境中運行和訓練機器學習模型和深度學習神經網絡。

均方誤差是預測值和實際值之間的平方差的平均值。結果總是正數,萬一為 0.0,但永遠不會變成負數。在 tensorflow.js 庫中,我們使用 tf.losses.meanSquaredError() 函數來計算兩個張量之間的均方誤差。

用法:

tf.losses.meanSquaredError(labels, predictions, weights?, reduction?)

參數:

  • labels:這是計算預測差異的實際輸出張量。它可以是 tf.tensor、typedArray 或普通數組。
  • predictions:這是與標簽具有相同維度的預測輸出張量。它是 tf.tensor 或 typedArray 或普通數組。
  • weights:這可以是一個秩張量,或者等於標簽的秩以便它可以廣播,也可以是 0。它是可選的。
  • reduction:對損失應用減少。它是可選的。

返回值:tf.Tensor 由 meansquaredError 函數計算。



範例1:在這個例子中,我們將兩個二維張量作為標簽,另一個作為預測,然後找到這兩個的均方誤差。

Javascript


// Importing the tensorflow.js library
const tf = require("@tensorflow/tfjs");
  
// Defining label tensor
const y_true = tf.tensor2d([
    [0., 1., 0.], 
    [0., 0., 0.]
]);
  
// Defining prediction tensor
const y_pred = tf.tensor2d([
    [1., 1., 0.], 
    [1., 0., 0 ]
]);
  
// Calculating mean squared error
const mse = tf.losses.meanSquaredError(y_true,y_pred)
  
// Printing the output
mse.print()

輸出:

Tensor
    0.3333

例2:同理,我們再舉一個例子,在meanSquaredError函數中取rank的權重作為labels的權重,然後計算均方誤差。

Javascript


// Importing the tensorflow.js library
const tf = require("@tensorflow/tfjs");
  
// Defining label tensor
const y_true = tf.tensor2d(
    [0., 1., 0., 0., 0., 0., 1., 
    0., 1., 1., 0., 1.], [4, 3]
);
  
// Defining predicted tensor
const y_pred = tf.tensor2d(
    [1., 1., 0., 1., 0., 0., 1., 
    1., 1., 0., 0., 1.], [4, 3]
);
  
// Calculating meansquared error
const mse = tf.losses.meanSquaredError(
        y_true, y_pred, [0.7, 0.3, 0.2],)
  
mse.print()

輸出:

Tensor
    0.2000

範例3:在設計模型的編譯函數中,我們使用“均方誤差”作為損失參數。以下是一個簡單的神經網絡,我們在其中進行計算。

Javascript


// Importing the tensorflow.js library
const tf = require("@tensorflow/tfjs");
  
// Define the model
const model = tf.sequential({
    layers:[tf.layers.dense({ 
        units:1, inputShape:[12] 
    })],
});
  
// In model compilation we pass
// meanSquaredError as the parameter
  
model.compile(
    { optimizer:"adam", loss:"meanSquaredError" },
    (metrics = ["accuracy"])
);
  
// Evaluate the model which was compiled above
// computation is done in batches of size 4
const result = model.evaluate(
    tf.ones([10, 12]), tf.ones([10, 1]), {
        batchSize:4,
    }
);
  
// Print the result
result.print();

輸出:

Tensor
    0.4817

參考:https://js.tensorflow.org/api/3.6.0/#metrics.meanSquaredError

相關用法


注:本文由純淨天空篩選整理自barnadipdey2510大神的英文原創作品 Tensorflow.js tf.losses.meanSquaredError() Function。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。