用法:
mxnet.ndarray.op.linalg_gemm(A=None, B=None, C=None, transpose_a=_Null, transpose_b=_Null, alpha=_Null, beta=_Null, axis=_Null, out=None, name=None, **kwargs)
- A:(
NDArray
) - 输入矩阵的张量 - B:(
NDArray
) - 输入矩阵的张量 - C:(
NDArray
) - 输入矩阵的张量 - transpose_a:(
boolean
,
optional
,
default=0
) - 与第一个输入 (A) 的转置相乘。 - transpose_b:(
boolean
,
optional
,
default=0
) - 与第二个输入 (B) 的转置相乘。 - alpha:(
double
,
optional
,
default=1
) - 标量因子乘以 A*B。 - beta:(
double
,
optional
,
default=1
) - 标量因子乘以 C。 - axis:(
int
,
optional
,
default='-2'
) - 对应于矩阵行的轴。 - out:(
NDArray
,
optional
) - 输出 NDArray 来保存结果。
- A:(
out:- 此函数的输出。
NDArray 或 NDArray 列表
参数:
返回:
返回类型:
执行一般矩阵乘法和累加。输入是张量
A
,B
,C
,每个维度n >= 2
并且在前导n-2
维度上具有相同的形状。如果
n=2
,则执行 BLAS3 函数gemm
:out
=alpha
*op
(A
) *op
(B
) +beta
*C
这里,
alpha
和beta
是标量参数,而op()
是单位或矩阵转置(取决于transpose_a
,transpose_b
)。如果
n>2
,gemm
对一批矩阵单独执行。矩阵的列索引由张量的最后一个维度给出,行索引由axis
参数指定的轴给出。默认情况下,尾随二维将用于矩阵编码。对于非默认轴参数,执行的操作相当于一系列 swapaxes/gemm/swapaxes 调用。例如,让
A
、B
、C
为 5 维张量。那么 gemm(A
,B
,C
, axis=1) 在没有额外交换轴操作的开销的情况下等价于以下内容:A1 = swapaxes(A, dim1=1, dim2=3) B1 = swapaxes(B, dim1=1, dim2=3) C = swapaxes(C, dim1=1, dim2=3) C = gemm(A1, B1, C) C = swapaxis(C, dim1=1, dim2=3)
当输入数据为 float32 类型且环境变量 MXNET_CUDA_ALLOW_TENSOR_CORE 和 MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION 设置为 1 时,此运算符将尝试使用 pseudo-float16 精度(float32 math with float16 I /O) 精度,以便在合适的 NVIDIA GPU 上使用张量核心。这有时可以显著加快速度。
注意:
该运算符仅支持 float32 和 float64 数据类型。
例子:
Single matrix multiply-add A = [[1.0, 1.0], [1.0, 1.0]] B = [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]] C = [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]] gemm(A, B, C, transpose_b=True, alpha=2.0, beta=10.0) = [[14.0, 14.0, 14.0], [14.0, 14.0, 14.0]] Batch matrix multiply-add A = [[[1.0, 1.0]], [[0.1, 0.1]]] B = [[[1.0, 1.0]], [[0.1, 0.1]]] C = [[[10.0]], [[0.01]]] gemm(A, B, C, transpose_b=True, alpha=2.0 , beta=10.0) = [[[104.0]], [[0.14]]]
相关用法
- Python mxnet.ndarray.op.linalg_gemm2用法及代码示例
- Python mxnet.ndarray.op.linalg_gelqf用法及代码示例
- Python mxnet.ndarray.op.linalg_det用法及代码示例
- Python mxnet.ndarray.op.linalg_potrf用法及代码示例
- Python mxnet.ndarray.op.linalg_potri用法及代码示例
- Python mxnet.ndarray.op.linalg_slogdet用法及代码示例
- Python mxnet.ndarray.op.linalg_trsm用法及代码示例
- Python mxnet.ndarray.op.linalg_sumlogdiag用法及代码示例
- Python mxnet.ndarray.op.linalg_extractdiag用法及代码示例
- Python mxnet.ndarray.op.linalg_extracttrian用法及代码示例
- Python mxnet.ndarray.op.linalg_inverse用法及代码示例
- Python mxnet.ndarray.op.linalg_trmm用法及代码示例
- Python mxnet.ndarray.op.linalg_makediag用法及代码示例
- Python mxnet.ndarray.op.linalg_maketrian用法及代码示例
- Python mxnet.ndarray.op.linalg_syrk用法及代码示例
- Python mxnet.ndarray.op.log_softmax用法及代码示例
- Python mxnet.ndarray.op.uniform用法及代码示例
- Python mxnet.ndarray.op.sample_negative_binomial用法及代码示例
- Python mxnet.ndarray.op.khatri_rao用法及代码示例
- Python mxnet.ndarray.op.unravel_index用法及代码示例
注:本文由纯净天空筛选整理自apache.org大神的英文原创作品 mxnet.ndarray.op.linalg_gemm。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。