當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python PyTorch quantile用法及代碼示例

本文簡要介紹python語言中 torch.quantile 的用法。

用法:

torch.quantile(input, q, dim=None, keepdim=False, *, out=None) → Tensor

參數

  • input(Tensor) -輸入張量。

  • q(float或者Tensor) -[0, 1] 範圍內的值的標量或一維張量。

  • dim(int) -要減少的維度。

  • keepdim(bool) -輸出張量是否保留了dim

關鍵字參數

out(Tensor,可選的) -輸出張量。

沿維度 dim 計算 input 張量的每一行的 q-th 分位數。

為了計算分位數,我們將 [0, 1] 中的 q 映射到索引 [0, n] 的範圍內,以找到分位數在排序輸入中的位置。如果分位數位於兩個數據點 a < b 之間,索引 ij 按排序順序,則使用線性插值計算結果,如下所示:

a + (b - a) * fraction ,其中 fraction 是計算的分位數索引的小數部分。

如果 q 是一維張量,則輸出的第一個維度表示分位數並且大小等於 q 的大小,其餘維度是歸約後剩下的維度。

注意

默認情況下 dimNone 導致 input 張量在計算之前被展平。

例子:

>>> a = torch.randn(2, 3)
>>> a
tensor([[ 0.0795, -1.2117,  0.9765],
        [ 1.1707,  0.6706,  0.4884]])
>>> q = torch.tensor([0.25, 0.5, 0.75])
>>> torch.quantile(a, q, dim=1, keepdim=True)
tensor([[[-0.5661],
        [ 0.5795]],

        [[ 0.0795],
        [ 0.6706]],

        [[ 0.5280],
        [ 0.9206]]])
>>> torch.quantile(a, q, dim=1, keepdim=True).shape
torch.Size([3, 2, 1])
>>> a = torch.arange(4.)
>>> a
tensor([0., 1., 2., 3.])

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.quantile。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。