当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Tensorflow.js tf.layers setWeights()用法及代码示例


Tensorflow.js是由Google开发的开源库,用于在浏览器或节点环境中运行机器学习模型以及深度学习神经网络。

.setWeights() 方法用于根据给定的张量设置所述层的权重。

用法:

setWeights(weights)

参数:

  • weights:它是输入张量的规定列表。它是 tf.Tensor[] 类型。其中,数组的数量及其形状应等于所用层的规定权重的维度数量。换句话说,它必须等于 getWeights() 方法的结果。

返回值:它返回void。



范例1:

Javascript


// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Creating a model
const model = tf.sequential();
  
// Adding a layer
model.add(tf.layers.dense({units:2, inputShape:[11]}));
  
// Calling setWeights() method
model.layers[0].setWeights([tf.truncatedNormal([11, 2]), tf.zeros([2])]);
  
// Compiling the model
model.compile({loss:'categoricalCrossentropy', optimizer:'sgd'});
  
// Printing output using getWeights() method
model.layers[0].getWeights()[0].print();

输出:

Tensor
    [[-0.5969906, -0.1883931],
     [0.8569255 , -0.49416  ],
     [0.1157023 , 0.1150239 ],
     [-0.4052143, 1.9936075 ],
     [0.3090054 , 0.7212474 ],
     [0.4626641 , -0.7287846],
     [0.4352857 , -0.5195332],
     [0.4626429 , 0.0216295 ],
     [-0.1110666, -0.5997615],
     [-0.5083916, -0.3582681],
     [-0.2847465, 1.184485  ]]

这里,truncatedNormal() 方法用于创建 tf.Tensor 以及从截断正态分布中采样的值,zeros() 方法用于创建 tf.Tensor 以及所有设置为 0 的元素,getWeights() 方法用于创建打印使用 setWeights() 方法设置的权重。

范例2:

Javascript


// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Creating a model
const model = tf.sequential();
  
// Adding layers
model.add(tf.layers.dense({units:1, 
    inputShape:[5], batchSize:1, dtype:'int32'}));
model.add(tf.layers.dense({units:2, inputShape:[6], batchSize:5}));
model.add(tf.layers.dense({units:3, inputShape:[7], batchSize:8}));
model.add(tf.layers.dense({units:4, inputShape:[8], batchSize:12}));
  
// Calling setWeights() method
model.layers[0].setWeights([tf.ones([5, 1]), tf.zeros([1])]);
model.layers[1].setWeights([tf.ones([1, 2]), tf.zeros([2])]);
  
// Printing output using getWeights() method
model.layers[0].getWeights()[0].print();
model.layers[0].getWeights()[1].print();
model.layers[1].getWeights()[0].print();
model.layers[1].getWeights()[1].print();

输出:

Tensor
    [[1],
     [1],
     [1],
     [1],
     [1]]
Tensor
    [0]
Tensor
     [[1, 1],]
Tensor
    [0, 0]

在这里, ones() 方法用于创建一个 tf.Tensor 以及所有设置为 1 的元素。

参考:https://js.tensorflow.org/api/latest/#tf.layers.Layer.setWeights




相关用法


注:本文由纯净天空筛选整理自nidhi1352singh大神的英文原创作品 Tensorflow.js tf.layers setWeights() Method。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。