當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python PyTorch KeyedTensor用法及代碼示例

本文簡要介紹python語言中 torchrec.sparse.jagged_tensor.KeyedTensor 的用法。

用法:

class torchrec.sparse.jagged_tensor.KeyedTensor(*args, **kwargs)

參數

  • keys(List[str]) -鍵列表

  • length_per_key(List[int]) -每個鍵沿鍵維度的長度

  • values(torch.Tensor) -密集張量,通常沿關鍵維度連接

  • key_dim(int) -關鍵維度,零索引 - 默認為 1(通常 B 是 0 維度)

基礎:torchrec.streamable.Pipelineable

KeyedTensor 保存一個密集張量的串聯列表,每個張量都可以通過鍵訪問。鍵控尺寸可以是可變長度 (length_per_key)。常見用例包括存儲不同維度的池化嵌入。

實現是torch.jit.script-able

例子:

# kt is KeyedTensor holding

#                         0           1           2
#     "Embedding A"    [1,1]       [1,1]        [1,1]
#     "Embedding B"    [2,1,2]     [2,1,2]      [2,1,2]
#     "Embedding C"    [3,1,2,3]   [3,1,2,3]    [3,1,2,3]
# tensor_list = [
#         torch.tensor([[1,1]] * 3),
#         torch.tensor([[2,1,2]] * 3),
#         torch.tensor([[3,1,2,3]] * 3),
#     ]
keys = ["Embedding A", "Embedding B", "Embedding C"]
kt = KeyedTensor.from_tensor_list(keys, tensor_list)
kt.values()
    tensor([[1, 1, 2, 1, 2, 3, 1, 2, 3],
    [1, 1, 2, 1, 2, 3, 1, 2, 3],
    [1, 1, 2, 1, 2, 3, 1, 2, 3]])
kt["Embedding B"]
    tensor([[2, 1, 2],
    [2, 1, 2],
    [2, 1, 2]])

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torchrec.sparse.jagged_tensor.KeyedTensor。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。