当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch gather用法及代码示例


本文简要介绍python语言中 torch.gather 的用法。

用法:

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

参数

  • input(Tensor) -源张量

  • dim(int) -索引的轴

  • index(LongTensor) -要收集的元素的索引

关键字参数

  • sparse_grad(bool,可选的) -如果 True ,梯度 w.r.t. input 将是一个稀疏张量。

  • out(Tensor,可选的) -目标张量

沿 dim 指定的轴收集值。

对于 3-D 张量,输出由下式指定:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

inputindex 必须具有相同的维数。对于所有维度 d != dim ,还需要 index.size(d) <= input.size(d)out 将具有与 index 相同的形状。请注意,inputindex 不会相互广播。

例子:

>>> t = torch.tensor([[1, 2], [3, 4]])
>>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
tensor([[ 1,  1],
        [ 4,  3]])

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.gather。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。