当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Tensorflow.js tf.layers.stackedRNNCells()用法及代码示例


介绍: Tensorflow.js是 Google 开发的开源库,用于在浏览器或节点环境中运行机器学习模型和深度学习神经网络。

Tensorflow.js tf.layers.stackedRNNCells() 函数用于堆叠 RNN 单元并使它们表现为单个单元。

用法:

tf.layers.stackedRNNCells(arge); 

参数:上述方法接受以下参数:

  • args: 这是一个对象类型。它有以下字段:
    • cells: 它是一个应该堆叠在一起的实例 RNNCell 数组。
    • InputShape: 它应该为空或数字数组。它用于创建插入到该层之前的输入层。它仅用于输入层。
    • batchinputShape: 它应该为 null 的数字数组。它用于创建插入到该层之前的输入层。它比 inputShape 具有更高的优先级,因此如果定义了 batchinputShape,它将用于创建输入层。
    • batchSize: 它应该是一个数字。如果缺少batchinputShape,则用于使用InputShape创建batchinputShape,这将是[batchSize,...inputSize]。
    • dtype: 它是输入层的数据类型。该输入层的默认数据类型是 float32。
    • name: 它应该是一个字符串。它定义输入层的名称。
    • weights: 应该是张量。其中定义了输入层的初始权重值。
    • inputDtype: 它应该是数据类型。它用于支持旧版。

返回:它返回一个对象(StackedRNNCells)。

示例 1:在此示例中,我们将看到简单的 RNNCell 如何与 tf.layers.stackedRNNCells() 堆叠并作为单个 RNNCell 工作:

Javascript


import * as tf from "@tensorflow/tfjs"
// Creating RNNcells for stack
const cell1 = tf.layers.simpleRNNCell({units: 2});
const cell2 = tf.layers.simpleRNNCell({units: 4});
// Stack all the RNNCells 
const cell = tf.layers.stackedRNNCells({ cells: [cell1, cell2]});
const input = tf.input({shape: [8]});
const output = cell.apply(input);
console.log(JSON.stringify(output.shape));

输出:

[null,8]

示例 2:在此示例中,我们将借助 stackedRNNCells 将多个单元组合成一个堆叠 RNN 单元,并用于创建 RNN。

Javascript


import * as tf from "@tensorflow/tfjs";
// Creating simple RNNCell for stacking together
const cell1 = tf.layers.simpleRNNCell({ units: 4 });
const cell2 = tf.layers.simpleRNNCell({ units: 8 });
const cell3 = tf.layers.simpleRNNCell({ units: 12 });
const cell4 = tf.layers.simpleRNNCell({ units: 16 });
const stacked_cell = tf.layers.stackedRNNCells({
    cells: [cell1, cell2, cell3, cell4],
    name: "Stacked_RNN",
    dtype: "int32",
});
const rnn = tf.layers.rnn({ cell: stacked_cell, returnSequences: true });
// Create input with 10 steps and 20 length vector at each step.
const input = tf.input({ shape: [8, 32] });
const output = rnn.apply(input);
console.log("Shape of output should be in : ", JSON.stringify(output.shape));

输出:

Shape of output should be in :  [null,8,16]

参考:https://js.tensorflow.org/api/latest/#layers.stackedRNNCells



相关用法


注:本文由纯净天空筛选整理自satyam00so大神的英文原创作品 Tensorflow.js tf.layers.stackedRNNCells() Function。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。