当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Tensorflow.js tf.layers countParams()用法及代码示例


Tensorflow.js是由Google开发的开源库,用于在浏览器或节点环境中运行机器学习模型以及深度学习神经网络。

.countParams() 函数用于在规定的权重中查找数字的绝对计数,例如 float32、int32。

用法:

countParams()

参数:该方法不持有任何参数。

返回值:它返回数字。



范例1:

Javascript


// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Creating a model
const model = tf.sequential();
  
// Adding a layer
model.add(tf.layers.dense({units:2, inputShape:[11]}));
  
// Calling setWeights() method
model.layers[0].setWeights(
    [tf.truncatedNormal([11, 2]), tf.zeros([2])]);
  
// Calling countParams() method and also
// Printing output
console.log(model.layers[0].countParams());

输出:这里,truncatedNormal() 方法用于创建 tf.Tensor 以及从截断正态分布中采样的值,zeros() 方法用于创建 tf.Tensor 以及所有设置为 0 的元素,setWeights() 方法用于创建设置权重。

24

范例2:

Javascript


// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Creating a model
const model = tf.sequential();
  
// Adding layers
model.add(tf.layers.dense({units:1, 
    inputShape:[5], batchSize:1, dtype:'int32'}));
model.add(tf.layers.dense({units:2, inputShape:[6], batchSize:5}));
model.add(tf.layers.dense({units:3, inputShape:[7], batchSize:8}));
model.add(tf.layers.dense({units:4, inputShape:[8], batchSize:12}));
  
// Calling setWeights() method
model.layers[0].setWeights([tf.ones([5, 1]), tf.zeros([1])]);
model.layers[1].setWeights([tf.ones([1, 2]), tf.zeros([2])]);
  
// Calling countParams() method and also
// Printing outputs
console.log(model.layers[0].countParams());
console.log(model.layers[1].countParams());
console.log(model.layers[2].countParams());

输出:在这里, ones() 方法用于创建一个 tf.Tensor 以及所有设置为 1 的元素。

6
4
9

参考: https://js.tensorflow.org/api/latest/#tf.layers.Layer.countParams




相关用法


注:本文由纯净天空筛选整理自nidhi1352singh大神的英文原创作品 Tensorflow.js tf.layers countParams() Method。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。