當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python numpy take_along_axis用法及代碼示例


本文簡要介紹 python 語言中 numpy.take_along_axis 的用法。

用法:

numpy.take_along_axis(arr, indices, axis)

通過匹配一維索引和數據切片從輸入數組中獲取值。

這會遍曆索引和數據數組中沿指定軸定向的匹配 1d 切片,並使用前者在後者中查找值。這些切片可以是不同的長度。

沿軸返回索引的函數(如 argsort argpartition )會為此函數生成合適的索引。

參數

arr ndarray (Ni…, M, Nk…)

源數組

indices ndarray (Ni…, J, Nk…)

arr 的每個 1d 切片的索引。這必須匹配 arr 的維度,但維度 Ni 和 Nj 隻需要針對 arr 進行廣播。

axis int

沿 1d 切片的軸。如果axis為None,則輸入數組被視為首先被展平為1d,以與 sort argsort 保持一致。

返回

輸出:ndarray(Ni…,J,Nk…)

索引的結果。

注意

這相當於(但比)以下使用 ndindex s_ ,將 iikk 設置為索引元組:

Ni, M, Nk = a.shape[:axis], a.shape[axis], a.shape[axis+1:]
J = indices.shape[axis]  # Need not equal M
out = np.empty(Ni + (J,) + Nk)

for ii in ndindex(Ni):
    for kk in ndindex(Nk):
        a_1d       = a      [ii + s_[:,] + kk]
        indices_1d = indices[ii + s_[:,] + kk]
        out_1d     = out    [ii + s_[:,] + kk]
        for j in range(J):
            out_1d[j] = a_1d[indices_1d[j]]

等效地,消除內部循環,最後兩行將是:

out_1d[:] = a_1d[indices_1d]

例子

對於這個示例數組

>>> a = np.array([[10, 30, 20], [60, 40, 50]])

我們可以使用 sort 直接排序,也可以使用 argsort 和這個函數

>>> np.sort(a, axis=1)
array([[10, 20, 30],
       [40, 50, 60]])
>>> ai = np.argsort(a, axis=1); ai
array([[0, 2, 1],
       [1, 2, 0]])
>>> np.take_along_axis(a, ai, axis=1)
array([[10, 20, 30],
       [40, 50, 60]])

如果擴展尺寸,max 和 min 也是如此:

>>> np.expand_dims(np.max(a, axis=1), axis=1)
array([[30],
       [60]])
>>> ai = np.expand_dims(np.argmax(a, axis=1), axis=1)
>>> ai
array([[1],
       [0]])
>>> np.take_along_axis(a, ai, axis=1)
array([[30],
       [60]])

如果我們想同時得到最大值和最小值,我們可以先堆疊索引

>>> ai_min = np.expand_dims(np.argmin(a, axis=1), axis=1)
>>> ai_max = np.expand_dims(np.argmax(a, axis=1), axis=1)
>>> ai = np.concatenate([ai_min, ai_max], axis=1)
>>> ai
array([[0, 1],
       [1, 0]])
>>> np.take_along_axis(a, ai, axis=1)
array([[10, 30],
       [40, 60]])

相關用法


注:本文由純淨天空篩選整理自numpy.org大神的英文原創作品 numpy.take_along_axis。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。