當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Tensorflow.js tf.layers.lstmCell()用法及代碼示例


Tensorflow.js是一個開放源代碼庫,由Google開發,用於在瀏覽器或節點環境中運行機器學習模型以及深度學習神經網絡。

layer.lstmCell() 函數用於 LSTM 的 Cell 類。它與 RNN 子類分開。

用法:

tf.layers.lstmCell(args)

參數:該函數包含一個 args 對象,該對象包含以下參數:

  • recurrentactivation:它用於激活循環步驟。
  • unitForgetBias:它是一個布爾值,用於初始化時的遺忘門。
  • implementation:它是一個指定實現模式的整數。 MODE 1 用於將其操作結構化為大量較小的點積和加法。 MODE 2 用於將它們批處理為更少、更大的操作。
  • units:無論是整數還是輸出空間的維數,它都是一個數字。
  • activation:它用於要使用的函數。
  • useBias:層是否使用偏置向量是一個布爾值。
  • kernelInitializer:它用於輸入的線性變換。
  • recurrentInitializer:它用於循環狀態的線性變換。
  • biasInitializer:它用於偏置向量。
  • kernelRegularizer:它是一個字符串,用於應用於核權重矩陣的正則化函數。
  • recurrentRegularizer:它是一個字符串,用於應用於 recurrent_kernel 權重矩陣的正則化函數。
  • biasRegularizer:它是一個字符串,用於應用於偏置向量的正則化函數。
  • kernelConstraint:它是一個字符串,用於應用於內核權重矩陣的約束函數。
  • recurrentConstraint:它是一個字符串,用於應用於 recurrentKernel 權重矩陣的約束函數。
  • iasConstraint:它是一個字符串,用於應用於偏置向量的約束函數。
  • dropout:它是一個介於 0 和 1 之間的數字。用於輸入線性變換的單位分數。
  • recurrentDropout:它是一個介於 0 和 1 之間的數字。為循環狀態的線性變換而要下降的單位的分數。
  • inputShape:它是一個數字,用於創建要在該層之前插入的輸入層。
  • batchInputShape:它是一個數字,用於創建要在該層之前插入的輸入層。
  • batchSize:它是一個用於構造batchInputShape 的數字。
  • dtype:此參數僅適用於輸入層。
  • name:它是一個用於圖層的字符串。
  • trainable:它是一個布爾值,用於該層的權重是否可通過擬合更新。
  • weights:它是層的初始權重值。
  • inputDType:它用於舊版支持。它不適用於新代碼。

返回值:它返回 LSTMCell。



範例1:

Javascript


// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Calling the layers.lstmCell 
// function and printing the output
const cell = tf.layers.lstmCell({units:3});
const input = tf.input({shape:[120]});
const output = cell.apply(input);
  
console.log(JSON.stringify(output.shape));

輸出:

[null, 120]

範例2:

Javascript


// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
const cells = [
   tf.layers.lstmCell({units:6}),
   tf.layers.lstmCell({units:10}),
];
const rnn = tf.layers.rnn(
  {cell:cells, returnSequences:true}
);
  
// Create an input with 10 time steps 
// and a length-20 vector at each step
const input = tf.input({shape:[40, 60]});
const output = rnn.apply(input);
  
console.log(JSON.stringify(output.shape));

輸出:

[null, 30, 8]

參考: https://js.tensorflow.org/api/latest/#layers.lstmCell




相關用法


注:本文由純淨天空篩選整理自sweetyty大神的英文原創作品 Tensorflow.js tf.layers.lstmCell() Function。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。