当前位置: 首页>>编程示例 >>用法及示例精选 >>正文


Python PyTorch where用法及代码示例

本文简要介绍python语言中 torch.where 的用法。

用法:

torch.where(condition, x, y) → Tensor

参数

  • condition(BoolTensor) -当 True(非零)时,产生 x,否则产生 y

  • x(Tensor或者标量) -值(如果 x 是标量)或在 conditionTrue 的索引处选择的值

  • y(Tensor或者标量) -值(如果 y 是标量)或在 conditionFalse 的索引处选择的值

返回

形状等于 conditionxy 的广播形状的张量

返回类型

Tensor

返回从 xy 中选择的元素的张量,具体取决于 condition

操作定义为:

注意

张量 conditionxy 必须是可广播的。

注意

当前有效的标量和张量组合是 1. 浮点数 dtype 和 torch.double 的标量 2. 积分 dtype 和 torch.long 的标量 3. 复杂 dtype 和 torch.complex128 的标量

例子:

>>> x = torch.randn(3, 2)
>>> y = torch.ones(3, 2)
>>> x
tensor([[-0.4620,  0.3139],
        [ 0.3898, -0.7197],
        [ 0.0478, -0.1657]])
>>> torch.where(x > 0, x, y)
tensor([[ 1.0000,  0.3139],
        [ 0.3898,  1.0000],
        [ 0.0478,  1.0000]])
>>> x = torch.randn(2, 2, dtype=torch.double)
>>> x
tensor([[ 1.0779,  0.0383],
        [-0.8785, -1.1089]], dtype=torch.float64)
>>> torch.where(x > 0, x, 0.)
tensor([[1.0779, 0.0383],
        [0.0000, 0.0000]], dtype=torch.float64)
torch.where(condition) → tuple of LongTensor

torch.where(condition)torch.nonzero(condition, as_tuple=True) 相同。

注意

另见 torch.nonzero()

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.where。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。