当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Tensorflow.js tf.layers.gruCell()用法及代码示例


Tensorflow.js是一个开放源代码库,由Google开发,用于在浏览器或节点环境中运行机器学习模型以及深度学习神经网络。

.layers.gruCell() 函数用于为 GRU 创建一个单元类。

用法:

tf.layers.gruCell (args)

参数:

  • recurrentActivation: 它是一个张量输入,是用于循环步骤的激活函数,默认为 hard sigmoid。如果传递 null,则不应用激活。
  • implementation:它是一个张量输入,有两种实现方式:
    1. 首先,模式将其操作结构为大量较小的点积和添加。
    2. 其次,模式会将它们分批处理成更少、更大的操作。这些模式将在不同硬件和不同应用程序上具有不同的性能配置文件。
  • resetAfter:它是一个张量输入,可以是 GRU 约定是否在矩阵乘法之后或之前应用重置门,其中 false=”before” 和 true=”after”。
  • units:它是一个具有正整数单位的张量输入,它是输出空间的维数。
  • activation:它是一个张量输入,是一个要使用的激活函数,默认为双曲正切。如果您传递 null,则将应用线性激活。
  • useBias:它是一个张量输入,其中偏置向量用于该层。
  • KernelInitializer:它是一个张量输入,是内核权重矩阵的初始化器,用于输入的线性变换。
  • recurrentInitializer:它是一个张量输入,是 recurrentKernel 权重矩阵的初始化器,用于循环状态的线性变换。
  • biasInitializer:它是一个张量输入,是偏置向量的初始化器。
  • kernelRegularizer:它是一个张量输入,其中正则化函数应用于核权重矩阵。
  • recurrentRegularizer:它是一个张量输入,其中正则化函数应用于 recurrent_kernel 权重矩阵。
  • biasRegularizer:它是一个张量输入,其中正则化函数应用于偏置向量。
  • kernelConstraint:它是一个张量输入,其中约束函数应用于核权重矩阵。
  • recurrentConstraint:它是一个张量输入,其中约束函数应用于 recurrentKernel 权重矩阵。
  • biasConstraint:它是一个张量输入,其中约束函数应用于偏置向量。
  • dropout:它是一个张量输入,其中为输入的线性变换和介于 0 和 1 之间的浮点数而要丢弃的单位的分数。
  • recurrentDropout:它是一个张量输入,其中用于循环状态和 0 和 1 之间的浮点数的线性变换的单位的分数。
  • inputShape:它是一个张量输入,将用于创建一个输入层以在此层之前插入(如果已定义)。它仅适用于输入层。
  • batchInputShape:它是一个张量输入,将用于创建一个输入层以在该层之前插入(如果已定义)。它仅适用于输入层。
  • batchSize:它是一个张量输入,其中 batchSize 用于构造 batchInputShape ,如果指定了 inputShape 而未指定 batchInputShape 。
  • dType:它是一个张量输入,是该层的数据类型,默认为 ‘float32’。
  • name:这是一个张量输入,是该层的名称。
  • trainable:它是一个张量输入,默认为 true,无论该层的权重是否可通过拟合更新。
  • weight:它是一个张量输入,可以是层的初始权重值。
  • inputDType:它是一个具有传统支持的张量输入,不用于新代码。

返回值:它返回 GRUCell。



范例1:在这个例子中,GRUCell 与 RNN 子类 GRU 的不同之处在于它的 apply 方法只获取单个时间步的输入数据并在时间步返回单元格的输出,而 GRU 在多个时间步上获取输入数据。

Javascript


// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining input elements
const cell = tf.layers.gruCell({units:3});
const input = tf.input({shape:[11]});
const output = cell.apply(input);
  
console.log(JSON.stringify(output.shape));

输出:

[null,11]

范例2:在这个例子中,GRUCell 的实例可用于构建 RNN 层。此工作流最典型的用途是将多个单元组合成一个堆叠的 RNN 单元(即内部的 StackedRNNCell)并使用它来创建一个 RNN。

Javascript


// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining input elements
const cells = [
   tf.layers.gruCell({units:8}),
   tf.layers.gruCell({units:16}),
];
const rnn = tf.layers.rnn({
    cell:cells, 
    returnSequences:true
});
  
// Create an input with 20 time steps and 
// a length-30 vector at each step.
const input = tf.input({shape:[20, 30]});
const output = rnn.apply(input);
  
console.log(JSON.stringify(output.shape));

输出:

[null,20,16]

参考:https://js.tensorflow.org/api/latest/#layers.gruCell




相关用法


注:本文由纯净天空筛选整理自arorakashish0911大神的英文原创作品 Tensorflow.js tf.layers.gruCell() Function。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。