用法:
mxnet.symbol.linalg.gemm2(A=None, B=None, transpose_a=_Null, transpose_b=_Null, alpha=_Null, axis=_Null, name=None, attr=None, out=None, **kwargs)
- A:(
Symbol
) - 輸入矩陣的張量 - B:(
Symbol
) - 輸入矩陣的張量 - transpose_a:(
boolean
,
optional
,
default=0
) - 與第一個輸入 (A) 的轉置相乘。 - transpose_b:(
boolean
,
optional
,
default=0
) - 與第二個輸入 (B) 的轉置相乘。 - alpha:(
double
,
optional
,
default=1
) - 標量因子乘以 A*B。 - axis:(
int
,
optional
,
default='-2'
) - 對應於矩陣行索引的軸。 - name:(
string
,
optional.
) - 結果符號的名稱。
- A:(
結果符號。
參數:
返回:
返回類型:
執行一般矩陣乘法。輸入是張量
A
,B
,每個維度n >= 2
並且在前導n-2
維度上具有相同的形狀。如果
n=2
,則執行 BLAS3 函數gemm
:out
=alpha
*op
(A
) *op
(B
)這裏
alpha
是一個標量參數,而op()
是單位或矩陣轉置(取決於transpose_a
,transpose_b
)。如果
n>2
,gemm
對一批矩陣單獨執行。矩陣的列索引由張量的最後一個維度給出,行索引由axis
參數指定的軸給出。默認情況下,尾隨二維將用於矩陣編碼。對於非默認軸參數,執行的操作相當於一係列 swapaxes/gemm/swapaxes 調用。例如讓
A
,B
為 5 維張量。然後 gemm(A
,B
, axis=1) 等價於以下內容,而沒有額外的交換軸操作的開銷:A1 = swapaxes(A, dim1=1, dim2=3) B1 = swapaxes(B, dim1=1, dim2=3) C = gemm2(A1, B1) C = swapaxis(C, dim1=1, dim2=3)
當輸入數據為 float32 類型且環境變量 MXNET_CUDA_ALLOW_TENSOR_CORE 和 MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION 設置為 1 時,此運算符將嘗試使用 pseudo-float16 精度(float32 math with float16 I /O) 精度,以便在合適的 NVIDIA GPU 上使用張量核心。這有時可以顯著加快速度。
注意:
該運算符僅支持 float32 和 float64 數據類型。
例子:
Single matrix multiply A = [[1.0, 1.0], [1.0, 1.0]] B = [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]] gemm2(A, B, transpose_b=True, alpha=2.0) = [[4.0, 4.0, 4.0], [4.0, 4.0, 4.0]] Batch matrix multiply A = [[[1.0, 1.0]], [[0.1, 0.1]]] B = [[[1.0, 1.0]], [[0.1, 0.1]]] gemm2(A, B, transpose_b=True, alpha=2.0) = [[[4.0]], [[0.04 ]]]
相關用法
- Python mxnet.symbol.linalg.gemm用法及代碼示例
- Python mxnet.symbol.linalg.gelqf用法及代碼示例
- Python mxnet.symbol.linalg.makediag用法及代碼示例
- Python mxnet.symbol.linalg.extracttrian用法及代碼示例
- Python mxnet.symbol.linalg.syevd用法及代碼示例
- Python mxnet.symbol.linalg.syrk用法及代碼示例
- Python mxnet.symbol.linalg.sumlogdiag用法及代碼示例
- Python mxnet.symbol.linalg.maketrian用法及代碼示例
- Python mxnet.symbol.linalg.potrf用法及代碼示例
- Python mxnet.symbol.linalg.inverse用法及代碼示例
- Python mxnet.symbol.linalg.extractdiag用法及代碼示例
- Python mxnet.symbol.linalg.trsm用法及代碼示例
- Python mxnet.symbol.linalg.trmm用法及代碼示例
- Python mxnet.symbol.linalg.det用法及代碼示例
- Python mxnet.symbol.linalg.potri用法及代碼示例
- Python mxnet.symbol.linalg.slogdet用法及代碼示例
- Python mxnet.symbol.linalg_potrf用法及代碼示例
- Python mxnet.symbol.linalg_gelqf用法及代碼示例
- Python mxnet.symbol.linalg_extracttrian用法及代碼示例
- Python mxnet.symbol.linalg_sumlogdiag用法及代碼示例
注:本文由純淨天空篩選整理自apache.org大神的英文原創作品 mxnet.symbol.linalg.gemm2。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。