當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python mxnet.ndarray.op.linalg_gemm2用法及代碼示例


用法:

mxnet.ndarray.op.linalg_gemm2(A=None, B=None, transpose_a=_Null, transpose_b=_Null, alpha=_Null, axis=_Null, out=None, name=None, **kwargs)

參數

  • A(NDArray) - 輸入矩陣的張量
  • B(NDArray) - 輸入矩陣的張量
  • transpose_a(boolean, optional, default=0) - 與第一個輸入 (A) 的轉置相乘。
  • transpose_b(boolean, optional, default=0) - 與第二個輸入 (B) 的轉置相乘。
  • alpha(double, optional, default=1) - 標量因子乘以 A*B。
  • axis(int, optional, default='-2') - 對應於矩陣行索引的軸。
  • out(NDArray, optional) - 輸出 NDArray 來保存結果。

返回

out- 此函數的輸出。

返回類型

NDArray 或 NDArray 列表

執行一般矩陣乘法。輸入是張量 AB ,每個維度 n >= 2 並且在前導 n-2 維度上具有相同的形狀。

如果 n=2 ,則執行 BLAS3 函數 gemm

out = alpha * op(A) * op(B)

這裏 alpha 是一個標量參數,而 op() 是單位或矩陣轉置(取決於 transpose_atranspose_b )。

如果 n>2 , gemm 對一批矩陣單獨執行。矩陣的列索引由張量的最後一個維度給出,行索引由 axis 參數指定的軸給出。默認情況下,尾隨二維將用於矩陣編碼。

對於非默認軸參數,執行的操作相當於一係列 swapaxes/gemm/swapaxes 調用。例如讓 A , B 為 5 維張量。然後 gemm( A , B , axis=1) 等價於以下內容,而沒有額外的交換軸操作的開銷:

A1 = swapaxes(A, dim1=1, dim2=3)
B1 = swapaxes(B, dim1=1, dim2=3)
C = gemm2(A1, B1)
C = swapaxis(C, dim1=1, dim2=3)

當輸入數據為 float32 類型且環境變量 MXNET_CUDA_ALLOW_TENSOR_CORE 和 MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION 設置為 1 時,此運算符將嘗試使用 pseudo-float16 精度(float32 math with float16 I /O) 精度,以便在合適的 NVIDIA GPU 上使用張量核心。這有時可以顯著加快速度。

注意

該運算符僅支持 float32 和 float64 數據類型。

例子:

Single matrix multiply
A = [[1.0, 1.0], [1.0, 1.0]]
B = [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]
gemm2(A, B, transpose_b=True, alpha=2.0)
         = [[4.0, 4.0, 4.0], [4.0, 4.0, 4.0]]

Batch matrix multiply
A = [[[1.0, 1.0]], [[0.1, 0.1]]]
B = [[[1.0, 1.0]], [[0.1, 0.1]]]
gemm2(A, B, transpose_b=True, alpha=2.0)
        = [[[4.0]], [[0.04 ]]]

相關用法


注:本文由純淨天空篩選整理自apache.org大神的英文原創作品 mxnet.ndarray.op.linalg_gemm2。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。