本文整理汇总了Python中torch_geometric.nn.global_add_pool方法的典型用法代码示例。如果您正苦于以下问题:Python nn.global_add_pool方法的具体用法?Python nn.global_add_pool怎么用?Python nn.global_add_pool使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch_geometric.nn
的用法示例。
在下文中一共展示了nn.global_add_pool方法的10个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: forward
# 需要导入模块: from torch_geometric import nn [as 别名]
# 或者: from torch_geometric.nn import global_add_pool [as 别名]
def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
if self.adj_dropout > 0:
edge_index, edge_type = dropout_adj(
edge_index, edge_type, p=self.adj_dropout,
force_undirected=self.force_undirected, num_nodes=len(x),
training=self.training
)
concat_states = []
for conv in self.convs:
x = torch.tanh(conv(x, edge_index))
concat_states.append(x)
concat_states = torch.cat(concat_states, 1)
x = global_add_pool(concat_states, batch)
x = F.relu(self.lin1(x))
x = F.dropout(x, p=0.5, training=self.training)
x = self.lin2(x)
if self.regression:
return x[:, 0]
else:
return F.log_softmax(x, dim=-1)
示例2: forward
# 需要导入模块: from torch_geometric import nn [as 别名]
# 或者: from torch_geometric.nn import global_add_pool [as 别名]
def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
out = F.relu(self.conv1(x, edge_index))
out, edge_index, _, batch, perm, score = self.pool1(
out, edge_index, None, batch, attn=x)
ratio = out.size(0) / x.size(0)
out = F.relu(self.conv2(out, edge_index))
out = global_add_pool(out, batch)
out = self.lin(out).view(-1)
attn_loss = F.kl_div(torch.log(score + 1e-14), data.attn[perm],
reduction='none')
attn_loss = scatter_mean(attn_loss, batch)
return out, attn_loss, ratio
示例3: forward
# 需要导入模块: from torch_geometric import nn [as 别名]
# 或者: from torch_geometric.nn import global_add_pool [as 别名]
def forward(self, x, edge_index, batch):
x = F.relu(self.conv1(x, edge_index))
x = self.bn1(x)
x = F.relu(self.conv2(x, edge_index))
x = self.bn2(x)
x = F.relu(self.conv3(x, edge_index))
x = self.bn3(x)
x = F.relu(self.conv4(x, edge_index))
x = self.bn4(x)
x = F.relu(self.conv5(x, edge_index))
x = self.bn5(x)
x = global_add_pool(x, batch)
x = F.relu(self.fc1(x))
x = F.dropout(x, p=0.5, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=-1)
示例4: forward
# 需要导入模块: from torch_geometric import nn [as 别名]
# 或者: from torch_geometric.nn import global_add_pool [as 别名]
def forward(self, data):
return self.mlp(global_add_pool(data.x, data.batch))
示例5: forward
# 需要导入模块: from torch_geometric import nn [as 别名]
# 或者: from torch_geometric.nn import global_add_pool [as 别名]
def forward(self, data):
x, batch = data.x, data.batch
x = F.relu(self.fc_vertex(x))
x = global_add_pool(x, batch) # sums all vertex embeddings belonging to the same graph!
x = F.relu(self.fc_global1(x))
x = self.fc_global2(x)
return x
示例6: forward
# 需要导入模块: from torch_geometric import nn [as 别名]
# 或者: from torch_geometric.nn import global_add_pool [as 别名]
def forward(self, x, edge_index, batch):
for conv, batch_norm in zip(self.convs, self.batch_norms):
x = F.relu(batch_norm(conv(x, edge_index)))
x = global_add_pool(x, batch)
x = F.relu(self.batch_norm1(self.lin1(x)))
x = F.dropout(x, p=0.5, training=self.training)
x = self.lin2(x)
return F.log_softmax(x, dim=-1)
示例7: forward
# 需要导入模块: from torch_geometric import nn [as 别名]
# 或者: from torch_geometric.nn import global_add_pool [as 别名]
def forward(self, x, edge_index, edge_attr, batch):
x = self.node_emb(x.squeeze())
edge_attr = self.edge_emb(edge_attr)
for conv, batch_norm in zip(self.convs, self.batch_norms):
x = F.relu(batch_norm(conv(x, edge_index, edge_attr)))
x = global_add_pool(x, batch)
return self.mlp(x)
示例8: forward
# 需要导入模块: from torch_geometric import nn [as 别名]
# 或者: from torch_geometric.nn import global_add_pool [as 别名]
def forward(self, batched_data):
x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch
### virtual node embeddings for graphs
virtualnode_embedding = self.virtualnode_embedding(torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device))
h_list = [self.atom_encoder(x)]
for layer in range(self.num_layer):
### add message from virtual nodes to graph nodes
h_list[layer] = h_list[layer] + virtualnode_embedding[batch]
### Message passing among graph nodes
h = self.convs[layer](h_list[layer], edge_index, edge_attr)
h = self.batch_norms[layer](h)
if layer == self.num_layer - 1:
#remove relu for the last layer
h = F.dropout(h, self.drop_ratio, training = self.training)
else:
h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
if self.residual:
h = h + h_list[layer]
h_list.append(h)
### update the virtual nodes
if layer < self.num_layer - 1:
### add message from graph nodes to virtual nodes
virtualnode_embedding_temp = global_add_pool(h_list[layer], batch) + virtualnode_embedding
### transform virtual nodes using MLP
if self.residual:
virtualnode_embedding = virtualnode_embedding + F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training)
else:
virtualnode_embedding = F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training)
### Different implementations of Jk-concat
if self.JK == "last":
node_representation = h_list[-1]
elif self.JK == "sum":
node_representation = 0
for layer in range(self.num_layer):
node_representation += h_list[layer]
return node_representation
示例9: forward
# 需要导入模块: from torch_geometric import nn [as 别名]
# 或者: from torch_geometric.nn import global_add_pool [as 别名]
def forward(self, batched_data):
x, edge_index, edge_attr, node_depth, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.node_depth, batched_data.batch
### virtual node embeddings for graphs
virtualnode_embedding = self.virtualnode_embedding(torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device))
h_list = [self.node_encoder(x, node_depth.view(-1,))]
for layer in range(self.num_layer):
### add message from virtual nodes to graph nodes
h_list[layer] = h_list[layer] + virtualnode_embedding[batch]
### Message passing among graph nodes
h = self.convs[layer](h_list[layer], edge_index, edge_attr)
h = self.batch_norms[layer](h)
if layer == self.num_layer - 1:
#remove relu for the last layer
h = F.dropout(h, self.drop_ratio, training = self.training)
else:
h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
if self.residual:
h = h + h_list[layer]
h_list.append(h)
### update the virtual nodes
if layer < self.num_layer - 1:
### add message from graph nodes to virtual nodes
virtualnode_embedding_temp = global_add_pool(h_list[layer], batch) + virtualnode_embedding
### transform virtual nodes using MLP
if self.residual:
virtualnode_embedding = virtualnode_embedding + F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training)
else:
virtualnode_embedding = F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training)
### Different implementations of Jk-concat
if self.JK == "last":
node_representation = h_list[-1]
elif self.JK == "sum":
node_representation = 0
for layer in range(self.num_layer):
node_representation += h_list[layer]
return node_representation
示例10: forward
# 需要导入模块: from torch_geometric import nn [as 别名]
# 或者: from torch_geometric.nn import global_add_pool [as 别名]
def forward(self, batched_data):
x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch
### virtual node embeddings for graphs
virtualnode_embedding = self.virtualnode_embedding(torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device))
h_list = [self.node_encoder(x)]
for layer in range(self.num_layer):
### add message from virtual nodes to graph nodes
h_list[layer] = h_list[layer] + virtualnode_embedding[batch]
### Message passing among graph nodes
h = self.convs[layer](h_list[layer], edge_index, edge_attr)
h = self.batch_norms[layer](h)
if layer == self.num_layer - 1:
#remove relu for the last layer
h = F.dropout(h, self.drop_ratio, training = self.training)
else:
h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
if self.residual:
h = h + h_list[layer]
h_list.append(h)
### update the virtual nodes
if layer < self.num_layer - 1:
### add message from graph nodes to virtual nodes
virtualnode_embedding_temp = global_add_pool(h_list[layer], batch) + virtualnode_embedding
### transform virtual nodes using MLP
if self.residual:
virtualnode_embedding = virtualnode_embedding + F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training)
else:
virtualnode_embedding = F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training)
### Different implementations of Jk-concat
if self.JK == "last":
node_representation = h_list[-1]
elif self.JK == "sum":
node_representation = 0
for layer in range(self.num_layer):
node_representation += h_list[layer]
return node_representation