當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch nonzero用法及代碼示例


本文簡要介紹python語言中 torch.nonzero 的用法。

用法:

torch.nonzero(input, *, out=None, as_tuple=False) → LongTensor or tuple of LongTensors

參數

input(Tensor) -輸入張量。

關鍵字參數

out(LongTensor,可選的) -包含索引的輸出張量

返回

如果 as_tupleFalse ,則輸出張量包含索引。如果 as_tupleTrue ,則每個維度都有一個一維張量,包含沿該維度的每個非零元素的索引。

返回類型

LongTensor 或 LongTensor 的元組

注意

torch.nonzero(..., as_tuple=False)(默認)返回一個二維張量,其中每行都是非零值的索引。

torch.nonzero(..., as_tuple=True) 返回一維索引張量的元組,允許高級索引,因此 x[x.nonzero(as_tuple=True)] 給出張量 x 的所有非零值。在返回的元組中,每個索引張量都包含特定維度的非零索引。

有關這兩種行為的更多詳細信息,請參見下文。

input 在 CUDA 上時,torch.nonzero() 會導致 host-device 同步。

什麽時候 as_tuple False (默認)

返回一個張量,其中包含 input 的所有非零元素的索引。結果中的每一行都包含 input 中非零元素的索引。結果按字典順序排序,最後一個索引變化最快(C 風格)。

如果 input 具有 維度,則生成的索引張量 out 的大小為 ,其中 input 張量中非零元素的總數。

什麽時候 as_tuple True

返回一維張量的元組,一個用於 input 中的每個維度,每個包含 input 的所有非零元素的索引(在該維度中)。

如果 input 具有 維度,則生成的元組包含大小為 張量,其中 input 張量中非零元素的總數。

作為一種特殊情況,當input 具有零維和非零標量值時,它被視為具有一個元素的一維張量。

例子:

>>> torch.nonzero(torch.tensor([1, 1, 1, 0, 1]))
tensor([[ 0],
        [ 1],
        [ 2],
        [ 4]])
>>> torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0],
...                             [0.0, 0.4, 0.0, 0.0],
...                             [0.0, 0.0, 1.2, 0.0],
...                             [0.0, 0.0, 0.0,-0.4]]))
tensor([[ 0,  0],
        [ 1,  1],
        [ 2,  2],
        [ 3,  3]])
>>> torch.nonzero(torch.tensor([1, 1, 1, 0, 1]), as_tuple=True)
(tensor([0, 1, 2, 4]),)
>>> torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0],
...                             [0.0, 0.4, 0.0, 0.0],
...                             [0.0, 0.0, 1.2, 0.0],
...                             [0.0, 0.0, 0.0,-0.4]]), as_tuple=True)
(tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]))
>>> torch.nonzero(torch.tensor(5), as_tuple=True)
(tensor([0]),)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.nonzero。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。