當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch take_along_dim用法及代碼示例


本文簡要介紹python語言中 torch.take_along_dim 的用法。

用法:

torch.take_along_dim(input, indices, dim, *, out=None) → Tensor

參數

  • input(Tensor) -輸入張量。

  • indices(張量) -input 中的索引。必須有長 dtype。

  • dim(int) -要選擇的尺寸。

關鍵字參數

out(Tensor,可選的) -輸出張量。

沿著給定的 dimindices 的一維索引處從 input 中選擇值。

沿維度返回索引的函數,如 torch.argmax() torch.argsort() ,旨在與此函數一起使用。請參閱下麵的示例。

注意

這個函數類似於 NumPy 的 take_along_axis 。另見 torch.gather()

例子:

>>> t = torch.tensor([[10, 30, 20], [60, 40, 50]])
>>> max_idx = torch.argmax(t)
>>> torch.take_along_dim(t, max_idx)
tensor([60])
>>> sorted_idx = torch.argsort(t, dim=1)
>>> torch.take_along_dim(t, sorted_idx, dim=1)
tensor([[10, 20, 30],
        [40, 50, 60]])

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.take_along_dim。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。