当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch sum用法及代码示例


本文简要介绍python语言中 torch.sparse.sum 的用法。

用法:

torch.sparse.sum(input, dim=None, dtype=None)

参数

  • input(Tensor) -输入稀疏张量

  • dim(int或者python的元组:ints) -要减少的维度或维度列表。默认值:减少所有暗淡。

  • dtype(torch.dtype, 可选的) -返回张量的所需数据类型。默认值:input 的数据类型。

返回给定维度 dim 中稀疏张量 input 的每一行的总和。如果dim 是维度列表,则对所有维度进行归约。当对所有 sparse_dim 求和时,此方法返回密集张量而不是稀疏张量。

所有求和的 dim 都被压缩(参见 torch.squeeze() ),导致输出张量的 diminput 的维度少。

在反向期间,只有inputnnz 位置处的梯度会向后传播。请注意,input 的梯度是合并的。

例子:

>>> nnz = 3
>>> dims = [5, 5, 2, 3]
>>> I = torch.cat([torch.randint(0, dims[0], size=(nnz,)),
                   torch.randint(0, dims[1], size=(nnz,))], 0).reshape(2, nnz)
>>> V = torch.randn(nnz, dims[2], dims[3])
>>> size = torch.Size(dims)
>>> S = torch.sparse_coo_tensor(I, V, size)
>>> S
tensor(indices=tensor([[2, 0, 3],
                       [2, 4, 1]]),
       values=tensor([[[-0.6438, -1.6467,  1.4004],
                       [ 0.3411,  0.0918, -0.2312]],

                      [[ 0.5348,  0.0634, -2.0494],
                       [-0.7125, -1.0646,  2.1844]],

                      [[ 0.1276,  0.1874, -0.6334],
                       [-1.9682, -0.5340,  0.7483]]]),
       size=(5, 5, 2, 3), nnz=3, layout=torch.sparse_coo)

# when sum over only part of sparse_dims, return a sparse tensor
>>> torch.sparse.sum(S, [1, 3])
tensor(indices=tensor([[0, 2, 3]]),
       values=tensor([[-1.4512,  0.4073],
                      [-0.8901,  0.2017],
                      [-0.3183, -1.7539]]),
       size=(5, 2), nnz=3, layout=torch.sparse_coo)

# when sum over all sparse dim, return a dense tensor
# with summed dims squeezed
>>> torch.sparse.sum(S, [0, 1, 3])
tensor([-2.6596, -1.1450])

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.sparse.sum。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。