当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch DeepFM用法及代码示例


本文简要介绍python语言中 torchrec.modules.deepfm.DeepFM 的用法。

用法:

class torchrec.modules.deepfm.DeepFM(dense_module: torch.nn.modules.module.Module)

参数

dense_module(nn.Module) -任何可以在 DeepFM 中使用的定制模块(例如 MLP)。该模块的in_features必须等于元素计数。例如,输入嵌入是 [randn(3, 2, 3), randn(3, 4, 5)],in_features 应该是:2*3+4*5。

基础:torch.nn.modules.module.Module

这是DeepFM module

本模块不涵盖已发表论文的end-end 函数。相反,它仅涵盖出版物的深层部分。它用于学习high-order 特征交互。如果要学习低阶特征交互,请改用FactorizationMachine模块,它将共享该模块的相同嵌入输入。

为了支持建模灵活性,我们将关键组件定制为:

  • 与公开论文不同,我们将输入从原始稀疏

    特征到特征的嵌入。它允许嵌入维度和嵌入数量的灵活性,只要所有嵌入张量具有相同的批量大小。

  • 在公开论文之上,我们允许用户自定义隐藏层

    可以是任何模块,不仅限于 MLP。

模块的一般架构如下:

# 1 x 1 output
# ^
# pass into `dense_module`
# ^
# 1 x 90
# ^
# concat
# ^
# 1 x 20, 1 x 30, 1 x 40 list of embeddings

例子:

import torch
from torchrec.fb.modules.deepfm import DeepFM
from torchrec.fb.modules.mlp import LazyMLP
batch_size = 3
output_dim = 30
# the input embedding are a torch.Tensor of [batch_size, num_embeddings, embedding_dim]
input_embeddings = [
    torch.randn(batch_size, 2, 64),
    torch.randn(batch_size, 2, 32),
]
dense_module = nn.Linear(192, output_dim)
deepfm = DeepFM(dense_module=dense_module)
deep_fm_output = deepfm(embeddings=input_embeddings)

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torchrec.modules.deepfm.DeepFM。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。