当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch VectorCrossNet用法及代码示例


本文简要介绍python语言中 torchrec.modules.crossnet.VectorCrossNet 的用法。

用法:

class torchrec.modules.crossnet.VectorCrossNet(in_features: int, num_layers: int)

参数

  • in_features(int) -输入的维度。

  • num_layers(int) -模块中的层数。

基础:torch.nn.modules.module.Module

矢量交叉网络可以称为DCN-V1

它也是一个专门的低秩交叉网络,其中秩=1。在这个版本中,在每一层上,我们只保留一个向量内核 W (Nx1),而不是保留两个内核 W 和 V。我们使用点运算来计算特征的“crossing”效果,从而节省了两次矩阵乘法以进一步降低计算成本并减少可学习参数的数量。

在每一层 l 上,张量被转换为

其中 是向量, 表示逐元素乘法; 表示点操作。

例子:

batch_size = 3
num_layers = 2
in_features = 10
input = torch.randn(batch_size, in_features)
dcn = VectorCrossNet(num_layers=num_layers)
output = dcn(input)

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torchrec.modules.crossnet.VectorCrossNet。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。