當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python torchrec.modules.crossnet.LowRankCrossNet用法及代碼示例


用法:

class torchrec.modules.crossnet.LowRankCrossNet(in_features: int, num_layers: int, low_rank: int = 1)

參數

  • in_features(int) -輸入的維度。

  • num_layers(int) -模塊中的層數。

  • low_rank(int) -交叉矩陣的秩設置(默認 = 0)。值必須始終 >= 0。

基礎:torch.nn.modules.module.Module

低秩交叉網是一種高效的交叉網。代替在每一層使用全秩交叉矩陣 (NxN),它將使用兩個內核 ,其中 r << N ,以簡化矩陣乘法。

在每一層 l 上,張量轉換為:

其中 是向量, 表示逐元素乘法, 表示矩陣乘法。

注意

排名r 應該明智地選擇。通常,我們期望r < N/2 節省計算量;我們應該期望 保持全等級交叉網絡的準確性。

例子:

batch_size = 3
num_layers = 2
in_features = 10
input = torch.randn(batch_size, in_features)
dcn = LowRankCrossNet(num_layers=num_layers, low_rank=3)
output = dcn(input)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torchrec.modules.crossnet.LowRankCrossNet。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。