用法:
mxnet.symbol.linalg.gemm2(A=None, B=None, transpose_a=_Null, transpose_b=_Null, alpha=_Null, axis=_Null, name=None, attr=None, out=None, **kwargs)
- A:(
Symbol
) - 输入矩阵的张量 - B:(
Symbol
) - 输入矩阵的张量 - transpose_a:(
boolean
,
optional
,
default=0
) - 与第一个输入 (A) 的转置相乘。 - transpose_b:(
boolean
,
optional
,
default=0
) - 与第二个输入 (B) 的转置相乘。 - alpha:(
double
,
optional
,
default=1
) - 标量因子乘以 A*B。 - axis:(
int
,
optional
,
default='-2'
) - 对应于矩阵行索引的轴。 - name:(
string
,
optional.
) - 结果符号的名称。
- A:(
结果符号。
参数:
返回:
返回类型:
执行一般矩阵乘法。输入是张量
A
,B
,每个维度n >= 2
并且在前导n-2
维度上具有相同的形状。如果
n=2
,则执行 BLAS3 函数gemm
:out
=alpha
*op
(A
) *op
(B
)这里
alpha
是一个标量参数,而op()
是单位或矩阵转置(取决于transpose_a
,transpose_b
)。如果
n>2
,gemm
对一批矩阵单独执行。矩阵的列索引由张量的最后一个维度给出,行索引由axis
参数指定的轴给出。默认情况下,尾随二维将用于矩阵编码。对于非默认轴参数,执行的操作相当于一系列 swapaxes/gemm/swapaxes 调用。例如让
A
,B
为 5 维张量。然后 gemm(A
,B
, axis=1) 等价于以下内容,而没有额外的交换轴操作的开销:A1 = swapaxes(A, dim1=1, dim2=3) B1 = swapaxes(B, dim1=1, dim2=3) C = gemm2(A1, B1) C = swapaxis(C, dim1=1, dim2=3)
当输入数据为 float32 类型且环境变量 MXNET_CUDA_ALLOW_TENSOR_CORE 和 MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION 设置为 1 时,此运算符将尝试使用 pseudo-float16 精度(float32 math with float16 I /O) 精度,以便在合适的 NVIDIA GPU 上使用张量核心。这有时可以显著加快速度。
注意:
该运算符仅支持 float32 和 float64 数据类型。
例子:
Single matrix multiply A = [[1.0, 1.0], [1.0, 1.0]] B = [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]] gemm2(A, B, transpose_b=True, alpha=2.0) = [[4.0, 4.0, 4.0], [4.0, 4.0, 4.0]] Batch matrix multiply A = [[[1.0, 1.0]], [[0.1, 0.1]]] B = [[[1.0, 1.0]], [[0.1, 0.1]]] gemm2(A, B, transpose_b=True, alpha=2.0) = [[[4.0]], [[0.04 ]]]
相关用法
- Python mxnet.symbol.linalg.gemm用法及代码示例
- Python mxnet.symbol.linalg.gelqf用法及代码示例
- Python mxnet.symbol.linalg.makediag用法及代码示例
- Python mxnet.symbol.linalg.extracttrian用法及代码示例
- Python mxnet.symbol.linalg.syevd用法及代码示例
- Python mxnet.symbol.linalg.syrk用法及代码示例
- Python mxnet.symbol.linalg.sumlogdiag用法及代码示例
- Python mxnet.symbol.linalg.maketrian用法及代码示例
- Python mxnet.symbol.linalg.potrf用法及代码示例
- Python mxnet.symbol.linalg.inverse用法及代码示例
- Python mxnet.symbol.linalg.extractdiag用法及代码示例
- Python mxnet.symbol.linalg.trsm用法及代码示例
- Python mxnet.symbol.linalg.trmm用法及代码示例
- Python mxnet.symbol.linalg.det用法及代码示例
- Python mxnet.symbol.linalg.potri用法及代码示例
- Python mxnet.symbol.linalg.slogdet用法及代码示例
- Python mxnet.symbol.linalg_potrf用法及代码示例
- Python mxnet.symbol.linalg_gelqf用法及代码示例
- Python mxnet.symbol.linalg_extracttrian用法及代码示例
- Python mxnet.symbol.linalg_sumlogdiag用法及代码示例
注:本文由纯净天空筛选整理自apache.org大神的英文原创作品 mxnet.symbol.linalg.gemm2。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。