Tensorflow.js 是谷歌开发的开源 JavaScript 库,用于在浏览器和 node.js 环境中运行和训练机器学习模型和深度学习神经网络。
均方误差是预测值和实际值之间的平方差的平均值。结果总是正数,万一为 0.0,但永远不会变成负数。在 tensorflow.js 库中,我们使用 tf.losses.meanSquaredError() 函数来计算两个张量之间的均方误差。
用法:
tf.losses.meanSquaredError(labels, predictions, weights?, reduction?)
参数:
- labels:这是计算预测差异的实际输出张量。它可以是 tf.tensor、typedArray 或普通数组。
- predictions:这是与标签具有相同维度的预测输出张量。它是 tf.tensor 或 typedArray 或普通数组。
- weights:这可以是一个秩张量,或者等于标签的秩以便它可以广播,也可以是 0。它是可选的。
- reduction:对损失应用减少。它是可选的。
返回值:tf.Tensor 由 meansquaredError 函数计算。
范例1:在这个例子中,我们将两个二维张量作为标签,另一个作为预测,然后找到这两个的均方误差。
Javascript
// Importing the tensorflow.js library
const tf = require("@tensorflow/tfjs");
// Defining label tensor
const y_true = tf.tensor2d([
[0., 1., 0.],
[0., 0., 0.]
]);
// Defining prediction tensor
const y_pred = tf.tensor2d([
[1., 1., 0.],
[1., 0., 0 ]
]);
// Calculating mean squared error
const mse = tf.losses.meanSquaredError(y_true,y_pred)
// Printing the output
mse.print()
输出:
Tensor 0.3333
例2:同理,我们再举一个例子,在meanSquaredError函数中取rank的权重作为labels的权重,然后计算均方误差。
Javascript
// Importing the tensorflow.js library
const tf = require("@tensorflow/tfjs");
// Defining label tensor
const y_true = tf.tensor2d(
[0., 1., 0., 0., 0., 0., 1.,
0., 1., 1., 0., 1.], [4, 3]
);
// Defining predicted tensor
const y_pred = tf.tensor2d(
[1., 1., 0., 1., 0., 0., 1.,
1., 1., 0., 0., 1.], [4, 3]
);
// Calculating meansquared error
const mse = tf.losses.meanSquaredError(
y_true, y_pred, [0.7, 0.3, 0.2],)
mse.print()
输出:
Tensor 0.2000
范例3:在设计模型的编译函数中,我们使用“均方误差”作为损失参数。以下是一个简单的神经网络,我们在其中进行计算。
Javascript
// Importing the tensorflow.js library
const tf = require("@tensorflow/tfjs");
// Define the model
const model = tf.sequential({
layers:[tf.layers.dense({
units:1, inputShape:[12]
})],
});
// In model compilation we pass
// meanSquaredError as the parameter
model.compile(
{ optimizer:"adam", loss:"meanSquaredError" },
(metrics = ["accuracy"])
);
// Evaluate the model which was compiled above
// computation is done in batches of size 4
const result = model.evaluate(
tf.ones([10, 12]), tf.ones([10, 1]), {
batchSize:4,
}
);
// Print the result
result.print();
输出:
Tensor 0.4817
参考:https://js.tensorflow.org/api/3.6.0/#metrics.meanSquaredError
相关用法
- PHP imagecreatetruecolor()用法及代码示例
- p5.js year()用法及代码示例
- d3.js d3.utcTuesdays()用法及代码示例
- PHP ImagickDraw getTextAlignment()用法及代码示例
- PHP Ds\Sequence last()用法及代码示例
- PHP Imagick floodFillPaintImage()用法及代码示例
- PHP array_udiff_uassoc()用法及代码示例
- PHP geoip_continent_code_by_name()用法及代码示例
- d3.js d3.map.set()用法及代码示例
- PHP GmagickPixel setcolor()用法及代码示例
- PHP opendir()用法及代码示例
- PHP cal_to_jd()用法及代码示例
- d3.js d3.bisectLeft()用法及代码示例
- PHP stream_get_transports()用法及代码示例
注:本文由纯净天空筛选整理自barnadipdey2510大神的英文原创作品 Tensorflow.js tf.losses.meanSquaredError() Function。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。