当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch EmbeddingBagCollection用法及代码示例


本文简要介绍python语言中 torchrec.modules.embedding_modules.EmbeddingBagCollection 的用法。

用法:

class torchrec.modules.embedding_modules.EmbeddingBagCollection(tables: List[torchrec.modules.embedding_configs.EmbeddingBagConfig], is_weighted: bool = False, device: Optional[torch.device] = None)

参数

  • tables(List[EmbeddingBagConfig]) -嵌入表列表。

  • is_weighted(bool) -输入KeyedJaggedTensor 是否加权。

  • device(可选的[torch.device]) -默认计算设备。

基础:torchrec.modules.embedding_modules.EmbeddingBagCollectionInterface

EmbeddingBagCollection 表示池化嵌入的集合 (EmbeddingBags)。

它以 KeyedJaggedTensor 的形式处理稀疏数据,其值的形式为 [F X B X L],其中:

  • F:特征(键)

  • B:批量大小

  • L:稀疏特征的长度(锯齿状)

并输出一个 KeyedTensor,其值的格式为 [B * (F * D)] 其中:

  • F:特征(键)

  • D:每个特征(键)的嵌入维度

  • B:批量大小

例子:

table_0 = EmbeddingBagConfig(
    name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"]
)
table_1 = EmbeddingBagConfig(
    name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"]
)

ebc = EmbeddingBagCollection(tables=[table_0, table_1])

#        0       1        2  <-- batch
# "f1"   [0,1] None    [2]
# "f2"   [3]    [4]    [5,6,7]
#  ^
# feature

features = KeyedJaggedTensor(
    keys=["f1", "f2"],
    values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
    offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]),
)

pooled_embeddings = ebc(features)
print(pooled_embeddings.values())
tensor([[-0.6149,  0.0000, -0.3176],
[-0.8876,  0.0000, -1.5606],
[ 1.6805,  0.0000,  0.6810],
[-1.4206, -1.0409,  0.2249],
[ 0.1823, -0.4697,  1.3823],
[-0.2767, -0.9965, -0.1797],
[ 0.8864,  0.1315, -2.0724]], grad_fn=<TransposeBackward0>)
print(pooled_embeddings.keys())
['f1', 'f2']
print(pooled_embeddings.offset_per_key())
tensor([0, 3, 7])

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torchrec.modules.embedding_modules.EmbeddingBagCollection。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。