当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch EmbeddingCollection用法及代码示例


本文简要介绍python语言中 torchrec.modules.embedding_modules.EmbeddingCollection 的用法。

用法:

class torchrec.modules.embedding_modules.EmbeddingCollection(tables: List[torchrec.modules.embedding_configs.EmbeddingConfig], device: Optional[torch.device] = None)

参数

  • tables(List[EmbeddingBagConfig]) -嵌入表列表。

  • device(可选的[torch.device]) -默认计算设备。

基础:torch.nn.modules.module.Module

EmbeddingCollection 表示非池化嵌入的集合。

它以 [F X B X L] 形式的 KeyedJaggedTensor 形式处理稀疏数据,其中:

  • F:特征(键)

  • B:批量大小

  • L:稀疏特征的长度(可变)

并输出 Dict[feature (key), JaggedTensor] 。每个 JaggedTensor 包含 (B * L) X D 形式的值,其中:

  • B:批量大小

  • L:稀疏特征的长度(锯齿状)

  • D:每个特征(键)的嵌入维度和长度的形式为 L

例子:

e1_config = EmbeddingConfig(
    name="t1", embedding_dim=2, num_embeddings=10, feature_names=["f1"]
)
e2_config = EmbeddingConfig(
    name="t2", embedding_dim=3, num_embeddings=10, feature_names=["f2"]
)
ec_config = EmbeddingCollectionConfig(tables=[e1_config, e2_config])

ec = EmbeddingCollection(config=ec_config)

#     0       1        2  <-- batch
# 0   [0,1] None    [2]
# 1   [3]    [4]    [5,6,7]
# ^
# feature

features = KeyedJaggedTensor.from_offsets_sync(
    keys=["f1", "f2"],
    values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
    offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]),
)
feature_embeddings = ec(features)
print(feature_embeddings['f2'].values())
tensor([[-0.2050,  0.5478,  0.6054],
[ 0.7352,  0.3210, -3.0399],
[ 0.1279, -0.1756, -0.4130],
[ 0.7519, -0.4341, -0.0499],
[ 0.9329, -1.0697, -0.8095]], grad_fn=<EmbeddingBackward>)

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torchrec.modules.embedding_modules.EmbeddingCollection。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。