用法:
mxnet.ndarray.op.linalg_gemm2(A=None, B=None, transpose_a=_Null, transpose_b=_Null, alpha=_Null, axis=_Null, out=None, name=None, **kwargs)
- A:(
NDArray
) - 輸入矩陣的張量 - B:(
NDArray
) - 輸入矩陣的張量 - transpose_a:(
boolean
,
optional
,
default=0
) - 與第一個輸入 (A) 的轉置相乘。 - transpose_b:(
boolean
,
optional
,
default=0
) - 與第二個輸入 (B) 的轉置相乘。 - alpha:(
double
,
optional
,
default=1
) - 標量因子乘以 A*B。 - axis:(
int
,
optional
,
default='-2'
) - 對應於矩陣行索引的軸。 - out:(
NDArray
,
optional
) - 輸出 NDArray 來保存結果。
- A:(
out:- 此函數的輸出。
NDArray 或 NDArray 列表
參數:
返回:
返回類型:
執行一般矩陣乘法。輸入是張量
A
,B
,每個維度n >= 2
並且在前導n-2
維度上具有相同的形狀。如果
n=2
,則執行 BLAS3 函數gemm
:out
=alpha
*op
(A
) *op
(B
)這裏
alpha
是一個標量參數,而op()
是單位或矩陣轉置(取決於transpose_a
,transpose_b
)。如果
n>2
,gemm
對一批矩陣單獨執行。矩陣的列索引由張量的最後一個維度給出,行索引由axis
參數指定的軸給出。默認情況下,尾隨二維將用於矩陣編碼。對於非默認軸參數,執行的操作相當於一係列 swapaxes/gemm/swapaxes 調用。例如讓
A
,B
為 5 維張量。然後 gemm(A
,B
, axis=1) 等價於以下內容,而沒有額外的交換軸操作的開銷:A1 = swapaxes(A, dim1=1, dim2=3) B1 = swapaxes(B, dim1=1, dim2=3) C = gemm2(A1, B1) C = swapaxis(C, dim1=1, dim2=3)
當輸入數據為 float32 類型且環境變量 MXNET_CUDA_ALLOW_TENSOR_CORE 和 MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION 設置為 1 時,此運算符將嘗試使用 pseudo-float16 精度(float32 math with float16 I /O) 精度,以便在合適的 NVIDIA GPU 上使用張量核心。這有時可以顯著加快速度。
注意:
該運算符僅支持 float32 和 float64 數據類型。
例子:
Single matrix multiply A = [[1.0, 1.0], [1.0, 1.0]] B = [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]] gemm2(A, B, transpose_b=True, alpha=2.0) = [[4.0, 4.0, 4.0], [4.0, 4.0, 4.0]] Batch matrix multiply A = [[[1.0, 1.0]], [[0.1, 0.1]]] B = [[[1.0, 1.0]], [[0.1, 0.1]]] gemm2(A, B, transpose_b=True, alpha=2.0) = [[[4.0]], [[0.04 ]]]
相關用法
- Python mxnet.ndarray.op.linalg_gemm用法及代碼示例
- Python mxnet.ndarray.op.linalg_gelqf用法及代碼示例
- Python mxnet.ndarray.op.linalg_det用法及代碼示例
- Python mxnet.ndarray.op.linalg_potrf用法及代碼示例
- Python mxnet.ndarray.op.linalg_potri用法及代碼示例
- Python mxnet.ndarray.op.linalg_slogdet用法及代碼示例
- Python mxnet.ndarray.op.linalg_trsm用法及代碼示例
- Python mxnet.ndarray.op.linalg_sumlogdiag用法及代碼示例
- Python mxnet.ndarray.op.linalg_extractdiag用法及代碼示例
- Python mxnet.ndarray.op.linalg_extracttrian用法及代碼示例
- Python mxnet.ndarray.op.linalg_inverse用法及代碼示例
- Python mxnet.ndarray.op.linalg_trmm用法及代碼示例
- Python mxnet.ndarray.op.linalg_makediag用法及代碼示例
- Python mxnet.ndarray.op.linalg_maketrian用法及代碼示例
- Python mxnet.ndarray.op.linalg_syrk用法及代碼示例
- Python mxnet.ndarray.op.log_softmax用法及代碼示例
- Python mxnet.ndarray.op.uniform用法及代碼示例
- Python mxnet.ndarray.op.sample_negative_binomial用法及代碼示例
- Python mxnet.ndarray.op.khatri_rao用法及代碼示例
- Python mxnet.ndarray.op.unravel_index用法及代碼示例
注:本文由純淨天空篩選整理自apache.org大神的英文原創作品 mxnet.ndarray.op.linalg_gemm2。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。