用法:
mxnet.symbol.linalg_gemm(A=None, B=None, C=None, transpose_a=_Null, transpose_b=_Null, alpha=_Null, beta=_Null, axis=_Null, name=None, attr=None, out=None, **kwargs)
- A:(
Symbol
) - 輸入矩陣的張量 - B:(
Symbol
) - 輸入矩陣的張量 - C:(
Symbol
) - 輸入矩陣的張量 - transpose_a:(
boolean
,
optional
,
default=0
) - 與第一個輸入 (A) 的轉置相乘。 - transpose_b:(
boolean
,
optional
,
default=0
) - 與第二個輸入 (B) 的轉置相乘。 - alpha:(
double
,
optional
,
default=1
) - 標量因子乘以 A*B。 - beta:(
double
,
optional
,
default=1
) - 標量因子乘以 C。 - axis:(
int
,
optional
,
default='-2'
) - 對應於矩陣行的軸。 - name:(
string
,
optional.
) - 結果符號的名稱。
- A:(
結果符號。
參數:
返回:
返回類型:
執行一般矩陣乘法和累加。輸入是張量
A
,B
,C
,每個維度n >= 2
並且在前導n-2
維度上具有相同的形狀。如果
n=2
,則執行 BLAS3 函數gemm
:out
=alpha
*op
(A
) *op
(B
) +beta
*C
這裏,
alpha
和beta
是標量參數,而op()
是單位或矩陣轉置(取決於transpose_a
,transpose_b
)。如果
n>2
,gemm
對一批矩陣單獨執行。矩陣的列索引由張量的最後一個維度給出,行索引由axis
參數指定的軸給出。默認情況下,尾隨二維將用於矩陣編碼。對於非默認軸參數,執行的操作相當於一係列 swapaxes/gemm/swapaxes 調用。例如,讓
A
、B
、C
為 5 維張量。那麽 gemm(A
,B
,C
, axis=1) 在沒有額外交換軸操作的開銷的情況下等價於以下內容:A1 = swapaxes(A, dim1=1, dim2=3) B1 = swapaxes(B, dim1=1, dim2=3) C = swapaxes(C, dim1=1, dim2=3) C = gemm(A1, B1, C) C = swapaxis(C, dim1=1, dim2=3)
當輸入數據為 float32 類型且環境變量 MXNET_CUDA_ALLOW_TENSOR_CORE 和 MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION 設置為 1 時,此運算符將嘗試使用 pseudo-float16 精度(float32 math with float16 I /O) 精度,以便在合適的 NVIDIA GPU 上使用張量核心。這有時可以顯著加快速度。
注意:
該運算符僅支持 float32 和 float64 數據類型。
例子:
Single matrix multiply-add A = [[1.0, 1.0], [1.0, 1.0]] B = [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]] C = [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]] gemm(A, B, C, transpose_b=True, alpha=2.0, beta=10.0) = [[14.0, 14.0, 14.0], [14.0, 14.0, 14.0]] Batch matrix multiply-add A = [[[1.0, 1.0]], [[0.1, 0.1]]] B = [[[1.0, 1.0]], [[0.1, 0.1]]] C = [[[10.0]], [[0.01]]] gemm(A, B, C, transpose_b=True, alpha=2.0 , beta=10.0) = [[[104.0]], [[0.14]]]
相關用法
- Python mxnet.symbol.linalg_gemm2用法及代碼示例
- Python mxnet.symbol.linalg_gelqf用法及代碼示例
- Python mxnet.symbol.linalg_potrf用法及代碼示例
- Python mxnet.symbol.linalg_extracttrian用法及代碼示例
- Python mxnet.symbol.linalg_sumlogdiag用法及代碼示例
- Python mxnet.symbol.linalg_potri用法及代碼示例
- Python mxnet.symbol.linalg_extractdiag用法及代碼示例
- Python mxnet.symbol.linalg_syrk用法及代碼示例
- Python mxnet.symbol.linalg_makediag用法及代碼示例
- Python mxnet.symbol.linalg_det用法及代碼示例
- Python mxnet.symbol.linalg_slogdet用法及代碼示例
- Python mxnet.symbol.linalg_maketrian用法及代碼示例
- Python mxnet.symbol.linalg_inverse用法及代碼示例
- Python mxnet.symbol.linalg_trmm用法及代碼示例
- Python mxnet.symbol.linalg_trsm用法及代碼示例
- Python mxnet.symbol.linalg.makediag用法及代碼示例
- Python mxnet.symbol.linalg.extracttrian用法及代碼示例
- Python mxnet.symbol.linalg.syevd用法及代碼示例
- Python mxnet.symbol.linalg.syrk用法及代碼示例
- Python mxnet.symbol.linalg.sumlogdiag用法及代碼示例
注:本文由純淨天空篩選整理自apache.org大神的英文原創作品 mxnet.symbol.linalg_gemm。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。