当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Tensorflow.js tf.layers.lstmCell()用法及代码示例


Tensorflow.js是一个开放源代码库,由Google开发,用于在浏览器或节点环境中运行机器学习模型以及深度学习神经网络。

layer.lstmCell() 函数用于 LSTM 的 Cell 类。它与 RNN 子类分开。

用法:

tf.layers.lstmCell(args)

参数:该函数包含一个 args 对象,该对象包含以下参数:

  • recurrentactivation:它用于激活循环步骤。
  • unitForgetBias:它是一个布尔值,用于初始化时的遗忘门。
  • implementation:它是一个指定实现模式的整数。 MODE 1 用于将其操作结构化为大量较小的点积和加法。 MODE 2 用于将它们批处理为更少、更大的操作。
  • units:无论是整数还是输出空间的维数,它都是一个数字。
  • activation:它用于要使用的函数。
  • useBias:层是否使用偏置向量是一个布尔值。
  • kernelInitializer:它用于输入的线性变换。
  • recurrentInitializer:它用于循环状态的线性变换。
  • biasInitializer:它用于偏置向量。
  • kernelRegularizer:它是一个字符串,用于应用于核权重矩阵的正则化函数。
  • recurrentRegularizer:它是一个字符串,用于应用于 recurrent_kernel 权重矩阵的正则化函数。
  • biasRegularizer:它是一个字符串,用于应用于偏置向量的正则化函数。
  • kernelConstraint:它是一个字符串,用于应用于内核权重矩阵的约束函数。
  • recurrentConstraint:它是一个字符串,用于应用于 recurrentKernel 权重矩阵的约束函数。
  • iasConstraint:它是一个字符串,用于应用于偏置向量的约束函数。
  • dropout:它是一个介于 0 和 1 之间的数字。用于输入线性变换的单位分数。
  • recurrentDropout:它是一个介于 0 和 1 之间的数字。为循环状态的线性变换而要下降的单位的分数。
  • inputShape:它是一个数字,用于创建要在该层之前插入的输入层。
  • batchInputShape:它是一个数字,用于创建要在该层之前插入的输入层。
  • batchSize:它是一个用于构造batchInputShape 的数字。
  • dtype:此参数仅适用于输入层。
  • name:它是一个用于图层的字符串。
  • trainable:它是一个布尔值,用于该层的权重是否可通过拟合更新。
  • weights:它是层的初始权重值。
  • inputDType:它用于旧版支持。它不适用于新代码。

返回值:它返回 LSTMCell。



范例1:

Javascript


// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Calling the layers.lstmCell 
// function and printing the output
const cell = tf.layers.lstmCell({units:3});
const input = tf.input({shape:[120]});
const output = cell.apply(input);
  
console.log(JSON.stringify(output.shape));

输出:

[null, 120]

范例2:

Javascript


// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
const cells = [
   tf.layers.lstmCell({units:6}),
   tf.layers.lstmCell({units:10}),
];
const rnn = tf.layers.rnn(
  {cell:cells, returnSequences:true}
);
  
// Create an input with 10 time steps 
// and a length-20 vector at each step
const input = tf.input({shape:[40, 60]});
const output = rnn.apply(input);
  
console.log(JSON.stringify(output.shape));

输出:

[null, 30, 8]

参考: https://js.tensorflow.org/api/latest/#layers.lstmCell




相关用法


注:本文由纯净天空筛选整理自sweetyty大神的英文原创作品 Tensorflow.js tf.layers.lstmCell() Function。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。