当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Tensorflow.js tf.layers.permute()用法及代码示例


模型中的层是模型构建的基本块,因为每一层都对下一层的输入和输出执行一些计算。
tf.layers.permute() 函数继承自 layer 类,用于以给定模式排列输入的维度,也用于将 RNN 和 convnet 连接在一起。在这篇文章中,我们将了解这个函数是如何工作的。

用法:

tf.layers.permute(agrs)

参数:

  • dims:它是一个整数数组,表示排列模式。它不包括批次维度。
  • inputShape:它用于创建和插入输入层。
  • batchInputShape:它用于创建和插入输入层。如果同时提到 inputShape 和 batchInputShape,则将使用 batchInputShape。
  • batchSize:它用于在指定 inputShape 而没有指定 batchInputShape 时创建 batchInputShape。
  • dtype:层的数据类型。
  • name:它代表图层的名称。
  • trainable:它是一个布尔值,表示是否通过拟合更新权重。
  • weights:它是一个权重数组,表示层的初始权重值。

返回值:置换

范例1:在这个例子中,我们将创建一个具有单层的模型,并且只将所需的参数传递给 permute() 函数。



Javascript


// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
 
//create model
const model = tf.sequential();
 
//add layer into model and use permute() method
model.add(tf.layers.permute({
   dims:[2,1],
   inputShape:[8,8]
}));
 
//print outputShape
console.log("Output Shape"+model.outputShape);
 
//model summary()
model.summary()

输出:

Output Shape,8,8
_________________________________________________________________
Layer (type)                 Output shape              Param #    
=================================================================
permute_Permute11 (Permute)  [null,8,8]                0          
=================================================================
Total params:0
Trainable params:0
Non-trainable params:0
_________________________________________________________________

范例2:在这个例子中,我们将使用 permute() 函数创建具有 2 层第 1 层和第 2 层的模型,并将所有参数传递给它。

Javascript


// Importing the tensorflow.js library
//import * as tf from "@tensorflow/tfjs"
 
//create model
const model = tf.sequential();
 
//add layer into model and use permute() method
//layer 1
model.add(tf.layers.permute({
   dims:[2,1],
   inputShape:[10,64],
   dtype:'int32',
   name:'layer1',
   batchSize:2,
  trainable:true,
  inputDType:'int32'
   
}));
 
//add layer2
model.add(tf.layers.permute({
   dims:[2,1],
   inputShape:[8,16],
   dtype:'int32',
   name:'layer2',
   batchSize:2,
  trainable:true,
  inputDType:'int32'
   
}));
 
//print outputShape
console.log("Output Shape:"+model.outputShape);
 
//model summary()
model.summary()

输出:

Output Shape:2,10,64
_________________________________________________________________
Layer (type)                 Output shape              Param #    
=================================================================
layer1 (Permute)             [2,64,10]                 0          
_________________________________________________________________
layer2 (Permute)             [2,10,64]                 0          
=================================================================
Total params:0
Trainable params:0
Non-trainable params:0
_________________________________________________________________

参考文献:https://js.tensorflow.org/api/latest/#layers.permute




相关用法


注:本文由纯净天空筛选整理自abhijitmahajan772大神的英文原创作品 Tensorflow.js tf.layers.permute() Function。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。