本文簡要介紹 python 語言中 numpy.take_along_axis
的用法。
用法:
numpy.take_along_axis(arr, indices, axis)
通過匹配一維索引和數據切片從輸入數組中獲取值。
這會遍曆索引和數據數組中沿指定軸定向的匹配 1d 切片,並使用前者在後者中查找值。這些切片可以是不同的長度。
沿軸返回索引的函數(如
argsort
和argpartition
)會為此函數生成合適的索引。- 輸出:ndarray(Ni…,J,Nk…)
索引的結果。
參數:
返回:
注意:
這相當於(但比)以下使用
ndindex
和s_
,將ii
和kk
設置為索引元組:Ni, M, Nk = a.shape[:axis], a.shape[axis], a.shape[axis+1:] J = indices.shape[axis] # Need not equal M out = np.empty(Ni + (J,) + Nk) for ii in ndindex(Ni): for kk in ndindex(Nk): a_1d = a [ii + s_[:,] + kk] indices_1d = indices[ii + s_[:,] + kk] out_1d = out [ii + s_[:,] + kk] for j in range(J): out_1d[j] = a_1d[indices_1d[j]]
等效地,消除內部循環,最後兩行將是:
out_1d[:] = a_1d[indices_1d]
例子:
對於這個示例數組
>>> a = np.array([[10, 30, 20], [60, 40, 50]])
我們可以使用 sort 直接排序,也可以使用 argsort 和這個函數
>>> np.sort(a, axis=1) array([[10, 20, 30], [40, 50, 60]]) >>> ai = np.argsort(a, axis=1); ai array([[0, 2, 1], [1, 2, 0]]) >>> np.take_along_axis(a, ai, axis=1) array([[10, 20, 30], [40, 50, 60]])
如果擴展尺寸,max 和 min 也是如此:
>>> np.expand_dims(np.max(a, axis=1), axis=1) array([[30], [60]]) >>> ai = np.expand_dims(np.argmax(a, axis=1), axis=1) >>> ai array([[1], [0]]) >>> np.take_along_axis(a, ai, axis=1) array([[30], [60]])
如果我們想同時得到最大值和最小值,我們可以先堆疊索引
>>> ai_min = np.expand_dims(np.argmin(a, axis=1), axis=1) >>> ai_max = np.expand_dims(np.argmax(a, axis=1), axis=1) >>> ai = np.concatenate([ai_min, ai_max], axis=1) >>> ai array([[0, 1], [1, 0]]) >>> np.take_along_axis(a, ai, axis=1) array([[10, 30], [40, 60]])
相關用法
- Python numpy take用法及代碼示例
- Python numpy tanh用法及代碼示例
- Python numpy tan用法及代碼示例
- Python numpy trim_zeros用法及代碼示例
- Python numpy testing.rundocs用法及代碼示例
- Python numpy testing.assert_warns用法及代碼示例
- Python numpy trace用法及代碼示例
- Python numpy testing.assert_array_almost_equal_nulp用法及代碼示例
- Python numpy tri用法及代碼示例
- Python numpy testing.assert_array_less用法及代碼示例
- Python numpy testing.assert_raises用法及代碼示例
- Python numpy true_divide用法及代碼示例
- Python numpy transpose用法及代碼示例
- Python numpy testing.assert_almost_equal用法及代碼示例
- Python numpy tile用法及代碼示例
- Python numpy testing.assert_approx_equal用法及代碼示例
- Python numpy testing.assert_allclose用法及代碼示例
- Python numpy testing.decorators.slow用法及代碼示例
- Python numpy trapz用法及代碼示例
- Python numpy testing.suppress_warnings用法及代碼示例
- Python numpy testing.assert_string_equal用法及代碼示例
- Python numpy testing.run_module_suite用法及代碼示例
- Python numpy testing.assert_array_max_ulp用法及代碼示例
- Python numpy testing.assert_equal用法及代碼示例
- Python numpy triu_indices用法及代碼示例
注:本文由純淨天空篩選整理自numpy.org大神的英文原創作品 numpy.take_along_axis。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。