當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Tensorflow.js tf.layers.stackedRNNCells()用法及代碼示例

介紹: Tensorflow.js是 Google 開發的開源庫,用於在瀏覽器或節點環境中運行機器學習模型和深度學習神經網絡。

Tensorflow.js tf.layers.stackedRNNCells() 函數用於堆疊 RNN 單元並使它們表現為單個單元。

用法:

tf.layers.stackedRNNCells(arge); 

參數:上述方法接受以下參數:

  • args: 這是一個對象類型。它有以下字段:
    • cells: 它是一個應該堆疊在一起的實例 RNNCell 數組。
    • InputShape: 它應該為空或數字數組。它用於創建插入到該層之前的輸入層。它僅用於輸入層。
    • batchinputShape: 它應該為 null 的數字數組。它用於創建插入到該層之前的輸入層。它比 inputShape 具有更高的優先級,因此如果定義了 batchinputShape,它將用於創建輸入層。
    • batchSize: 它應該是一個數字。如果缺少batchinputShape,則用於使用InputShape創建batchinputShape,這將是[batchSize,...inputSize]。
    • dtype: 它是輸入層的數據類型。該輸入層的默認數據類型是 float32。
    • name: 它應該是一個字符串。它定義輸入層的名稱。
    • weights: 應該是張量。其中定義了輸入層的初始權重值。
    • inputDtype: 它應該是數據類型。它用於支持舊版。

返回:它返回一個對象(StackedRNNCells)。

示例 1:在此示例中,我們將看到簡單的 RNNCell 如何與 tf.layers.stackedRNNCells() 堆疊並作為單個 RNNCell 工作:

Javascript


import * as tf from "@tensorflow/tfjs"
// Creating RNNcells for stack
const cell1 = tf.layers.simpleRNNCell({units: 2});
const cell2 = tf.layers.simpleRNNCell({units: 4});
// Stack all the RNNCells 
const cell = tf.layers.stackedRNNCells({ cells: [cell1, cell2]});
const input = tf.input({shape: [8]});
const output = cell.apply(input);
console.log(JSON.stringify(output.shape));

輸出:

[null,8]

示例 2:在此示例中,我們將借助 stackedRNNCells 將多個單元組合成一個堆疊 RNN 單元,並用於創建 RNN。

Javascript


import * as tf from "@tensorflow/tfjs";
// Creating simple RNNCell for stacking together
const cell1 = tf.layers.simpleRNNCell({ units: 4 });
const cell2 = tf.layers.simpleRNNCell({ units: 8 });
const cell3 = tf.layers.simpleRNNCell({ units: 12 });
const cell4 = tf.layers.simpleRNNCell({ units: 16 });
const stacked_cell = tf.layers.stackedRNNCells({
    cells: [cell1, cell2, cell3, cell4],
    name: "Stacked_RNN",
    dtype: "int32",
});
const rnn = tf.layers.rnn({ cell: stacked_cell, returnSequences: true });
// Create input with 10 steps and 20 length vector at each step.
const input = tf.input({ shape: [8, 32] });
const output = rnn.apply(input);
console.log("Shape of output should be in : ", JSON.stringify(output.shape));

輸出:

Shape of output should be in :  [null,8,16]

參考:https://js.tensorflow.org/api/latest/#layers.stackedRNNCells



相關用法


注:本文由純淨天空篩選整理自satyam00so大神的英文原創作品 Tensorflow.js tf.layers.stackedRNNCells() Function。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。