當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python PyTorch EmbeddingBagCollection用法及代碼示例

本文簡要介紹python語言中 torchrec.quant.embedding_modules.EmbeddingBagCollection 的用法。

用法:

class torchrec.quant.embedding_modules.EmbeddingBagCollection(table_name_to_quantized_weights: Dict[str, Tuple[torch.Tensor, torch.Tensor]], embedding_configs: List[torchrec.modules.embedding_configs.EmbeddingBagConfig], is_weighted: bool, device: torch.device)

返回

KeyedTensor

基礎:torchrec.modules.embedding_modules.EmbeddingBagCollectionInterface

EmbeddingBagCollection 表示池化嵌入 (EmbeddingBags) 的集合。此 EmbeddingBagCollection 經過量化以降低精度。它依賴於 fbgemm 量化運算

它以 KeyedJaggedTensor 的形式處理稀疏數據,其值的形式為 [F X B X L] F:特征(鍵) B:批量大小 L:稀疏特征的長度(鋸齒狀)

並輸出 KeyedTensor,其值的形式為 [B * (F * D)],其中 F:特征(鍵) D:每個特征(鍵)的嵌入維度 B:批量大小

構造函數參數:

table_name_to_quantized_weights (Dict[str, Tuple[Tensor, Tensor]]): 表到量化權重的映射 embedding_configs (List[EmbeddingBagConfig]): 嵌入表列表 is_weighted: (bool): 是否輸入KeyedJaggedTensor是加權設備:(可選[torch.device]):默認計算設備

調用參數:

特征:KeyedJaggedTensor,

例子:

table_0 = EmbeddingBagConfig(
    name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"]
)
table_1 = EmbeddingBagConfig(
    name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"]
)
ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])

#        0       1        2  <-- batch
# "f1"   [0,1] None    [2]
# "f2"   [3]    [4]    [5,6,7]
#  ^
# feature
features = KeyedJaggedTensor(
    keys=["f1", "f2"],
    values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
    offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]),
)

ebc.qconfig = torch.quantization.QConfig(
    activation=torch.quantization.PlaceholderObserver.with_args(
        dtype=torch.qint8
    ),
    weight=torch.quantization.PlaceholderObserver.with_args(dtype=torch.qint8),
)

qebc = QuantEmbeddingBagCollection.from_float(ebc)
quantized_embeddings = qebc(features)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torchrec.quant.embedding_modules.EmbeddingBagCollection。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。