當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch Tensor.scatter_add_用法及代碼示例


本文簡要介紹python語言中 torch.Tensor.scatter_add_ 的用法。

用法:

Tensor.scatter_add_(dim, index, src) → Tensor

參數

  • dim(int) -索引的軸

  • index(LongTensor) -要分散和添加的元素的索引可以為空或與 src 具有相同的維度。當為空時,該操作返回 self 不變。

  • src(Tensor) -要分散和添加的源元素

以與 scatter_() 類似的方式將來自張量 other 的所有值添加到 selfindex 張量中指定的索引處。對於 src 中的每個值,將其添加到 self 中的索引中,該索引由 src 中的索引指定 dimension != dimindex 中的相應值 dimension = dim

對於 3-D 張量,self 更新為:

self[index[i][j][k]][j][k] += src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] += src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] += src[i][j][k]  # if dim == 2

selfindexsrc 應該具有相同的維數。還需要 index.size(d) <= src.size(d) 用於所有維度 d ,並且 index.size(d) <= self.size(d) 用於所有維度 d != dim 。請注意,indexsrc 不廣播。

注意

當給定 CUDA 設備上的張量時,此操作可能會表現得不確定。有關詳細信息,請參閱重現性。

注意

向後傳遞僅針對 src.shape == index.shape 實施。

例子:

>>> src = torch.ones((2, 5))
>>> index = torch.tensor([[0, 1, 2, 0, 0]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_add_(0, index, src)
tensor([[1., 0., 0., 1., 1.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.]])
>>> index = torch.tensor([[0, 1, 2, 0, 0], [0, 1, 2, 2, 2]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_add_(0, index, src)
tensor([[2., 0., 0., 1., 1.],
        [0., 2., 0., 0., 0.],
        [0., 0., 2., 1., 1.]])

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.Tensor.scatter_add_。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。