當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch EmbeddingBagCollection用法及代碼示例


本文簡要介紹python語言中 torchrec.modules.embedding_modules.EmbeddingBagCollection 的用法。

用法:

class torchrec.modules.embedding_modules.EmbeddingBagCollection(tables: List[torchrec.modules.embedding_configs.EmbeddingBagConfig], is_weighted: bool = False, device: Optional[torch.device] = None)

參數

  • tables(List[EmbeddingBagConfig]) -嵌入表列表。

  • is_weighted(bool) -輸入KeyedJaggedTensor 是否加權。

  • device(可選的[torch.device]) -默認計算設備。

基礎:torchrec.modules.embedding_modules.EmbeddingBagCollectionInterface

EmbeddingBagCollection 表示池化嵌入的集合 (EmbeddingBags)。

它以 KeyedJaggedTensor 的形式處理稀疏數據,其值的形式為 [F X B X L],其中:

  • F:特征(鍵)

  • B:批量大小

  • L:稀疏特征的長度(鋸齒狀)

並輸出一個 KeyedTensor,其值的格式為 [B * (F * D)] 其中:

  • F:特征(鍵)

  • D:每個特征(鍵)的嵌入維度

  • B:批量大小

例子:

table_0 = EmbeddingBagConfig(
    name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"]
)
table_1 = EmbeddingBagConfig(
    name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"]
)

ebc = EmbeddingBagCollection(tables=[table_0, table_1])

#        0       1        2  <-- batch
# "f1"   [0,1] None    [2]
# "f2"   [3]    [4]    [5,6,7]
#  ^
# feature

features = KeyedJaggedTensor(
    keys=["f1", "f2"],
    values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
    offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]),
)

pooled_embeddings = ebc(features)
print(pooled_embeddings.values())
tensor([[-0.6149,  0.0000, -0.3176],
[-0.8876,  0.0000, -1.5606],
[ 1.6805,  0.0000,  0.6810],
[-1.4206, -1.0409,  0.2249],
[ 0.1823, -0.4697,  1.3823],
[-0.2767, -0.9965, -0.1797],
[ 0.8864,  0.1315, -2.0724]], grad_fn=<TransposeBackward0>)
print(pooled_embeddings.keys())
['f1', 'f2']
print(pooled_embeddings.offset_per_key())
tensor([0, 3, 7])

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torchrec.modules.embedding_modules.EmbeddingBagCollection。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。