本文簡要介紹python語言中 torchrec.models.dlrm.DLRM
的用法。
用法:
class torchrec.models.dlrm.DLRM(embedding_bag_collection: torchrec.modules.embedding_modules.EmbeddingBagCollection, dense_in_features: int, dense_arch_layer_sizes: List[int], over_arch_layer_sizes: List[int], dense_device: Optional[torch.device] = None)
embedding_bag_collection(torchrec.modules.embedding_modules.EmbeddingBagCollection) -用於定義
SparseArch
的嵌入包集合。dense_in_features(int) -密集輸入特征的維度。
dense_arch_layer_sizes(List[int]) -
DenseArch
的層大小。over_arch_layer_sizes(List[int]) -
OverArch
的層大小。InteractionArch
的輸出維度不應在此處手動指定。dense_device(可選的[torch.device]) -默認計算設備。
基礎:
torch.nn.modules.module.Module
Recsys 模型來自“個性化和推薦係統的深度學習推薦模型”(https://arxiv.org/abs/1906.00091)。通過學習每個特征的池化嵌入來處理稀疏特征。通過將密集特征投影到相同的嵌入空間來學習密集特征和稀疏特征之間的關係。此外,還學習稀疏特征之間的成對關係。
該模塊假設所有稀疏特征具有相同的嵌入維度(即每個EmbeddingBagConfig使用相同的embedding_dim)。
在整個模型的文檔中使用以下符號:
F:稀疏特征的數量
D:embedding_dimension 稀疏特征
B:批量大小
num_features:密集特征的數量
例子:
B = 2 D = 8 eb1_config = EmbeddingBagConfig( name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1", "f3"] ) eb2_config = EmbeddingBagConfig( name="t2", embedding_dim=D, num_embeddings=100, feature_names=["f2"], ) ebc_config = EmbeddingBagCollectionConfig(tables=[eb1_config, eb2_config]) ebc = EmbeddingBagCollection(config=ebc_config) model = DLRM( embedding_bag_collection=ebc, dense_in_features=100, dense_arch_layer_sizes=[20], over_arch_layer_sizes=[5, 1], ) features = torch.rand((B, 100)) # 0 1 # 0 [1,2] [4,5] # 1 [4,3] [2,9] # ^ # feature sparse_features = KeyedJaggedTensor.from_offsets_sync( keys=["f1", "f3"], values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]), offsets=torch.tensor([0, 2, 4, 6, 8]), ) logits = model( dense_features=features, sparse_features=sparse_features, )
參數:
相關用法
- Python PyTorch DeQuantize用法及代碼示例
- Python PyTorch DistributedModelParallel用法及代碼示例
- Python PyTorch DistributedDataParallel用法及代碼示例
- Python PyTorch DenseArch用法及代碼示例
- Python PyTorch DeepFM用法及代碼示例
- Python PyTorch DistributedDataParallel.register_comm_hook用法及代碼示例
- Python PyTorch DataFrameMaker用法及代碼示例
- Python PyTorch DistributedSampler用法及代碼示例
- Python PyTorch DistributedDataParallel.join用法及代碼示例
- Python PyTorch Dropout用法及代碼示例
- Python PyTorch DistributedModelParallel.named_parameters用法及代碼示例
- Python PyTorch Dropout3d用法及代碼示例
- Python PyTorch DataParallel用法及代碼示例
- Python PyTorch DistributedModelParallel.state_dict用法及代碼示例
- Python PyTorch DistributedDataParallel.no_sync用法及代碼示例
- Python PyTorch Decompressor用法及代碼示例
- Python PyTorch Dropout2d用法及代碼示例
- Python PyTorch DistributedModelParallel.named_buffers用法及代碼示例
- Python PyTorch DeepFM.forward用法及代碼示例
- Python PyTorch Dirichlet用法及代碼示例
- Python PyTorch Demultiplexer用法及代碼示例
- Python PyTorch DistributedOptimizer用法及代碼示例
- Python PyTorch DatasetFolder.find_classes用法及代碼示例
- Python PyTorch frexp用法及代碼示例
- Python PyTorch jvp用法及代碼示例
注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torchrec.models.dlrm.DLRM。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。