當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Tensorflow.js tf.layers setWeights()用法及代碼示例

Tensorflow.js是由Google開發的開源庫,用於在瀏覽器或節點環境中運行機器學習模型以及深度學習神經網絡。

.setWeights() 方法用於根據給定的張量設置所述層的權重。

用法:

setWeights(weights)

參數:

  • weights:它是輸入張量的規定列表。它是 tf.Tensor[] 類型。其中,數組的數量及其形狀應等於所用層的規定權重的維度數量。換句話說,它必須等於 getWeights() 方法的結果。

返回值:它返回void。



範例1:

Javascript


// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Creating a model
const model = tf.sequential();
  
// Adding a layer
model.add(tf.layers.dense({units:2, inputShape:[11]}));
  
// Calling setWeights() method
model.layers[0].setWeights([tf.truncatedNormal([11, 2]), tf.zeros([2])]);
  
// Compiling the model
model.compile({loss:'categoricalCrossentropy', optimizer:'sgd'});
  
// Printing output using getWeights() method
model.layers[0].getWeights()[0].print();

輸出:

Tensor
    [[-0.5969906, -0.1883931],
     [0.8569255 , -0.49416  ],
     [0.1157023 , 0.1150239 ],
     [-0.4052143, 1.9936075 ],
     [0.3090054 , 0.7212474 ],
     [0.4626641 , -0.7287846],
     [0.4352857 , -0.5195332],
     [0.4626429 , 0.0216295 ],
     [-0.1110666, -0.5997615],
     [-0.5083916, -0.3582681],
     [-0.2847465, 1.184485  ]]

這裏,truncatedNormal() 方法用於創建 tf.Tensor 以及從截斷正態分布中采樣的值,zeros() 方法用於創建 tf.Tensor 以及所有設置為 0 的元素,getWeights() 方法用於創建打印使用 setWeights() 方法設置的權重。

範例2:

Javascript


// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Creating a model
const model = tf.sequential();
  
// Adding layers
model.add(tf.layers.dense({units:1, 
    inputShape:[5], batchSize:1, dtype:'int32'}));
model.add(tf.layers.dense({units:2, inputShape:[6], batchSize:5}));
model.add(tf.layers.dense({units:3, inputShape:[7], batchSize:8}));
model.add(tf.layers.dense({units:4, inputShape:[8], batchSize:12}));
  
// Calling setWeights() method
model.layers[0].setWeights([tf.ones([5, 1]), tf.zeros([1])]);
model.layers[1].setWeights([tf.ones([1, 2]), tf.zeros([2])]);
  
// Printing output using getWeights() method
model.layers[0].getWeights()[0].print();
model.layers[0].getWeights()[1].print();
model.layers[1].getWeights()[0].print();
model.layers[1].getWeights()[1].print();

輸出:

Tensor
    [[1],
     [1],
     [1],
     [1],
     [1]]
Tensor
    [0]
Tensor
     [[1, 1],]
Tensor
    [0, 0]

在這裏, ones() 方法用於創建一個 tf.Tensor 以及所有設置為 1 的元素。

參考:https://js.tensorflow.org/api/latest/#tf.layers.Layer.setWeights




相關用法


注:本文由純淨天空篩選整理自nidhi1352singh大神的英文原創作品 Tensorflow.js tf.layers setWeights() Method。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。