当前位置: 首页>>编程示例 >>用法及示例精选 >>正文


Python PyTorch KeyedJaggedTensor用法及代码示例

本文简要介绍python语言中 torchrec.sparse.jagged_tensor.KeyedJaggedTensor 的用法。

用法:

class torchrec.sparse.jagged_tensor.KeyedJaggedTensor(*args, **kwargs)

参数

  • keys(List[str]) -锯齿状张量的关键。

  • values(torch.Tensor) -密集表示中的值张量。

  • weights(可选的[torch.Tensor]) -如果值有权重。与值具有相同形状的张量。

  • lengths(可选的[torch.Tensor]) -锯齿状切片,以长度表示。

  • offsets(可选的[torch.Tensor]) -锯齿状切片,表示为累积偏移量。

  • stride(可选的[int]) -每批次的示例数。

  • length_per_key(可选的[List[int]]) -每个键的起始长度。

  • offset_per_key(可选的[List[int]]) -每个键的起始偏移量和最终偏移量。

  • index_per_key(可选的[字典[str,int]]) - 每个键的索引。

  • jt_dict(可选的[字典[str,JaggedTensor]]) -

基础:torchrec.streamable.Pipelineable

表示(可选加权)键控锯齿状张量。

A JaggedTensor是一个张量锯齿状尺寸这是其切片可能具有不同长度的维度。在第一个维度上键控,在最后一个维度上呈锯齿状。

例如:

#              0       1        2  <-- dim_1
# "Feature0"   [V0,V1] None    [V2]
# "Feature1"   [V3]    [V4]    [V5,V6,V7]
#   ^
#  dim_0

dim_0: keyed dimension (ie. `Feature0`, `Feature1`)
dim_1: optional second dimension (ie. batch size)
dim_2: The jagged dimension which has slice lengths between 0-3 in the above example

We represent this data with following inputs:

values: torch.Tensor = [V0, V1, V2, V3, V4, V5, V6, V7], V == any tensor datatype
weights: torch.Tensor = [W0, W1, W2, W3, W4, W5, W6, W7], W == any tensor datatype
lengths: torch.Tensor = [2, 0, 1, 1, 1, 3], representing the jagged slice
offsets: torch.Tensor = [0, 2, 2, 3, 4, 5, 8], offsets from 0 for each jagged slice
keys: List[int] = ["Feature0", "Feature1"], which corresponds to each value of dim_0
index_per_key: Dict[str, int] = {"Feature0": 0, "Feature1": 1}, index for each key
offset_per_key: List[int] = [0, 3, 8], start offset for each key and final offset

实现是torch.jit.script-able

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torchrec.sparse.jagged_tensor.KeyedJaggedTensor。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。