当前位置: 首页>>编程示例 >>用法及示例精选 >>正文


Tensorflow.js tf.basicLSTMCell()用法及代码示例

Tensorflow.js是Google开发的开源工具包,用于在浏览器或节点平台上执行机器学习模型和深度学习神经网络。它还使开发人员能够在 JavaScript 中创建机器学习模型,并直接在浏览器中或通过 Node.js 使用它们。

tf.basicLSTMCell() 函数计算 BasicLSTMCell 的下一个状态和输出。

用法:

tf.basicLSTMCell (forgetBias, lstmKernel, lstmBias, data, c, h)

参数:

  • forgetBias:单元格的遗忘偏差。
  • lstmKernel:单元格的权重。
  • lstmBias:细胞的偏差。
  • data:单元格的输入。
  • c:先前细胞状态的数组。
  • h:先前单元输出的数组。

返回:[tf.Tensor2D,tf.Tensor2D]

示例 1:

Javascript


import * as tf from "@tensorflow/tfjs"; 
  
const data = tf.tensor2d([7, 51, 50, 54, 24, 1, 48, 75], [4, 2]); 
const kernel = tf.tensor2d([49, 62, 47, 93, 12, 80,  
    24, 89, 34, 8, 96, 74, 56, 42, 32, 53, 7, 87, 35, 54], [5, 4]); 
const state = tf.tensor2d([97, 56, 32, 29, 57, 6, 8, 75, 26, 20, 1, 17], [4, 3]); 
const output = tf.tensor2d([27, 77, 90, 72, 9, 8, 94, 41, 89, 51, 18, 60], [4, 3]); 
const basicLSTMCell = tf.basicLSTMCell(0.8, kernel, 2.2, data, state, output); 
  
console.log(basicLSTMCella)

输出:

[
 Tensor {
   kept: false,
   isDisposedInternal: false,
   shape: [ 4, 3 ],
   dtype: 'float32',
   size: 12,
   strides: [ 3 ],
   dataId: { id: 19 },
   id: 19,
   rankType: '2',
   scopeId: 0
 },
 Tensor {
   kept: false,
   isDisposedInternal: false,
   shape: [ 4, 3 ],
   dtype: 'float32',
   size: 12,
   strides: [ 3 ],
   dataId: { id: 22 },
   id: 22,
   rankType: '2',
   scopeId: 0
 }
]

示例2:

Javascript


import * as tf from "@tensorflow/tfjs"; 
  
const data = tf.tensor2d([70, 10, 62,  
    55, 74, 85, 66, 9], [4, 2]); 
  
const kernel = tf.tensor2d([10, 82, 93, 83,  
    49, 73, 45, 77, 56, 29, 32, 2, 24,  
    39, 34, 91, 95, 61, 76, 69], [5, 4]); 
  
const state = tf.tensor2d([29, 40, 79, 61,  
    5, 34, 78, 47, 86, 74, 46, 28], [4, 3]); 
  
const output = tf.tensor2d([25, 55, 33, 85,  
    82, 65, 20, 75, 54, 59, 50, 3], [4, 3]); 
  
const basicLSTMCell = tf.basicLSTMCell(1.0,  
    kernel, 2.0, data, state, output); 
  
const input = tf.input({ shape: [4, 2] }); 
const simpleRNNLayer = tf.layers.simpleRNN({ 
    units: 4, 
    returnSequences: true, 
    returnState: true, 
    cell: basicLSTMCell 
}); 
  
let outputs, finalState; 
  
[outputs, finalState] = simpleRNNLayer.apply(input); 
  
const model = tf.model({ 
    inputs: input, 
    outputs: outputs 
}); 
  
const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8], [1, 4, 2]); 
  
model.predict(x).print();

输出:

Tensor
   [[[0.8135326, -0.8665518, 0.946215 , 0.8714994],
     [0.9547493, -0.9747651, 0.9873405, 0.9995403],
     [0.9983249, -0.9986398, 0.9996439, 0.9999973],
     [0.9999447, -0.9999344, 0.9999925, 1        ]]]

参考: https://js.tensorflow.org/api/latest/#basicLSTMCell



相关用法


注:本文由纯净天空筛选整理自aayushmohansinha大神的英文原创作品 Tensorflow.js tf.basicLSTMCell() Function。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。