当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch triangular_solve用法及代码示例


本文简要介绍python语言中 torch.triangular_solve 的用法。

用法:

torch.triangular_solve(b, A, upper=True, transpose=False, unitriangular=False, *, out=None)

参数

  • b(Tensor) -大小为 的多个右侧,其中 是多个批次维度的零

  • A(Tensor) -大小为 的输入三角系数矩阵,其中 是零个或多个批处理维度

  • upper(bool,可选的) -是否求解上三角方程组(默认)或下三角方程组。默认值:True

  • transpose(bool,可选的) - 在发送到求解器之前是否应转置。默认值:False

  • unitriangular(bool,可选的) - 是否为单位三角形。如果为 True,则 的对角线元素被假定为 1 并且未从 引用。默认值:False

关键字参数

out((Tensor,Tensor),可选的) -将输出写入的两个张量的元组。如果 None 则忽略。默认值:None

返回

命名元组(solution, cloned_coefficient) 其中cloned_coefficient 的克隆,solution 的解(或方程组的任何变体,取决于关键字参数。)

求解具有三角系数矩阵 和多个右侧 的方程组。

特别是,求解 并假设 是具有默认关键字参数的上三角函数。

torch.triangular_solve(b, A) 可以接受 2D 输入 b, A 或作为 2D 矩阵批次的输入。如果输入是批处理,则返回批处理输出X

如果 A 的对角线包含零或非常接近零的元素和 unitriangular = False(默认),或者如果输入矩阵条件不佳,则结果可能包含 NaN s。

支持 float、double、cfloat 和 cdouble 数据类型的输入。

例子:

>>> A = torch.randn(2, 2).triu()
>>> A
tensor([[ 1.1527, -1.0753],
        [ 0.0000,  0.7986]])
>>> b = torch.randn(2, 3)
>>> b
tensor([[-0.0210,  2.3513, -1.5492],
        [ 1.5429,  0.7403, -1.0243]])
>>> torch.triangular_solve(b, A)
torch.return_types.triangular_solve(
solution=tensor([[ 1.7841,  2.9046, -2.5405],
        [ 1.9320,  0.9270, -1.2826]]),
cloned_coefficient=tensor([[ 1.1527, -1.0753],
        [ 0.0000,  0.7986]]))

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.triangular_solve。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。