當前位置: 首頁>>代碼示例>>Python>>正文


Python nn.global_add_pool方法代碼示例

本文整理匯總了Python中torch_geometric.nn.global_add_pool方法的典型用法代碼示例。如果您正苦於以下問題:Python nn.global_add_pool方法的具體用法?Python nn.global_add_pool怎麽用?Python nn.global_add_pool使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在torch_geometric.nn的用法示例。


在下文中一共展示了nn.global_add_pool方法的10個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: forward

# 需要導入模塊: from torch_geometric import nn [as 別名]
# 或者: from torch_geometric.nn import global_add_pool [as 別名]
def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        if self.adj_dropout > 0:
            edge_index, edge_type = dropout_adj(
                edge_index, edge_type, p=self.adj_dropout, 
                force_undirected=self.force_undirected, num_nodes=len(x), 
                training=self.training
            )
        concat_states = []
        for conv in self.convs:
            x = torch.tanh(conv(x, edge_index))
            concat_states.append(x)
        concat_states = torch.cat(concat_states, 1)
        x = global_add_pool(concat_states, batch)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        if self.regression:
            return x[:, 0]
        else:
            return F.log_softmax(x, dim=-1) 
開發者ID:muhanzhang,項目名稱:IGMC,代碼行數:23,代碼來源:models.py

示例2: forward

# 需要導入模塊: from torch_geometric import nn [as 別名]
# 或者: from torch_geometric.nn import global_add_pool [as 別名]
def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        out = F.relu(self.conv1(x, edge_index))

        out, edge_index, _, batch, perm, score = self.pool1(
            out, edge_index, None, batch, attn=x)
        ratio = out.size(0) / x.size(0)

        out = F.relu(self.conv2(out, edge_index))
        out = global_add_pool(out, batch)
        out = self.lin(out).view(-1)

        attn_loss = F.kl_div(torch.log(score + 1e-14), data.attn[perm],
                             reduction='none')
        attn_loss = scatter_mean(attn_loss, batch)

        return out, attn_loss, ratio 
開發者ID:rusty1s,項目名稱:pytorch_geometric,代碼行數:20,代碼來源:colors_topk_pool.py

示例3: forward

# 需要導入模塊: from torch_geometric import nn [as 別名]
# 或者: from torch_geometric.nn import global_add_pool [as 別名]
def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))
        x = self.bn1(x)
        x = F.relu(self.conv2(x, edge_index))
        x = self.bn2(x)
        x = F.relu(self.conv3(x, edge_index))
        x = self.bn3(x)
        x = F.relu(self.conv4(x, edge_index))
        x = self.bn4(x)
        x = F.relu(self.conv5(x, edge_index))
        x = self.bn5(x)
        x = global_add_pool(x, batch)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=-1) 
開發者ID:rusty1s,項目名稱:pytorch_geometric,代碼行數:18,代碼來源:mutag_gin.py

示例4: forward

# 需要導入模塊: from torch_geometric import nn [as 別名]
# 或者: from torch_geometric.nn import global_add_pool [as 別名]
def forward(self, data):
        return self.mlp(global_add_pool(data.x, data.batch)) 
開發者ID:diningphil,項目名稱:gnn-comparison,代碼行數:4,代碼來源:MolecularFingerprint.py

示例5: forward

# 需要導入模塊: from torch_geometric import nn [as 別名]
# 或者: from torch_geometric.nn import global_add_pool [as 別名]
def forward(self, data):
        x, batch = data.x, data.batch

        x = F.relu(self.fc_vertex(x))
        x = global_add_pool(x, batch)  # sums all vertex embeddings belonging to the same graph!
        x = F.relu(self.fc_global1(x))
        x = self.fc_global2(x)
        return x 
開發者ID:diningphil,項目名稱:gnn-comparison,代碼行數:10,代碼來源:DeepMultisets.py

示例6: forward

# 需要導入模塊: from torch_geometric import nn [as 別名]
# 或者: from torch_geometric.nn import global_add_pool [as 別名]
def forward(self, x, edge_index, batch):
        for conv, batch_norm in zip(self.convs, self.batch_norms):
            x = F.relu(batch_norm(conv(x, edge_index)))
        x = global_add_pool(x, batch)
        x = F.relu(self.batch_norm1(self.lin1(x)))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1) 
開發者ID:rusty1s,項目名稱:pytorch_geometric,代碼行數:10,代碼來源:gin.py

示例7: forward

# 需要導入模塊: from torch_geometric import nn [as 別名]
# 或者: from torch_geometric.nn import global_add_pool [as 別名]
def forward(self, x, edge_index, edge_attr, batch):
        x = self.node_emb(x.squeeze())
        edge_attr = self.edge_emb(edge_attr)

        for conv, batch_norm in zip(self.convs, self.batch_norms):
            x = F.relu(batch_norm(conv(x, edge_index, edge_attr)))

        x = global_add_pool(x, batch)
        return self.mlp(x) 
開發者ID:rusty1s,項目名稱:pytorch_geometric,代碼行數:11,代碼來源:pna.py

示例8: forward

# 需要導入模塊: from torch_geometric import nn [as 別名]
# 或者: from torch_geometric.nn import global_add_pool [as 別名]
def forward(self, batched_data):

        x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch

        ### virtual node embeddings for graphs
        virtualnode_embedding = self.virtualnode_embedding(torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device))

        h_list = [self.atom_encoder(x)]
        for layer in range(self.num_layer):
            ### add message from virtual nodes to graph nodes
            h_list[layer] = h_list[layer] + virtualnode_embedding[batch]

            ### Message passing among graph nodes
            h = self.convs[layer](h_list[layer], edge_index, edge_attr)

            h = self.batch_norms[layer](h)
            if layer == self.num_layer - 1:
                #remove relu for the last layer
                h = F.dropout(h, self.drop_ratio, training = self.training)
            else:
                h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)

            if self.residual:
                h = h + h_list[layer]

            h_list.append(h)

            ### update the virtual nodes
            if layer < self.num_layer - 1:
                ### add message from graph nodes to virtual nodes
                virtualnode_embedding_temp = global_add_pool(h_list[layer], batch) + virtualnode_embedding
                ### transform virtual nodes using MLP

                if self.residual:
                    virtualnode_embedding = virtualnode_embedding + F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training)
                else:
                    virtualnode_embedding = F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training)

        ### Different implementations of Jk-concat
        if self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "sum":
            node_representation = 0
            for layer in range(self.num_layer):
                node_representation += h_list[layer]
        
        return node_representation 
開發者ID:snap-stanford,項目名稱:ogb,代碼行數:49,代碼來源:conv.py

示例9: forward

# 需要導入模塊: from torch_geometric import nn [as 別名]
# 或者: from torch_geometric.nn import global_add_pool [as 別名]
def forward(self, batched_data):

        x, edge_index, edge_attr, node_depth, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.node_depth, batched_data.batch

        ### virtual node embeddings for graphs
        virtualnode_embedding = self.virtualnode_embedding(torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device))

        h_list = [self.node_encoder(x, node_depth.view(-1,))]
        for layer in range(self.num_layer):
            ### add message from virtual nodes to graph nodes
            h_list[layer] = h_list[layer] + virtualnode_embedding[batch]

            ### Message passing among graph nodes
            h = self.convs[layer](h_list[layer], edge_index, edge_attr)

            h = self.batch_norms[layer](h)
            if layer == self.num_layer - 1:
                #remove relu for the last layer
                h = F.dropout(h, self.drop_ratio, training = self.training)
            else:
                h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)

            if self.residual:
                h = h + h_list[layer]

            h_list.append(h)

            ### update the virtual nodes
            if layer < self.num_layer - 1:
                ### add message from graph nodes to virtual nodes
                virtualnode_embedding_temp = global_add_pool(h_list[layer], batch) + virtualnode_embedding
                ### transform virtual nodes using MLP

                if self.residual:
                    virtualnode_embedding = virtualnode_embedding + F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training)
                else:
                    virtualnode_embedding = F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training)

        ### Different implementations of Jk-concat
        if self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "sum":
            node_representation = 0
            for layer in range(self.num_layer):
                node_representation += h_list[layer]
        
        return node_representation 
開發者ID:snap-stanford,項目名稱:ogb,代碼行數:49,代碼來源:conv.py

示例10: forward

# 需要導入模塊: from torch_geometric import nn [as 別名]
# 或者: from torch_geometric.nn import global_add_pool [as 別名]
def forward(self, batched_data):

        x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch

        ### virtual node embeddings for graphs
        virtualnode_embedding = self.virtualnode_embedding(torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device))

        h_list = [self.node_encoder(x)]
        for layer in range(self.num_layer):
            ### add message from virtual nodes to graph nodes
            h_list[layer] = h_list[layer] + virtualnode_embedding[batch]

            ### Message passing among graph nodes
            h = self.convs[layer](h_list[layer], edge_index, edge_attr)

            h = self.batch_norms[layer](h)
            if layer == self.num_layer - 1:
                #remove relu for the last layer
                h = F.dropout(h, self.drop_ratio, training = self.training)
            else:
                h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)

            if self.residual:
                h = h + h_list[layer]

            h_list.append(h)

            ### update the virtual nodes
            if layer < self.num_layer - 1:
                ### add message from graph nodes to virtual nodes
                virtualnode_embedding_temp = global_add_pool(h_list[layer], batch) + virtualnode_embedding
                ### transform virtual nodes using MLP

                if self.residual:
                    virtualnode_embedding = virtualnode_embedding + F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training)
                else:
                    virtualnode_embedding = F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training)

        ### Different implementations of Jk-concat
        if self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "sum":
            node_representation = 0
            for layer in range(self.num_layer):
                node_representation += h_list[layer]
        
        return node_representation 
開發者ID:snap-stanford,項目名稱:ogb,代碼行數:49,代碼來源:conv.py


注:本文中的torch_geometric.nn.global_add_pool方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。