当前位置: 首页>>代码示例>>Python>>正文


Python nn.global_add_pool方法代码示例

本文整理汇总了Python中torch_geometric.nn.global_add_pool方法的典型用法代码示例。如果您正苦于以下问题:Python nn.global_add_pool方法的具体用法?Python nn.global_add_pool怎么用?Python nn.global_add_pool使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch_geometric.nn的用法示例。


在下文中一共展示了nn.global_add_pool方法的10个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: forward

# 需要导入模块: from torch_geometric import nn [as 别名]
# 或者: from torch_geometric.nn import global_add_pool [as 别名]
def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        if self.adj_dropout > 0:
            edge_index, edge_type = dropout_adj(
                edge_index, edge_type, p=self.adj_dropout, 
                force_undirected=self.force_undirected, num_nodes=len(x), 
                training=self.training
            )
        concat_states = []
        for conv in self.convs:
            x = torch.tanh(conv(x, edge_index))
            concat_states.append(x)
        concat_states = torch.cat(concat_states, 1)
        x = global_add_pool(concat_states, batch)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        if self.regression:
            return x[:, 0]
        else:
            return F.log_softmax(x, dim=-1) 
开发者ID:muhanzhang,项目名称:IGMC,代码行数:23,代码来源:models.py

示例2: forward

# 需要导入模块: from torch_geometric import nn [as 别名]
# 或者: from torch_geometric.nn import global_add_pool [as 别名]
def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        out = F.relu(self.conv1(x, edge_index))

        out, edge_index, _, batch, perm, score = self.pool1(
            out, edge_index, None, batch, attn=x)
        ratio = out.size(0) / x.size(0)

        out = F.relu(self.conv2(out, edge_index))
        out = global_add_pool(out, batch)
        out = self.lin(out).view(-1)

        attn_loss = F.kl_div(torch.log(score + 1e-14), data.attn[perm],
                             reduction='none')
        attn_loss = scatter_mean(attn_loss, batch)

        return out, attn_loss, ratio 
开发者ID:rusty1s,项目名称:pytorch_geometric,代码行数:20,代码来源:colors_topk_pool.py

示例3: forward

# 需要导入模块: from torch_geometric import nn [as 别名]
# 或者: from torch_geometric.nn import global_add_pool [as 别名]
def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))
        x = self.bn1(x)
        x = F.relu(self.conv2(x, edge_index))
        x = self.bn2(x)
        x = F.relu(self.conv3(x, edge_index))
        x = self.bn3(x)
        x = F.relu(self.conv4(x, edge_index))
        x = self.bn4(x)
        x = F.relu(self.conv5(x, edge_index))
        x = self.bn5(x)
        x = global_add_pool(x, batch)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=-1) 
开发者ID:rusty1s,项目名称:pytorch_geometric,代码行数:18,代码来源:mutag_gin.py

示例4: forward

# 需要导入模块: from torch_geometric import nn [as 别名]
# 或者: from torch_geometric.nn import global_add_pool [as 别名]
def forward(self, data):
        return self.mlp(global_add_pool(data.x, data.batch)) 
开发者ID:diningphil,项目名称:gnn-comparison,代码行数:4,代码来源:MolecularFingerprint.py

示例5: forward

# 需要导入模块: from torch_geometric import nn [as 别名]
# 或者: from torch_geometric.nn import global_add_pool [as 别名]
def forward(self, data):
        x, batch = data.x, data.batch

        x = F.relu(self.fc_vertex(x))
        x = global_add_pool(x, batch)  # sums all vertex embeddings belonging to the same graph!
        x = F.relu(self.fc_global1(x))
        x = self.fc_global2(x)
        return x 
开发者ID:diningphil,项目名称:gnn-comparison,代码行数:10,代码来源:DeepMultisets.py

示例6: forward

# 需要导入模块: from torch_geometric import nn [as 别名]
# 或者: from torch_geometric.nn import global_add_pool [as 别名]
def forward(self, x, edge_index, batch):
        for conv, batch_norm in zip(self.convs, self.batch_norms):
            x = F.relu(batch_norm(conv(x, edge_index)))
        x = global_add_pool(x, batch)
        x = F.relu(self.batch_norm1(self.lin1(x)))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1) 
开发者ID:rusty1s,项目名称:pytorch_geometric,代码行数:10,代码来源:gin.py

示例7: forward

# 需要导入模块: from torch_geometric import nn [as 别名]
# 或者: from torch_geometric.nn import global_add_pool [as 别名]
def forward(self, x, edge_index, edge_attr, batch):
        x = self.node_emb(x.squeeze())
        edge_attr = self.edge_emb(edge_attr)

        for conv, batch_norm in zip(self.convs, self.batch_norms):
            x = F.relu(batch_norm(conv(x, edge_index, edge_attr)))

        x = global_add_pool(x, batch)
        return self.mlp(x) 
开发者ID:rusty1s,项目名称:pytorch_geometric,代码行数:11,代码来源:pna.py

示例8: forward

# 需要导入模块: from torch_geometric import nn [as 别名]
# 或者: from torch_geometric.nn import global_add_pool [as 别名]
def forward(self, batched_data):

        x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch

        ### virtual node embeddings for graphs
        virtualnode_embedding = self.virtualnode_embedding(torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device))

        h_list = [self.atom_encoder(x)]
        for layer in range(self.num_layer):
            ### add message from virtual nodes to graph nodes
            h_list[layer] = h_list[layer] + virtualnode_embedding[batch]

            ### Message passing among graph nodes
            h = self.convs[layer](h_list[layer], edge_index, edge_attr)

            h = self.batch_norms[layer](h)
            if layer == self.num_layer - 1:
                #remove relu for the last layer
                h = F.dropout(h, self.drop_ratio, training = self.training)
            else:
                h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)

            if self.residual:
                h = h + h_list[layer]

            h_list.append(h)

            ### update the virtual nodes
            if layer < self.num_layer - 1:
                ### add message from graph nodes to virtual nodes
                virtualnode_embedding_temp = global_add_pool(h_list[layer], batch) + virtualnode_embedding
                ### transform virtual nodes using MLP

                if self.residual:
                    virtualnode_embedding = virtualnode_embedding + F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training)
                else:
                    virtualnode_embedding = F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training)

        ### Different implementations of Jk-concat
        if self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "sum":
            node_representation = 0
            for layer in range(self.num_layer):
                node_representation += h_list[layer]
        
        return node_representation 
开发者ID:snap-stanford,项目名称:ogb,代码行数:49,代码来源:conv.py

示例9: forward

# 需要导入模块: from torch_geometric import nn [as 别名]
# 或者: from torch_geometric.nn import global_add_pool [as 别名]
def forward(self, batched_data):

        x, edge_index, edge_attr, node_depth, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.node_depth, batched_data.batch

        ### virtual node embeddings for graphs
        virtualnode_embedding = self.virtualnode_embedding(torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device))

        h_list = [self.node_encoder(x, node_depth.view(-1,))]
        for layer in range(self.num_layer):
            ### add message from virtual nodes to graph nodes
            h_list[layer] = h_list[layer] + virtualnode_embedding[batch]

            ### Message passing among graph nodes
            h = self.convs[layer](h_list[layer], edge_index, edge_attr)

            h = self.batch_norms[layer](h)
            if layer == self.num_layer - 1:
                #remove relu for the last layer
                h = F.dropout(h, self.drop_ratio, training = self.training)
            else:
                h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)

            if self.residual:
                h = h + h_list[layer]

            h_list.append(h)

            ### update the virtual nodes
            if layer < self.num_layer - 1:
                ### add message from graph nodes to virtual nodes
                virtualnode_embedding_temp = global_add_pool(h_list[layer], batch) + virtualnode_embedding
                ### transform virtual nodes using MLP

                if self.residual:
                    virtualnode_embedding = virtualnode_embedding + F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training)
                else:
                    virtualnode_embedding = F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training)

        ### Different implementations of Jk-concat
        if self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "sum":
            node_representation = 0
            for layer in range(self.num_layer):
                node_representation += h_list[layer]
        
        return node_representation 
开发者ID:snap-stanford,项目名称:ogb,代码行数:49,代码来源:conv.py

示例10: forward

# 需要导入模块: from torch_geometric import nn [as 别名]
# 或者: from torch_geometric.nn import global_add_pool [as 别名]
def forward(self, batched_data):

        x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch

        ### virtual node embeddings for graphs
        virtualnode_embedding = self.virtualnode_embedding(torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device))

        h_list = [self.node_encoder(x)]
        for layer in range(self.num_layer):
            ### add message from virtual nodes to graph nodes
            h_list[layer] = h_list[layer] + virtualnode_embedding[batch]

            ### Message passing among graph nodes
            h = self.convs[layer](h_list[layer], edge_index, edge_attr)

            h = self.batch_norms[layer](h)
            if layer == self.num_layer - 1:
                #remove relu for the last layer
                h = F.dropout(h, self.drop_ratio, training = self.training)
            else:
                h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)

            if self.residual:
                h = h + h_list[layer]

            h_list.append(h)

            ### update the virtual nodes
            if layer < self.num_layer - 1:
                ### add message from graph nodes to virtual nodes
                virtualnode_embedding_temp = global_add_pool(h_list[layer], batch) + virtualnode_embedding
                ### transform virtual nodes using MLP

                if self.residual:
                    virtualnode_embedding = virtualnode_embedding + F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training)
                else:
                    virtualnode_embedding = F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training)

        ### Different implementations of Jk-concat
        if self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "sum":
            node_representation = 0
            for layer in range(self.num_layer):
                node_representation += h_list[layer]
        
        return node_representation 
开发者ID:snap-stanford,项目名称:ogb,代码行数:49,代码来源:conv.py


注:本文中的torch_geometric.nn.global_add_pool方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。