当前位置: 首页>>编程示例 >>用法及示例精选 >>正文


Python PyTorch KeyedTensor用法及代码示例

本文简要介绍python语言中 torchrec.sparse.jagged_tensor.KeyedTensor 的用法。

用法:

class torchrec.sparse.jagged_tensor.KeyedTensor(*args, **kwargs)

参数

  • keys(List[str]) -键列表

  • length_per_key(List[int]) -每个键沿键维度的长度

  • values(torch.Tensor) -密集张量,通常沿关键维度连接

  • key_dim(int) -关键维度,零索引 - 默认为 1(通常 B 是 0 维度)

基础:torchrec.streamable.Pipelineable

KeyedTensor 保存一个密集张量的串联列表,每个张量都可以通过键访问。键控尺寸可以是可变长度 (length_per_key)。常见用例包括存储不同维度的池化嵌入。

实现是torch.jit.script-able

例子:

# kt is KeyedTensor holding

#                         0           1           2
#     "Embedding A"    [1,1]       [1,1]        [1,1]
#     "Embedding B"    [2,1,2]     [2,1,2]      [2,1,2]
#     "Embedding C"    [3,1,2,3]   [3,1,2,3]    [3,1,2,3]
# tensor_list = [
#         torch.tensor([[1,1]] * 3),
#         torch.tensor([[2,1,2]] * 3),
#         torch.tensor([[3,1,2,3]] * 3),
#     ]
keys = ["Embedding A", "Embedding B", "Embedding C"]
kt = KeyedTensor.from_tensor_list(keys, tensor_list)
kt.values()
    tensor([[1, 1, 2, 1, 2, 3, 1, 2, 3],
    [1, 1, 2, 1, 2, 3, 1, 2, 3],
    [1, 1, 2, 1, 2, 3, 1, 2, 3]])
kt["Embedding B"]
    tensor([[2, 1, 2],
    [2, 1, 2],
    [2, 1, 2]])

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torchrec.sparse.jagged_tensor.KeyedTensor。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。