本文简要介绍python语言中 torch.matmul
的用法。
用法:
torch.matmul(input, other, *, out=None) → Tensor
两个张量的矩阵乘积。
行为取决于张量的维度,如下所示:
如果两个张量都是一维的,则返回点积(标量)。
如果两个参数都是二维的,则返回 matrix-matrix 乘积。
如果第一个参数是一维的,第二个参数是二维的,为了矩阵乘法的目的,在它的维数前面加上一个 1。在矩阵相乘之后,前置维度被移除。
如果第一个参数是二维的,第二个参数是一维的,则返回 matrix-vector 乘积。
如果两个参数至少为一维且至少一个参数为 N 维(其中 N > 2),则返回批处理矩阵乘法。如果第一个参数是一维的,则将 1 添加到其维度,以便批量矩阵相乘并在之后删除。如果第二个参数是一维的,则将 1 附加到其维度以用于批量矩阵倍数并在之后删除。非矩阵(即批次)维度是广播的(因此必须是可广播的)。例如,如果
input
是 张量并且other
是 张量,则out
将是 张量。请注意,广播逻辑在确定输入是否可广播时仅查看批处理维度,而不是矩阵维度。例如,如果
input
是 张量,而other
是 张量,则即使最后两个维度(即矩阵维度)不同,这些输入对于广播也是有效的。out
将是一个 张量。
该运算符支持 TensorFloat32。
注意
此函数的一维点积版本不支持
out
参数。例子:
>>> # vector x vector >>> tensor1 = torch.randn(3) >>> tensor2 = torch.randn(3) >>> torch.matmul(tensor1, tensor2).size() torch.Size([]) >>> # matrix x vector >>> tensor1 = torch.randn(3, 4) >>> tensor2 = torch.randn(4) >>> torch.matmul(tensor1, tensor2).size() torch.Size([3]) >>> # batched matrix x broadcasted vector >>> tensor1 = torch.randn(10, 3, 4) >>> tensor2 = torch.randn(4) >>> torch.matmul(tensor1, tensor2).size() torch.Size([10, 3]) >>> # batched matrix x batched matrix >>> tensor1 = torch.randn(10, 3, 4) >>> tensor2 = torch.randn(10, 4, 5) >>> torch.matmul(tensor1, tensor2).size() torch.Size([10, 3, 5]) >>> # batched matrix x broadcasted matrix >>> tensor1 = torch.randn(10, 3, 4) >>> tensor2 = torch.randn(4, 5) >>> torch.matmul(tensor1, tensor2).size() torch.Size([10, 3, 5])
相关用法
- Python PyTorch matrix_rank用法及代码示例
- Python PyTorch matrix_exp用法及代码示例
- Python PyTorch matrix_power用法及代码示例
- Python PyTorch matrix_norm用法及代码示例
- Python PyTorch max用法及代码示例
- Python PyTorch maximum用法及代码示例
- Python PyTorch masked_select用法及代码示例
- Python PyTorch maskrcnn_resnet50_fpn用法及代码示例
- Python PyTorch make_tensor用法及代码示例
- Python PyTorch monitored_barrier用法及代码示例
- Python PyTorch mean用法及代码示例
- Python PyTorch multinomial用法及代码示例
- Python PyTorch meshgrid用法及代码示例
- Python PyTorch mm用法及代码示例
- Python PyTorch mv用法及代码示例
- Python PyTorch min用法及代码示例
- Python PyTorch msort用法及代码示例
- Python PyTorch mode用法及代码示例
- Python PyTorch movedim用法及代码示例
- Python PyTorch minimum用法及代码示例
- Python PyTorch multi_dot用法及代码示例
- Python PyTorch mul用法及代码示例
- Python PyTorch movielens_25m用法及代码示例
- Python PyTorch multigammaln用法及代码示例
- Python PyTorch movielens_20m用法及代码示例
注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.matmul。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。