當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python torchrec.modules.crossnet.CrossNet用法及代碼示例


用法:

class torchrec.modules.crossnet.CrossNet(in_features: int, num_layers: int)

參數

  • in_features(int) -輸入的維度。

  • num_layers(int) -模塊中的層數。

基礎:torch.nn.modules.module.Module

Cross Network

Cross Net 是對形狀為 的張量到相同形狀的 “crossing” 操作堆棧,有效地在輸入張量上創建 可學習多項式函數。

在這個模塊中,交叉操作是基於滿秩矩陣(NxN)定義的,這樣交叉效果可以覆蓋每一層的所有位。在每一層 l 上,張量轉換為:

其中 是方陣 表示逐元素乘法, 表示矩陣乘法。

例子:

batch_size = 3
num_layers = 2
in_features = 10
input = torch.randn(batch_size, in_features)
dcn = CrossNet(num_layers=num_layers)
output = dcn(input)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torchrec.modules.crossnet.CrossNet。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。