當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch gather用法及代碼示例


本文簡要介紹python語言中 torch.gather 的用法。

用法:

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

參數

  • input(Tensor) -源張量

  • dim(int) -索引的軸

  • index(LongTensor) -要收集的元素的索引

關鍵字參數

  • sparse_grad(bool,可選的) -如果 True ,梯度 w.r.t. input 將是一個稀疏張量。

  • out(Tensor,可選的) -目標張量

沿 dim 指定的軸收集值。

對於 3-D 張量,輸出由下式指定:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

inputindex 必須具有相同的維數。對於所有維度 d != dim ,還需要 index.size(d) <= input.size(d)out 將具有與 index 相同的形狀。請注意,inputindex 不會相互廣播。

例子:

>>> t = torch.tensor([[1, 2], [3, 4]])
>>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
tensor([[ 1,  1],
        [ 4,  3]])

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.gather。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。