当前位置: 首页>>代码示例>>Python>>正文


Python torch_sparse.spspmm方法代码示例

本文整理汇总了Python中torch_sparse.spspmm方法的典型用法代码示例。如果您正苦于以下问题:Python torch_sparse.spspmm方法的具体用法?Python torch_sparse.spspmm怎么用?Python torch_sparse.spspmm使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch_sparse的用法示例。


在下文中一共展示了torch_sparse.spspmm方法的6个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __call__

# 需要导入模块: import torch_sparse [as 别名]
# 或者: from torch_sparse import spspmm [as 别名]
def __call__(self, data):
        edge_index, edge_attr = data.edge_index, data.edge_attr
        N = data.num_nodes

        value = edge_index.new_ones((edge_index.size(1), ), dtype=torch.float)

        index, value = spspmm(edge_index, value, edge_index, value, N, N, N)
        value.fill_(0)
        index, value = remove_self_loops(index, value)

        edge_index = torch.cat([edge_index, index], dim=1)
        if edge_attr is None:
            data.edge_index, _ = coalesce(edge_index, None, N, N)
        else:
            value = value.view(-1, *[1 for _ in range(edge_attr.dim() - 1)])
            value = value.expand(-1, *list(edge_attr.size())[1:])
            edge_attr = torch.cat([edge_attr, value], dim=0)
            data.edge_index, edge_attr = coalesce(edge_index, edge_attr, N, N)
            data.edge_attr = edge_attr

        return data 
开发者ID:rusty1s,项目名称:pytorch_geometric,代码行数:23,代码来源:two_hop.py

示例2: forward

# 需要导入模块: import torch_sparse [as 别名]
# 或者: from torch_sparse import spspmm [as 别名]
def forward(self, A, H_=None):
        if self.first == True:
            result_A = self.conv1(A)
            result_B = self.conv2(A)                
            W = [(F.softmax(self.conv1.weight, dim=1)).detach(),(F.softmax(self.conv2.weight, dim=1)).detach()]
        else:
            result_A = H_
            result_B = self.conv1(A)
            W = [(F.softmax(self.conv1.weight, dim=1)).detach()]
        H = []
        for i in range(len(result_A)):
            a_edge, a_value = result_A[i]
            b_edge, b_value = result_B[i]
            
            edges, values = torch_sparse.spspmm(a_edge, a_value, b_edge, b_value, self.num_nodes, self.num_nodes, self.num_nodes)
            H.append((edges, values))
        return H, W 
开发者ID:THUDM,项目名称:cogdl,代码行数:19,代码来源:gtn.py

示例3: augment_adj

# 需要导入模块: import torch_sparse [as 别名]
# 或者: from torch_sparse import spspmm [as 别名]
def augment_adj(self, edge_index, edge_weight, num_nodes):
        edge_index, edge_weight = add_self_loops(edge_index, edge_weight,
                                                 num_nodes=num_nodes)
        edge_index, edge_weight = sort_edge_index(edge_index, edge_weight,
                                                  num_nodes)
        edge_index, edge_weight = spspmm(edge_index, edge_weight, edge_index,
                                         edge_weight, num_nodes, num_nodes,
                                         num_nodes)
        edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
        return edge_index, edge_weight 
开发者ID:rusty1s,项目名称:pytorch_geometric,代码行数:12,代码来源:graph_unet.py

示例4: forward

# 需要导入模块: import torch_sparse [as 别名]
# 或者: from torch_sparse import spspmm [as 别名]
def forward(self, phi_indices, phi_values, phi_inverse_indices, phi_inverse_values, features):
        """
        Forward propagation pass.
        :param phi_indices: Sparse wavelet matrix index pairs.
        :param phi_values: Sparse wavelet matrix values.
        :param phi_inverse_indices: Inverse wavelet matrix index pairs.
        :param phi_inverse_values: Inverse wavelet matrix values.
        :param features: Feature matrix.
        :return localized_features: Filtered feature matrix extracted.
        """
        rescaled_phi_indices, rescaled_phi_values = spspmm(phi_indices,
                                                           phi_values,
                                                           self.diagonal_weight_indices,
                                                           self.diagonal_weight_filter.view(-1),
                                                           self.ncount,
                                                           self.ncount,
                                                           self.ncount)

        phi_product_indices, phi_product_values = spspmm(rescaled_phi_indices,
                                                         rescaled_phi_values,
                                                         phi_inverse_indices,
                                                         phi_inverse_values,
                                                         self.ncount,
                                                         self.ncount,
                                                         self.ncount)

        filtered_features = torch.mm(features, self.weight_matrix)

        localized_features = spmm(phi_product_indices,
                                  phi_product_values,
                                  self.ncount,
                                  self.ncount,
                                  filtered_features)

        return localized_features 
开发者ID:benedekrozemberczki,项目名称:GraphWaveletNeuralNetwork,代码行数:37,代码来源:gwnn_layer.py

示例5: StAS

# 需要导入模块: import torch_sparse [as 别名]
# 或者: from torch_sparse import spspmm [as 别名]
def StAS(index_A, value_A, index_S, value_S, device, N, kN):
    r"""StAS: a function which returns new edge weights for the pooled graph using the formula S^{T}AS"""

    index_A, value_A = coalesce(index_A, value_A, m=N, n=N)
    index_S, value_S = coalesce(index_S, value_S, m=N, n=kN)
    index_B, value_B = spspmm(index_A, value_A, index_S, value_S, N, N, kN)

    index_St, value_St = transpose(index_S, value_S, N, kN)
    index_B, value_B = coalesce(index_B, value_B, m=N, n=kN)
    # index_E, value_E = spspmm(index_St.cpu(), value_St.cpu(), index_B.cpu(), value_B.cpu(), kN, N, kN)
    index_E, value_E = spspmm(index_St, value_St, index_B, value_B, kN, N, kN)

    # return index_E.to(device), value_E.to(device)
    return index_E, value_E 
开发者ID:malllabiisc,项目名称:ASAP,代码行数:16,代码来源:asap_pool.py

示例6: test_spspmm

# 需要导入模块: import torch_sparse [as 别名]
# 或者: from torch_sparse import spspmm [as 别名]
def test_spspmm(dtype, device):
    indexA = torch.tensor([[0, 0, 1, 2, 2], [1, 2, 0, 0, 1]], device=device)
    valueA = tensor([1, 2, 3, 4, 5], dtype, device)
    indexB = torch.tensor([[0, 2], [1, 0]], device=device)
    valueB = tensor([2, 4], dtype, device)

    indexC, valueC = spspmm(indexA, valueA, indexB, valueB, 3, 3, 2)
    assert indexC.tolist() == [[0, 1, 2], [0, 1, 1]]
    assert valueC.tolist() == [8, 6, 8] 
开发者ID:rusty1s,项目名称:pytorch_sparse,代码行数:11,代码来源:test_spspmm.py


注:本文中的torch_sparse.spspmm方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。