当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python mxnet.symbol.linalg.gemm用法及代码示例


用法:

mxnet.symbol.linalg.gemm(A=None, B=None, C=None, transpose_a=_Null, transpose_b=_Null, alpha=_Null, beta=_Null, axis=_Null, name=None, attr=None, out=None, **kwargs)

参数

  • A(Symbol) - 输入矩阵的张量
  • B(Symbol) - 输入矩阵的张量
  • C(Symbol) - 输入矩阵的张量
  • transpose_a(boolean, optional, default=0) - 与第一个输入 (A) 的转置相乘。
  • transpose_b(boolean, optional, default=0) - 与第二个输入 (B) 的转置相乘。
  • alpha(double, optional, default=1) - 标量因子乘以 A*B。
  • beta(double, optional, default=1) - 标量因子乘以 C。
  • axis(int, optional, default='-2') - 对应于矩阵行的轴。
  • name(string, optional.) - 结果符号的名称。

返回

结果符号。

返回类型

Symbol

执行一般矩阵乘法和累加。输入是张量 ABC ,每个维度 n >= 2 并且在前导 n-2 维度上具有相同的形状。

如果 n=2 ,则执行 BLAS3 函数 gemm

out = alpha * op(A) * op(B) + beta * C

这里,alphabeta 是标量参数,而 op() 是单位或矩阵转置(取决于 transpose_atranspose_b )。

如果 n>2 , gemm 对一批矩阵单独执行。矩阵的列索引由张量的最后一个维度给出,行索引由 axis 参数指定的轴给出。默认情况下,尾随二维将用于矩阵编码。

对于非默认轴参数,执行的操作相当于一系列 swapaxes/gemm/swapaxes 调用。例如,让 ABC 为 5 维张量。那么 gemm( A , B , C , axis=1) 在没有额外交换轴操作的开销的情况下等价于以下内容:

A1 = swapaxes(A, dim1=1, dim2=3)
B1 = swapaxes(B, dim1=1, dim2=3)
C = swapaxes(C, dim1=1, dim2=3)
C = gemm(A1, B1, C)
C = swapaxis(C, dim1=1, dim2=3)

当输入数据为 float32 类型且环境变量 MXNET_CUDA_ALLOW_TENSOR_CORE 和 MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION 设置为 1 时,此运算符将尝试使用 pseudo-float16 精度(float32 math with float16 I /O) 精度,以便在合适的 NVIDIA GPU 上使用张量核心。这有时可以显著加快速度。

注意

该运算符仅支持 float32 和 float64 数据类型。

例子:

Single matrix multiply-add
A = [[1.0, 1.0], [1.0, 1.0]]
B = [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]
C = [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
gemm(A, B, C, transpose_b=True, alpha=2.0, beta=10.0)
        = [[14.0, 14.0, 14.0], [14.0, 14.0, 14.0]]

Batch matrix multiply-add
A = [[[1.0, 1.0]], [[0.1, 0.1]]]
B = [[[1.0, 1.0]], [[0.1, 0.1]]]
C = [[[10.0]], [[0.01]]]
gemm(A, B, C, transpose_b=True, alpha=2.0 , beta=10.0)
        = [[[104.0]], [[0.14]]]

相关用法


注:本文由纯净天空筛选整理自apache.org大神的英文原创作品 mxnet.symbol.linalg.gemm。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。