当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python mxnet.symbol.linalg.gemm2用法及代码示例


用法:

mxnet.symbol.linalg.gemm2(A=None, B=None, transpose_a=_Null, transpose_b=_Null, alpha=_Null, axis=_Null, name=None, attr=None, out=None, **kwargs)

参数

  • A(Symbol) - 输入矩阵的张量
  • B(Symbol) - 输入矩阵的张量
  • transpose_a(boolean, optional, default=0) - 与第一个输入 (A) 的转置相乘。
  • transpose_b(boolean, optional, default=0) - 与第二个输入 (B) 的转置相乘。
  • alpha(double, optional, default=1) - 标量因子乘以 A*B。
  • axis(int, optional, default='-2') - 对应于矩阵行索引的轴。
  • name(string, optional.) - 结果符号的名称。

返回

结果符号。

返回类型

Symbol

执行一般矩阵乘法。输入是张量 AB ,每个维度 n >= 2 并且在前导 n-2 维度上具有相同的形状。

如果 n=2 ,则执行 BLAS3 函数 gemm

out = alpha * op(A) * op(B)

这里 alpha 是一个标量参数,而 op() 是单位或矩阵转置(取决于 transpose_atranspose_b )。

如果 n>2 , gemm 对一批矩阵单独执行。矩阵的列索引由张量的最后一个维度给出,行索引由 axis 参数指定的轴给出。默认情况下,尾随二维将用于矩阵编码。

对于非默认轴参数,执行的操作相当于一系列 swapaxes/gemm/swapaxes 调用。例如让 A , B 为 5 维张量。然后 gemm( A , B , axis=1) 等价于以下内容,而没有额外的交换轴操作的开销:

A1 = swapaxes(A, dim1=1, dim2=3)
B1 = swapaxes(B, dim1=1, dim2=3)
C = gemm2(A1, B1)
C = swapaxis(C, dim1=1, dim2=3)

当输入数据为 float32 类型且环境变量 MXNET_CUDA_ALLOW_TENSOR_CORE 和 MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION 设置为 1 时,此运算符将尝试使用 pseudo-float16 精度(float32 math with float16 I /O) 精度,以便在合适的 NVIDIA GPU 上使用张量核心。这有时可以显著加快速度。

注意

该运算符仅支持 float32 和 float64 数据类型。

例子:

Single matrix multiply
A = [[1.0, 1.0], [1.0, 1.0]]
B = [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]
gemm2(A, B, transpose_b=True, alpha=2.0)
         = [[4.0, 4.0, 4.0], [4.0, 4.0, 4.0]]

Batch matrix multiply
A = [[[1.0, 1.0]], [[0.1, 0.1]]]
B = [[[1.0, 1.0]], [[0.1, 0.1]]]
gemm2(A, B, transpose_b=True, alpha=2.0)
        = [[[4.0]], [[0.04 ]]]

相关用法


注:本文由纯净天空筛选整理自apache.org大神的英文原创作品 mxnet.symbol.linalg.gemm2。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。