当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python mxnet.symbol.linalg.trmm用法及代码示例


用法:

mxnet.symbol.linalg.trmm(A=None, B=None, transpose=_Null, rightside=_Null, lower=_Null, alpha=_Null, name=None, attr=None, out=None, **kwargs)

参数

  • A(Symbol) - 下三角矩阵的张量
  • B(Symbol) - 矩阵张量
  • transpose(boolean, optional, default=0) - 使用转置的三角矩阵
  • rightside(boolean, optional, default=0) - 将三角矩阵从右边乘以非三角矩阵。
  • lower(boolean, optional, default=1) - 如果三角矩阵是下三角矩阵,则为真,如果是上三角矩阵,则为假。
  • alpha(double, optional, default=1) - 应用于结果的标量因子。
  • name(string, optional.) - 结果符号的名称。

返回

结果符号。

返回类型

Symbol

执行与下三角矩阵的乘法。输入是张量 AB ,每个维度 n >= 2 并且在前导 n-2 维度上具有相同的形状。

如果 n=2A 必须是三角形的。操作符执行 BLAS3 函数 trmm

out = alpha * op(A) * B

如果 rightside=False ,或

out = alpha * B * op(A)

如果 rightside=True 。这里,alpha 是一个标量参数,op() 是单位或矩阵转置(取决于 transpose )。

如果 n>2trmm 对所有输入的尾随两个维度分别执行(批处理模式)。

注意

该运算符仅支持 float32 和 float64 数据类型。

例子:

Single triangular matrix multiply
A = [[1.0, 0], [1.0, 1.0]]
B = [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
trmm(A, B, alpha=2.0) = [[2.0, 2.0, 2.0], [4.0, 4.0, 4.0]]

Batch triangular matrix multiply
A = [[[1.0, 0], [1.0, 1.0]], [[1.0, 0], [1.0, 1.0]]]
B = [[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], [[0.5, 0.5, 0.5], [0.5, 0.5, 0.5]]]
trmm(A, B, alpha=2.0) = [[[2.0, 2.0, 2.0], [4.0, 4.0, 4.0]],
                         [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]]

相关用法


注:本文由纯净天空筛选整理自apache.org大神的英文原创作品 mxnet.symbol.linalg.trmm。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。