當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python PyTorch KeyedJaggedTensor用法及代碼示例

本文簡要介紹python語言中 torchrec.sparse.jagged_tensor.KeyedJaggedTensor 的用法。

用法:

class torchrec.sparse.jagged_tensor.KeyedJaggedTensor(*args, **kwargs)

參數

  • keys(List[str]) -鋸齒狀張量的關鍵。

  • values(torch.Tensor) -密集表示中的值張量。

  • weights(可選的[torch.Tensor]) -如果值有權重。與值具有相同形狀的張量。

  • lengths(可選的[torch.Tensor]) -鋸齒狀切片,以長度表示。

  • offsets(可選的[torch.Tensor]) -鋸齒狀切片,表示為累積偏移量。

  • stride(可選的[int]) -每批次的示例數。

  • length_per_key(可選的[List[int]]) -每個鍵的起始長度。

  • offset_per_key(可選的[List[int]]) -每個鍵的起始偏移量和最終偏移量。

  • index_per_key(可選的[字典[str,int]]) - 每個鍵的索引。

  • jt_dict(可選的[字典[str,JaggedTensor]]) -

基礎:torchrec.streamable.Pipelineable

表示(可選加權)鍵控鋸齒狀張量。

A JaggedTensor是一個張量鋸齒狀尺寸這是其切片可能具有不同長度的維度。在第一個維度上鍵控,在最後一個維度上呈鋸齒狀。

例如:

#              0       1        2  <-- dim_1
# "Feature0"   [V0,V1] None    [V2]
# "Feature1"   [V3]    [V4]    [V5,V6,V7]
#   ^
#  dim_0

dim_0: keyed dimension (ie. `Feature0`, `Feature1`)
dim_1: optional second dimension (ie. batch size)
dim_2: The jagged dimension which has slice lengths between 0-3 in the above example

We represent this data with following inputs:

values: torch.Tensor = [V0, V1, V2, V3, V4, V5, V6, V7], V == any tensor datatype
weights: torch.Tensor = [W0, W1, W2, W3, W4, W5, W6, W7], W == any tensor datatype
lengths: torch.Tensor = [2, 0, 1, 1, 1, 3], representing the jagged slice
offsets: torch.Tensor = [0, 2, 2, 3, 4, 5, 8], offsets from 0 for each jagged slice
keys: List[int] = ["Feature0", "Feature1"], which corresponds to each value of dim_0
index_per_key: Dict[str, int] = {"Feature0": 0, "Feature1": 1}, index for each key
offset_per_key: List[int] = [0, 3, 8], start offset for each key and final offset

實現是torch.jit.script-able

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torchrec.sparse.jagged_tensor.KeyedJaggedTensor。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。