本文简要介绍python语言中 torchrec.quant.embedding_modules.EmbeddingBagCollection
的用法。
用法:
class torchrec.quant.embedding_modules.EmbeddingBagCollection(table_name_to_quantized_weights: Dict[str, Tuple[torch.Tensor, torch.Tensor]], embedding_configs: List[torchrec.modules.embedding_configs.EmbeddingBagConfig], is_weighted: bool, device: torch.device)
KeyedTensor
基础:
torchrec.modules.embedding_modules.EmbeddingBagCollectionInterface
EmbeddingBagCollection 表示池化嵌入 (EmbeddingBags) 的集合。此 EmbeddingBagCollection 经过量化以降低精度。它依赖于 fbgemm 量化运算
它以 KeyedJaggedTensor 的形式处理稀疏数据,其值的形式为 [F X B X L] F:特征(键) B:批量大小 L:稀疏特征的长度(锯齿状)
并输出 KeyedTensor,其值的形式为 [B * (F * D)],其中 F:特征(键) D:每个特征(键)的嵌入维度 B:批量大小
- 构造函数参数:
table_name_to_quantized_weights (Dict[str, Tuple[Tensor, Tensor]]): 表到量化权重的映射 embedding_configs (List[EmbeddingBagConfig]): 嵌入表列表 is_weighted: (bool): 是否输入KeyedJaggedTensor是加权设备:(可选[torch.device]):默认计算设备
- 调用参数:
特征:KeyedJaggedTensor,
例子:
table_0 = EmbeddingBagConfig( name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] ) table_1 = EmbeddingBagConfig( name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"] ) ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) # 0 1 2 <-- batch # "f1" [0,1] None [2] # "f2" [3] [4] [5,6,7] # ^ # feature features = KeyedJaggedTensor( keys=["f1", "f2"], values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), ) ebc.qconfig = torch.quantization.QConfig( activation=torch.quantization.PlaceholderObserver.with_args( dtype=torch.qint8 ), weight=torch.quantization.PlaceholderObserver.with_args(dtype=torch.qint8), ) qebc = QuantEmbeddingBagCollection.from_float(ebc) quantized_embeddings = qebc(features)
返回:
相关用法
- Python PyTorch EmbeddingBagCollection.state_dict用法及代码示例
- Python PyTorch EmbeddingBagCollection用法及代码示例
- Python PyTorch EmbeddingBagCollection.named_buffers用法及代码示例
- Python PyTorch EmbeddingBag用法及代码示例
- Python PyTorch EmbeddingBag.from_pretrained用法及代码示例
- Python PyTorch Embedding用法及代码示例
- Python PyTorch EmbeddingCollection用法及代码示例
- Python PyTorch Embedding.from_pretrained用法及代码示例
- Python PyTorch ELU用法及代码示例
- Python PyTorch EndOnDiskCacheHolder用法及代码示例
- Python PyTorch Enumerator用法及代码示例
- Python PyTorch ElasticAgent用法及代码示例
- Python PyTorch EtcdServer用法及代码示例
- Python PyTorch EtcdRendezvousHandler用法及代码示例
- Python PyTorch Exponential用法及代码示例
- Python PyTorch frexp用法及代码示例
- Python PyTorch jvp用法及代码示例
- Python PyTorch cholesky用法及代码示例
- Python PyTorch vdot用法及代码示例
- Python PyTorch ScaledDotProduct.__init__用法及代码示例
- Python PyTorch gumbel_softmax用法及代码示例
- Python PyTorch get_tokenizer用法及代码示例
- Python PyTorch saved_tensors_hooks用法及代码示例
- Python PyTorch positive用法及代码示例
- Python PyTorch renorm用法及代码示例
注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torchrec.quant.embedding_modules.EmbeddingBagCollection。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。