当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch nonzero用法及代码示例


本文简要介绍python语言中 torch.nonzero 的用法。

用法:

torch.nonzero(input, *, out=None, as_tuple=False) → LongTensor or tuple of LongTensors

参数

input(Tensor) -输入张量。

关键字参数

out(LongTensor,可选的) -包含索引的输出张量

返回

如果 as_tupleFalse ,则输出张量包含索引。如果 as_tupleTrue ,则每个维度都有一个一维张量,包含沿该维度的每个非零元素的索引。

返回类型

LongTensor 或 LongTensor 的元组

注意

torch.nonzero(..., as_tuple=False)(默认)返回一个二维张量,其中每行都是非零值的索引。

torch.nonzero(..., as_tuple=True) 返回一维索引张量的元组,允许高级索引,因此 x[x.nonzero(as_tuple=True)] 给出张量 x 的所有非零值。在返回的元组中,每个索引张量都包含特定维度的非零索引。

有关这两种行为的更多详细信息,请参见下文。

input 在 CUDA 上时,torch.nonzero() 会导致 host-device 同步。

什么时候 as_tuple False (默认)

返回一个张量,其中包含 input 的所有非零元素的索引。结果中的每一行都包含 input 中非零元素的索引。结果按字典顺序排序,最后一个索引变化最快(C 风格)。

如果 input 具有 维度,则生成的索引张量 out 的大小为 ,其中 input 张量中非零元素的总数。

什么时候 as_tuple True

返回一维张量的元组,一个用于 input 中的每个维度,每个包含 input 的所有非零元素的索引(在该维度中)。

如果 input 具有 维度,则生成的元组包含大小为 张量,其中 input 张量中非零元素的总数。

作为一种特殊情况,当input 具有零维和非零标量值时,它被视为具有一个元素的一维张量。

例子:

>>> torch.nonzero(torch.tensor([1, 1, 1, 0, 1]))
tensor([[ 0],
        [ 1],
        [ 2],
        [ 4]])
>>> torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0],
...                             [0.0, 0.4, 0.0, 0.0],
...                             [0.0, 0.0, 1.2, 0.0],
...                             [0.0, 0.0, 0.0,-0.4]]))
tensor([[ 0,  0],
        [ 1,  1],
        [ 2,  2],
        [ 3,  3]])
>>> torch.nonzero(torch.tensor([1, 1, 1, 0, 1]), as_tuple=True)
(tensor([0, 1, 2, 4]),)
>>> torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0],
...                             [0.0, 0.4, 0.0, 0.0],
...                             [0.0, 0.0, 1.2, 0.0],
...                             [0.0, 0.0, 0.0,-0.4]]), as_tuple=True)
(tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]))
>>> torch.nonzero(torch.tensor(5), as_tuple=True)
(tensor([0]),)

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.nonzero。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。