當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python PyTorch where用法及代碼示例

本文簡要介紹python語言中 torch.where 的用法。

用法:

torch.where(condition, x, y) → Tensor

參數

  • condition(BoolTensor) -當 True(非零)時,產生 x,否則產生 y

  • x(Tensor或者標量) -值(如果 x 是標量)或在 conditionTrue 的索引處選擇的值

  • y(Tensor或者標量) -值(如果 y 是標量)或在 conditionFalse 的索引處選擇的值

返回

形狀等於 conditionxy 的廣播形狀的張量

返回類型

Tensor

返回從 xy 中選擇的元素的張量,具體取決於 condition

操作定義為:

注意

張量 conditionxy 必須是可廣播的。

注意

當前有效的標量和張量組合是 1. 浮點數 dtype 和 torch.double 的標量 2. 積分 dtype 和 torch.long 的標量 3. 複雜 dtype 和 torch.complex128 的標量

例子:

>>> x = torch.randn(3, 2)
>>> y = torch.ones(3, 2)
>>> x
tensor([[-0.4620,  0.3139],
        [ 0.3898, -0.7197],
        [ 0.0478, -0.1657]])
>>> torch.where(x > 0, x, y)
tensor([[ 1.0000,  0.3139],
        [ 0.3898,  1.0000],
        [ 0.0478,  1.0000]])
>>> x = torch.randn(2, 2, dtype=torch.double)
>>> x
tensor([[ 1.0779,  0.0383],
        [-0.8785, -1.1089]], dtype=torch.float64)
>>> torch.where(x > 0, x, 0.)
tensor([[1.0779, 0.0383],
        [0.0000, 0.0000]], dtype=torch.float64)
torch.where(condition) → tuple of LongTensor

torch.where(condition)torch.nonzero(condition, as_tuple=True) 相同。

注意

另見 torch.nonzero()

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.where。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。