当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch multi_dot用法及代码示例


本文简要介绍python语言中 torch.linalg.multi_dot 的用法。

用法:

torch.linalg.multi_dot(tensors, *, out=None)

参数

tensors(Sequence[Tensor]) -两个或多个张量相乘。第一个和最后一个张量可以是 1D 或 2D。每个其他张量必须是 2D 的。

关键字参数

out(Tensor,可选的) -输出张量。如果 None 则忽略。默认值:None

通过重新排序乘法来有效地将两个或多个矩阵相乘,以便执行最少的算术运算。

支持 float、double、cfloat 和 cdouble dtypes 的输入。此函数不支持批量输入。

tensors 中的每个张量都必须是 2D,但第一个和最后一个可能是 1D 的除外。如果第一个张量是形状为 (n,) 的一维向量,则将其视为形状为 (1, n) 的行向量,类似地,如果最后一个张量是形状为 (n,) 的一维向量,则将其视为形状为列向量(n, 1)

如果第一个和最后一个张量是矩阵,则输出将是一个矩阵。但是,如果其中任何一个是一维向量,则输出将是一维向量。

numpy.linalg.multi_dot 的区别:

  • numpy.linalg.multi_dot 不同,第一个和最后一个张量必须是 1D 或 2D,而 NumPy 允许它们是 nD

警告

此函数不广播。

注意

此函数通过在计算最佳矩阵乘法顺序后链接 torch.mm() 调用来实现。

注意

将两个矩阵与形状 (a, b)(b, c) 相乘的成本是 a * b * c 。给定矩阵 A , B , C 形状分别为 (10, 100) , (100, 5) , (5, 50) ,我们可以计算不同乘法顺序的成本如下:

在这种情况下,先将 AB 相乘,然后再乘以 C 会快 10 倍。

例子:

>>> from torch.linalg import multi_dot

>>> multi_dot([torch.tensor([1, 2]), torch.tensor([2, 3])])
tensor(8)
>>> multi_dot([torch.tensor([[1, 2]]), torch.tensor([2, 3])])
tensor([8])
>>> multi_dot([torch.tensor([[1, 2]]), torch.tensor([[2], [3]])])
tensor([[8]])

>>> A = torch.arange(2 * 3).view(2, 3)
>>> B = torch.arange(3 * 2).view(3, 2)
>>> C = torch.arange(2 * 2).view(2, 2)
>>> multi_dot((A, B, C))
tensor([[ 26,  49],
        [ 80, 148]])

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.linalg.multi_dot。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。