当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch mm用法及代码示例


本文简要介绍python语言中 torch.sparse.mm 的用法。

用法:

torch.sparse.mm(mat1, mat2)

参数

  • mat1(SparseTensor) -第一个要相乘的稀疏矩阵

  • mat2(Tensor) -要相乘的第二个矩阵,可以是稀疏的或密集的

执行稀疏矩阵 mat1 和(稀疏或跨步)矩阵 mat2 的矩阵乘法。与 torch.mm() 类似,如果 mat1 张量,mat2 张量,则 out 将是 张量。 mat1 需要有 sparse_dim = 2 。此函数还支持两个矩阵的后向。请注意,mat1 的梯度是一个合并的稀疏张量。

形状:

该函数的输出张量格式如下: - 稀疏 x 稀疏 -> 稀疏 - 稀疏 x 密集 -> 密集

例子:

>>> a = torch.randn(2, 3).to_sparse().requires_grad_(True)
>>> a
tensor(indices=tensor([[0, 0, 0, 1, 1, 1],
                       [0, 1, 2, 0, 1, 2]]),
       values=tensor([ 1.5901,  0.0183, -0.6146,  1.8061, -0.0112,  0.6302]),
       size=(2, 3), nnz=6, layout=torch.sparse_coo, requires_grad=True)

>>> b = torch.randn(3, 2, requires_grad=True)
>>> b
tensor([[-0.6479,  0.7874],
        [-1.2056,  0.5641],
        [-1.1716, -0.9923]], requires_grad=True)

>>> y = torch.sparse.mm(a, b)
>>> y
tensor([[-0.3323,  1.8723],
        [-1.8951,  0.7904]], grad_fn=<SparseAddmmBackward>)
>>> y.sum().backward()
>>> a.grad
tensor(indices=tensor([[0, 0, 0, 1, 1, 1],
                       [0, 1, 2, 0, 1, 2]]),
       values=tensor([ 0.1394, -0.6415, -2.1639,  0.1394, -0.6415, -2.1639]),
       size=(2, 3), nnz=6, layout=torch.sparse_coo)

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.sparse.mm。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。