當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python PyTorch index_select用法及代碼示例

本文簡要介紹python語言中 torch.index_select 的用法。

用法:

torch.index_select(input, dim, index, *, out=None) → Tensor

參數

  • input(Tensor) -輸入張量。

  • dim(int) -我們索引的維度

  • index(IntTensor或者LongTensor) -包含要索引的索引的一維張量

關鍵字參數

out(Tensor,可選的) -輸出張量。

返回一個新張量,該張量使用 index 中的條目沿維度 dim 索引 input 張量,該條目是 LongTensor

返回的張量與原始張量 (input) 具有相同的維數。第 dim 維度的大小與 index 的長度相同;其他維度的大小與原始張量中的大小相同。

注意

返回的張量不是使用與原始張量相同的存儲。如果out如果形狀與預期不同,我們會默默地將其更改為正確的形狀,並在必要時重新分配底層存儲。

例子:

>>> x = torch.randn(3, 4)
>>> x
tensor([[ 0.1427,  0.0231, -0.5414, -1.0009],
        [-0.4664,  0.2647, -0.1228, -1.1068],
        [-1.1734, -0.6571,  0.7230, -0.6004]])
>>> indices = torch.tensor([0, 2])
>>> torch.index_select(x, 0, indices)
tensor([[ 0.1427,  0.0231, -0.5414, -1.0009],
        [-1.1734, -0.6571,  0.7230, -0.6004]])
>>> torch.index_select(x, 1, indices)
tensor([[ 0.1427, -0.5414],
        [-0.4664, -0.1228],
        [-1.1734,  0.7230]])

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.index_select。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。