當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Tensorflow.js tf.layers countParams()用法及代碼示例

Tensorflow.js是由Google開發的開源庫,用於在瀏覽器或節點環境中運行機器學習模型以及深度學習神經網絡。

.countParams() 函數用於在規定的權重中查找數字的絕對計數,例如 float32、int32。

用法:

countParams()

參數:該方法不持有任何參數。

返回值:它返回數字。



範例1:

Javascript


// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Creating a model
const model = tf.sequential();
  
// Adding a layer
model.add(tf.layers.dense({units:2, inputShape:[11]}));
  
// Calling setWeights() method
model.layers[0].setWeights(
    [tf.truncatedNormal([11, 2]), tf.zeros([2])]);
  
// Calling countParams() method and also
// Printing output
console.log(model.layers[0].countParams());

輸出:這裏,truncatedNormal() 方法用於創建 tf.Tensor 以及從截斷正態分布中采樣的值,zeros() 方法用於創建 tf.Tensor 以及所有設置為 0 的元素,setWeights() 方法用於創建設置權重。

24

範例2:

Javascript


// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Creating a model
const model = tf.sequential();
  
// Adding layers
model.add(tf.layers.dense({units:1, 
    inputShape:[5], batchSize:1, dtype:'int32'}));
model.add(tf.layers.dense({units:2, inputShape:[6], batchSize:5}));
model.add(tf.layers.dense({units:3, inputShape:[7], batchSize:8}));
model.add(tf.layers.dense({units:4, inputShape:[8], batchSize:12}));
  
// Calling setWeights() method
model.layers[0].setWeights([tf.ones([5, 1]), tf.zeros([1])]);
model.layers[1].setWeights([tf.ones([1, 2]), tf.zeros([2])]);
  
// Calling countParams() method and also
// Printing outputs
console.log(model.layers[0].countParams());
console.log(model.layers[1].countParams());
console.log(model.layers[2].countParams());

輸出:在這裏, ones() 方法用於創建一個 tf.Tensor 以及所有設置為 1 的元素。

6
4
9

參考: https://js.tensorflow.org/api/latest/#tf.layers.Layer.countParams




相關用法


注:本文由純淨天空篩選整理自nidhi1352singh大神的英文原創作品 Tensorflow.js tf.layers countParams() Method。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。