本文整理汇总了Python中torch_sparse.spspmm方法的典型用法代码示例。如果您正苦于以下问题:Python torch_sparse.spspmm方法的具体用法?Python torch_sparse.spspmm怎么用?Python torch_sparse.spspmm使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch_sparse
的用法示例。
在下文中一共展示了torch_sparse.spspmm方法的6个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: __call__
# 需要导入模块: import torch_sparse [as 别名]
# 或者: from torch_sparse import spspmm [as 别名]
def __call__(self, data):
edge_index, edge_attr = data.edge_index, data.edge_attr
N = data.num_nodes
value = edge_index.new_ones((edge_index.size(1), ), dtype=torch.float)
index, value = spspmm(edge_index, value, edge_index, value, N, N, N)
value.fill_(0)
index, value = remove_self_loops(index, value)
edge_index = torch.cat([edge_index, index], dim=1)
if edge_attr is None:
data.edge_index, _ = coalesce(edge_index, None, N, N)
else:
value = value.view(-1, *[1 for _ in range(edge_attr.dim() - 1)])
value = value.expand(-1, *list(edge_attr.size())[1:])
edge_attr = torch.cat([edge_attr, value], dim=0)
data.edge_index, edge_attr = coalesce(edge_index, edge_attr, N, N)
data.edge_attr = edge_attr
return data
示例2: forward
# 需要导入模块: import torch_sparse [as 别名]
# 或者: from torch_sparse import spspmm [as 别名]
def forward(self, A, H_=None):
if self.first == True:
result_A = self.conv1(A)
result_B = self.conv2(A)
W = [(F.softmax(self.conv1.weight, dim=1)).detach(),(F.softmax(self.conv2.weight, dim=1)).detach()]
else:
result_A = H_
result_B = self.conv1(A)
W = [(F.softmax(self.conv1.weight, dim=1)).detach()]
H = []
for i in range(len(result_A)):
a_edge, a_value = result_A[i]
b_edge, b_value = result_B[i]
edges, values = torch_sparse.spspmm(a_edge, a_value, b_edge, b_value, self.num_nodes, self.num_nodes, self.num_nodes)
H.append((edges, values))
return H, W
示例3: augment_adj
# 需要导入模块: import torch_sparse [as 别名]
# 或者: from torch_sparse import spspmm [as 别名]
def augment_adj(self, edge_index, edge_weight, num_nodes):
edge_index, edge_weight = add_self_loops(edge_index, edge_weight,
num_nodes=num_nodes)
edge_index, edge_weight = sort_edge_index(edge_index, edge_weight,
num_nodes)
edge_index, edge_weight = spspmm(edge_index, edge_weight, edge_index,
edge_weight, num_nodes, num_nodes,
num_nodes)
edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
return edge_index, edge_weight
示例4: forward
# 需要导入模块: import torch_sparse [as 别名]
# 或者: from torch_sparse import spspmm [as 别名]
def forward(self, phi_indices, phi_values, phi_inverse_indices, phi_inverse_values, features):
"""
Forward propagation pass.
:param phi_indices: Sparse wavelet matrix index pairs.
:param phi_values: Sparse wavelet matrix values.
:param phi_inverse_indices: Inverse wavelet matrix index pairs.
:param phi_inverse_values: Inverse wavelet matrix values.
:param features: Feature matrix.
:return localized_features: Filtered feature matrix extracted.
"""
rescaled_phi_indices, rescaled_phi_values = spspmm(phi_indices,
phi_values,
self.diagonal_weight_indices,
self.diagonal_weight_filter.view(-1),
self.ncount,
self.ncount,
self.ncount)
phi_product_indices, phi_product_values = spspmm(rescaled_phi_indices,
rescaled_phi_values,
phi_inverse_indices,
phi_inverse_values,
self.ncount,
self.ncount,
self.ncount)
filtered_features = torch.mm(features, self.weight_matrix)
localized_features = spmm(phi_product_indices,
phi_product_values,
self.ncount,
self.ncount,
filtered_features)
return localized_features
示例5: StAS
# 需要导入模块: import torch_sparse [as 别名]
# 或者: from torch_sparse import spspmm [as 别名]
def StAS(index_A, value_A, index_S, value_S, device, N, kN):
r"""StAS: a function which returns new edge weights for the pooled graph using the formula S^{T}AS"""
index_A, value_A = coalesce(index_A, value_A, m=N, n=N)
index_S, value_S = coalesce(index_S, value_S, m=N, n=kN)
index_B, value_B = spspmm(index_A, value_A, index_S, value_S, N, N, kN)
index_St, value_St = transpose(index_S, value_S, N, kN)
index_B, value_B = coalesce(index_B, value_B, m=N, n=kN)
# index_E, value_E = spspmm(index_St.cpu(), value_St.cpu(), index_B.cpu(), value_B.cpu(), kN, N, kN)
index_E, value_E = spspmm(index_St, value_St, index_B, value_B, kN, N, kN)
# return index_E.to(device), value_E.to(device)
return index_E, value_E
示例6: test_spspmm
# 需要导入模块: import torch_sparse [as 别名]
# 或者: from torch_sparse import spspmm [as 别名]
def test_spspmm(dtype, device):
indexA = torch.tensor([[0, 0, 1, 2, 2], [1, 2, 0, 0, 1]], device=device)
valueA = tensor([1, 2, 3, 4, 5], dtype, device)
indexB = torch.tensor([[0, 2], [1, 0]], device=device)
valueB = tensor([2, 4], dtype, device)
indexC, valueC = spspmm(indexA, valueA, indexB, valueB, 3, 3, 2)
assert indexC.tolist() == [[0, 1, 2], [0, 1, 1]]
assert valueC.tolist() == [8, 6, 8]