Tensorflow.js是Google开发的开源库,用于在浏览器或节点环境中运行机器学习模型和深度学习神经网络。它还可以帮助开发人员使用JavaScript语言开发ML模型,并可以直接在浏览器或Node.js中使用ML。
tf.data.Dataset.batch() 函数用于将元素分组为批次。
用法:
tf.data.Dataset.batch(batchSize, smallLastBatch?)
参数:
- batchSize:应该在一个批次中的元素。
- smallLastBatch:如果为 true,则最终批次将在元素少于 batchSize 时发出元素,反之亦然。默认值为真。提供此值是可选的。
返回值:它返回一个 tf.data.Dataset。
范例1:在这个例子中,我们将采用一个大小为 6 的数组,并将其分成多个批次,每个批次有 3 个元素。
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Creating an array
const gfg = tf.data.array(
[10, 20, 30, 40, 50, 60]
).batch(3);
// Printing the elements
await gfg.forEachAsync(
element => element.print()
);
输出:
"Tensor [10, 20, 30]" "Tensor [40, 50, 60]"
范例2:这次我们将取 8 个元素,并尝试将它们分批拆分,每个 3 个元素。
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Creating an array
const gfg = tf.data.array(
[10, 20, 30, 40, 50, 60, 70, 80]
).batch(3);
// Printing the elements
await gfg.forEachAsync(
element => element.print()
);
输出:
"Tensor [10, 20, 30]" "Tensor [40, 50, 60]" "Tensor [70, 80]"
由于 smallLastBatch 的默认值默认为 true,因此我们看到了具有 2 个元素的第三批。
范例3:这次我们将把 smallLastBatch 参数作为 false 传递。
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Creating an array
const gfg = tf.data.array(
[10, 20, 30, 40, 50, 60, 70, 80]
).batch(3, false);
// Printing the elements
await gfg.forEachAsync(
element => element.print()
);
输出:
"Tensor [10, 20, 30]" "Tensor [40, 50, 60]"
由于 smallLastBatch 的默认值是 false,我们没有看到第三批,因为最后一批中只有 2 个元素小于指定的批大小 3。
参考: https://js.tensorflow.org/api/latest/#tf.data.Dataset.batch
相关用法
- Tensorflow.js tf.Tensor.buffer()用法及代码示例
- Java String repeat()用法及代码示例
- Tensorflow.js tf.LayersModel.evaluate()用法及代码示例
- Tensorflow.js tf.Sequential.add()用法及代码示例
- p5.js Element class()用法及代码示例
注:本文由纯净天空筛选整理自parasmadan15大神的英文原创作品 Tensorflow.js tf.data.Dataset class .batch() Method。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。