當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch EmbeddingCollection用法及代碼示例


本文簡要介紹python語言中 torchrec.modules.embedding_modules.EmbeddingCollection 的用法。

用法:

class torchrec.modules.embedding_modules.EmbeddingCollection(tables: List[torchrec.modules.embedding_configs.EmbeddingConfig], device: Optional[torch.device] = None)

參數

  • tables(List[EmbeddingBagConfig]) -嵌入表列表。

  • device(可選的[torch.device]) -默認計算設備。

基礎:torch.nn.modules.module.Module

EmbeddingCollection 表示非池化嵌入的集合。

它以 [F X B X L] 形式的 KeyedJaggedTensor 形式處理稀疏數據,其中:

  • F:特征(鍵)

  • B:批量大小

  • L:稀疏特征的長度(可變)

並輸出 Dict[feature (key), JaggedTensor] 。每個 JaggedTensor 包含 (B * L) X D 形式的值,其中:

  • B:批量大小

  • L:稀疏特征的長度(鋸齒狀)

  • D:每個特征(鍵)的嵌入維度和長度的形式為 L

例子:

e1_config = EmbeddingConfig(
    name="t1", embedding_dim=2, num_embeddings=10, feature_names=["f1"]
)
e2_config = EmbeddingConfig(
    name="t2", embedding_dim=3, num_embeddings=10, feature_names=["f2"]
)
ec_config = EmbeddingCollectionConfig(tables=[e1_config, e2_config])

ec = EmbeddingCollection(config=ec_config)

#     0       1        2  <-- batch
# 0   [0,1] None    [2]
# 1   [3]    [4]    [5,6,7]
# ^
# feature

features = KeyedJaggedTensor.from_offsets_sync(
    keys=["f1", "f2"],
    values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
    offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]),
)
feature_embeddings = ec(features)
print(feature_embeddings['f2'].values())
tensor([[-0.2050,  0.5478,  0.6054],
[ 0.7352,  0.3210, -3.0399],
[ 0.1279, -0.1756, -0.4130],
[ 0.7519, -0.4341, -0.0499],
[ 0.9329, -1.0697, -0.8095]], grad_fn=<EmbeddingBackward>)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torchrec.modules.embedding_modules.EmbeddingCollection。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。