Tensorflow.js是由Google开发的开源库,用于在浏览器或节点环境中运行机器学习模型以及深度学习神经网络。
.trainOnBatch() 函数用于对特定批次的数据运行单独的梯度更新。
注意:此方法与 fit() 和 fitDataset() 的不同之处如下:
- 这种方法绝对适用于一批数据。
- 此方法仅返回损失和度量值,而不是按批次损失和度量值返回批次。
- 这种方法不支持像冗长和回调这样的细粒度选项。
用法:
trainOnBatch(x, y)
参数:
- x:规定的输入数据。它可以是 tf.Tensor、tf.Tensor[] 或 {[inputName:string]:tf.Tensor} 类型。它可以是以下任何一种:
- 指定的 tf.Tensor,或者如果指定的模型具有多个输入,则为 tf.Tensor 数组。
- 将输入名称绘制到匹配的 tf.Tensor 的对象,以防所述模型拥有命名输入。
- y:所述的目标数据。它可以是 tf.Tensor、tf.Tensor[] 或 {[inputName:string]:tf.Tensor} 类型。它必须是关于 x 的常数。
返回值:它返回 number 或 number[] 的承诺。
范例1:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Training Model
const mymodel = tf.sequential(
{layers:[tf.layers.dense({units:2, inputShape:[2]})]});
// Compiling our model
const config = {optimizer:'sgd',
loss:'meanSquaredError'};
mymodel.compile(config);
// Test tensor and target tensor
const xs = tf.ones([3,2]);
const ys = tf.ones([3,2]);
// Calling trainOneBatch() method
const result = await mymodel.trainOnBatch(xs, ys);
// Printing output
console.log(result);
输出:
2.0696773529052734
范例2:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
async function run() {
// Training Model
const mymodel = tf.sequential(
{layers:[tf.layers.dense({units:2, inputShape:[2],
activation:'sigmoid'})]});
// Compiling our model
const config = {optimizer:'sgd',
loss:'meanSquaredError'};
mymodel.compile(config);
// Test tensor and target tensor
const xs = tf.truncatedNormal([3,2]);
const ys = tf.randomNormal([3,2]);
// Calling trainOneBatch() method
const result = await mymodel.trainOnBatch(xs, ys);
// Printing output
console.log(JSON.stringify(+result));
}
// Function call
await run();
输出:
0.5935208797454834
参考: https://js.tensorflow.org/api/latest/#tf.LayersModel.trainOnBatch
相关用法
- Tensorflow.js tf.Sequential.trainOnBatch()用法及代码示例
- Tensorflow.js tf.Tensor.buffer()用法及代码示例
- Java String repeat()用法及代码示例
- Tensorflow.js tf.LayersModel.evaluate()用法及代码示例
- Tensorflow.js tf.data.Dataset.batch()用法及代码示例
- Tensorflow.js tf.Sequential.add()用法及代码示例
注:本文由纯净天空筛选整理自nidhi1352singh大神的英文原创作品 Tensorflow.js tf.LayersModel class .trainOnBatch() Method。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。