当前位置: 首页>>编程示例 >>用法及示例精选 >>正文


Python PyTorch EmbeddingBagCollection用法及代码示例

本文简要介绍python语言中 torchrec.quant.embedding_modules.EmbeddingBagCollection 的用法。

用法:

class torchrec.quant.embedding_modules.EmbeddingBagCollection(table_name_to_quantized_weights: Dict[str, Tuple[torch.Tensor, torch.Tensor]], embedding_configs: List[torchrec.modules.embedding_configs.EmbeddingBagConfig], is_weighted: bool, device: torch.device)

返回

KeyedTensor

基础:torchrec.modules.embedding_modules.EmbeddingBagCollectionInterface

EmbeddingBagCollection 表示池化嵌入 (EmbeddingBags) 的集合。此 EmbeddingBagCollection 经过量化以降低精度。它依赖于 fbgemm 量化运算

它以 KeyedJaggedTensor 的形式处理稀疏数据,其值的形式为 [F X B X L] F:特征(键) B:批量大小 L:稀疏特征的长度(锯齿状)

并输出 KeyedTensor,其值的形式为 [B * (F * D)],其中 F:特征(键) D:每个特征(键)的嵌入维度 B:批量大小

构造函数参数:

table_name_to_quantized_weights (Dict[str, Tuple[Tensor, Tensor]]): 表到量化权重的映射 embedding_configs (List[EmbeddingBagConfig]): 嵌入表列表 is_weighted: (bool): 是否输入KeyedJaggedTensor是加权设备:(可选[torch.device]):默认计算设备

调用参数:

特征:KeyedJaggedTensor,

例子:

table_0 = EmbeddingBagConfig(
    name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"]
)
table_1 = EmbeddingBagConfig(
    name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"]
)
ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])

#        0       1        2  <-- batch
# "f1"   [0,1] None    [2]
# "f2"   [3]    [4]    [5,6,7]
#  ^
# feature
features = KeyedJaggedTensor(
    keys=["f1", "f2"],
    values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
    offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]),
)

ebc.qconfig = torch.quantization.QConfig(
    activation=torch.quantization.PlaceholderObserver.with_args(
        dtype=torch.qint8
    ),
    weight=torch.quantization.PlaceholderObserver.with_args(dtype=torch.qint8),
)

qebc = QuantEmbeddingBagCollection.from_float(ebc)
quantized_embeddings = qebc(features)

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torchrec.quant.embedding_modules.EmbeddingBagCollection。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。