用法:
mxnet.symbol.linalg_gemm(A=None, B=None, C=None, transpose_a=_Null, transpose_b=_Null, alpha=_Null, beta=_Null, axis=_Null, name=None, attr=None, out=None, **kwargs)- A:(
Symbol) - 输入矩阵的张量 - B:(
Symbol) - 输入矩阵的张量 - C:(
Symbol) - 输入矩阵的张量 - transpose_a:(
boolean,optional,default=0) - 与第一个输入 (A) 的转置相乘。 - transpose_b:(
boolean,optional,default=0) - 与第二个输入 (B) 的转置相乘。 - alpha:(
double,optional,default=1) - 标量因子乘以 A*B。 - beta:(
double,optional,default=1) - 标量因子乘以 C。 - axis:(
int,optional,default='-2') - 对应于矩阵行的轴。 - name:(
string,optional.) - 结果符号的名称。
- A:(
结果符号。
参数:
返回:
返回类型:
执行一般矩阵乘法和累加。输入是张量
A,B,C,每个维度n >= 2并且在前导n-2维度上具有相同的形状。如果
n=2,则执行 BLAS3 函数gemm:out=alpha*op(A) *op(B) +beta*C这里,
alpha和beta是标量参数,而op()是单位或矩阵转置(取决于transpose_a,transpose_b)。如果
n>2,gemm对一批矩阵单独执行。矩阵的列索引由张量的最后一个维度给出,行索引由axis参数指定的轴给出。默认情况下,尾随二维将用于矩阵编码。对于非默认轴参数,执行的操作相当于一系列 swapaxes/gemm/swapaxes 调用。例如,让
A、B、C为 5 维张量。那么 gemm(A,B,C, axis=1) 在没有额外交换轴操作的开销的情况下等价于以下内容:A1 = swapaxes(A, dim1=1, dim2=3) B1 = swapaxes(B, dim1=1, dim2=3) C = swapaxes(C, dim1=1, dim2=3) C = gemm(A1, B1, C) C = swapaxis(C, dim1=1, dim2=3)当输入数据为 float32 类型且环境变量 MXNET_CUDA_ALLOW_TENSOR_CORE 和 MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION 设置为 1 时,此运算符将尝试使用 pseudo-float16 精度(float32 math with float16 I /O) 精度,以便在合适的 NVIDIA GPU 上使用张量核心。这有时可以显著加快速度。
注意:
该运算符仅支持 float32 和 float64 数据类型。
例子:
Single matrix multiply-add A = [[1.0, 1.0], [1.0, 1.0]] B = [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]] C = [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]] gemm(A, B, C, transpose_b=True, alpha=2.0, beta=10.0) = [[14.0, 14.0, 14.0], [14.0, 14.0, 14.0]] Batch matrix multiply-add A = [[[1.0, 1.0]], [[0.1, 0.1]]] B = [[[1.0, 1.0]], [[0.1, 0.1]]] C = [[[10.0]], [[0.01]]] gemm(A, B, C, transpose_b=True, alpha=2.0 , beta=10.0) = [[[104.0]], [[0.14]]]
相关用法
- Python mxnet.symbol.linalg_gemm2用法及代码示例
- Python mxnet.symbol.linalg_gelqf用法及代码示例
- Python mxnet.symbol.linalg_potrf用法及代码示例
- Python mxnet.symbol.linalg_extracttrian用法及代码示例
- Python mxnet.symbol.linalg_sumlogdiag用法及代码示例
- Python mxnet.symbol.linalg_potri用法及代码示例
- Python mxnet.symbol.linalg_extractdiag用法及代码示例
- Python mxnet.symbol.linalg_syrk用法及代码示例
- Python mxnet.symbol.linalg_makediag用法及代码示例
- Python mxnet.symbol.linalg_det用法及代码示例
- Python mxnet.symbol.linalg_slogdet用法及代码示例
- Python mxnet.symbol.linalg_maketrian用法及代码示例
- Python mxnet.symbol.linalg_inverse用法及代码示例
- Python mxnet.symbol.linalg_trmm用法及代码示例
- Python mxnet.symbol.linalg_trsm用法及代码示例
- Python mxnet.symbol.linalg.makediag用法及代码示例
- Python mxnet.symbol.linalg.extracttrian用法及代码示例
- Python mxnet.symbol.linalg.syevd用法及代码示例
- Python mxnet.symbol.linalg.syrk用法及代码示例
- Python mxnet.symbol.linalg.sumlogdiag用法及代码示例
注:本文由纯净天空筛选整理自apache.org大神的英文原创作品 mxnet.symbol.linalg_gemm。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。
