當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch VectorCrossNet用法及代碼示例


本文簡要介紹python語言中 torchrec.modules.crossnet.VectorCrossNet 的用法。

用法:

class torchrec.modules.crossnet.VectorCrossNet(in_features: int, num_layers: int)

參數

  • in_features(int) -輸入的維度。

  • num_layers(int) -模塊中的層數。

基礎:torch.nn.modules.module.Module

矢量交叉網絡可以稱為DCN-V1

它也是一個專門的低秩交叉網絡,其中秩=1。在這個版本中,在每一層上,我們隻保留一個向量內核 W (Nx1),而不是保留兩個內核 W 和 V。我們使用點運算來計算特征的“crossing”效果,從而節省了兩次矩陣乘法以進一步降低計算成本並減少可學習參數的數量。

在每一層 l 上,張量被轉換為

其中 是向量, 表示逐元素乘法; 表示點操作。

例子:

batch_size = 3
num_layers = 2
in_features = 10
input = torch.randn(batch_size, in_features)
dcn = VectorCrossNet(num_layers=num_layers)
output = dcn(input)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torchrec.modules.crossnet.VectorCrossNet。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。