当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch FMInteractionArch用法及代码示例


本文简要介绍python语言中 torchrec.models.deepfm.FMInteractionArch 的用法。

用法:

class torchrec.models.deepfm.FMInteractionArch(fm_in_features: int, sparse_feature_names: List[str], deep_fm_dimension: int)

参数

  • fm_in_features(int) -DeepFM 中 dense_module 的输入维度。例如,如果输入嵌入是 [randn(3, 2, 3), randn(3, 4, 5)],则 fm_in_features 应为:2 * 3 + 4 * 5。

  • sparse_feature_names(List[str]) -F的长度。

  • deep_fm_dimension(int) -DeepFM 拱门中深度交互 (DI) 的输出。

基础:torch.nn.modules.module.Module

处理SparseArch (sparse_features) 和DenseArch (dense_features) 的输出,并根据DeepFM 论文的外部来源应用通用DeepFM 交互:https://arxiv.org/pdf/1703.04247.pdf

输出维度预计为 dense_features 的 cat ,D。

例子:

D = 3
B = 10
keys = ["f1", "f2"]
F = len(keys)
fm_inter_arch = FMInteractionArch(sparse_feature_names=keys)
dense_features = torch.rand((B, D))
sparse_features = KeyedTensor(
    keys=keys,
    length_per_key=[D, D],
    values=torch.rand((B, D * F)),
)
cat_fm_output = fm_inter_arch(dense_features, sparse_features)

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torchrec.models.deepfm.FMInteractionArch。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。