当前位置: 首页>>代码示例>>Python>>正文


Python dgl.batch方法代码示例

本文整理汇总了Python中dgl.batch方法的典型用法代码示例。如果您正苦于以下问题:Python dgl.batch方法的具体用法?Python dgl.batch怎么用?Python dgl.batch使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在dgl的用法示例。


在下文中一共展示了dgl.batch方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: collate_fn

# 需要导入模块: import dgl [as 别名]
# 或者: from dgl import batch [as 别名]
def collate_fn(batch):
    '''
    collate_fn for dataset batching
    transform ndata to tensor (in gpu is available)
    '''
    graphs, labels = map(list, zip(*batch))
    #cuda = torch.cuda.is_available()

    # batch graphs and cast to PyTorch tensor
    for graph in graphs:
        for (key, value) in graph.ndata.items():
            graph.ndata[key] = torch.FloatTensor(value)
    batched_graphs = dgl.batch(graphs)

    # cast to PyTorch tensor
    batched_labels = torch.LongTensor(np.array(labels))

    return batched_graphs, batched_labels 
开发者ID:dmlc,项目名称:dgl,代码行数:20,代码来源:train.py

示例2: forward

# 需要导入模块: import dgl [as 别名]
# 或者: from dgl import batch [as 别名]
def forward(self, doc_ids, is_20ng=None):
        sub_graphs = [self.seq_to_graph(doc) for doc in doc_ids]

        batch_graph = dgl.batch(sub_graphs)

        batch_graph.update_all(
            message_func=dgl.function.src_mul_edge('h', 'w', 'weighted_message'),
            reduce_func=dgl.function.max('weighted_message', 'h')
        )

        h1 = dgl.sum_nodes(batch_graph, feat='h')

        drop1 = self.dropout(h1)
        act1 = self.activation(drop1)

        l = self.Linear(act1)

        return l 
开发者ID:HuangLianzhe,项目名称:TextLevelGCN,代码行数:20,代码来源:model.py

示例3: batcher

# 需要导入模块: import dgl [as 别名]
# 或者: from dgl import batch [as 别名]
def batcher(batch):
    g_batch = [[] for _ in range(coarsening_levels + 1)]
    x_batch = []
    y_batch = []
    for x, y in batch:
        x = torch.cat([x.view(-1), x.new_zeros(928 - 28 ** 2)], 0)
        x = x[perm]
        x_batch.append(x)
        y_batch.append(y)
        for i in range(coarsening_levels + 1):
            g_batch[i].append(g_arr[i])

    x_batch = torch.cat(x_batch).unsqueeze(-1)
    y_batch = torch.LongTensor(y_batch)
    g_batch = [dgl.batch(g) for g in g_batch]
    return g_batch, x_batch, y_batch 
开发者ID:dmlc,项目名称:dgl,代码行数:18,代码来源:mnist.py

示例4: forward

# 需要导入模块: import dgl [as 别名]
# 或者: from dgl import batch [as 别名]
def forward(self, pos, centroids, feat=None):
        dev = pos.device
        group_idx = self.frnn(pos, centroids)
        B, N, _ = pos.shape
        glist = []
        for i in range(B):
            center = torch.zeros((N)).to(dev)
            center[centroids[i]] = 1
            src = group_idx[i].contiguous().view(-1)
            dst = centroids[i].view(-1, 1).repeat(1, self.n_neighbor).view(-1)

            unified = torch.cat([src, dst])
            uniq, inv_idx = torch.unique(unified, return_inverse=True)
            src_idx = inv_idx[:src.shape[0]]
            dst_idx = inv_idx[src.shape[0]:]

            g = dgl.DGLGraph((src_idx.cpu(), dst_idx.cpu()), readonly=True)
            g.ndata['pos'] = pos[i][uniq]
            g.ndata['center'] = center[uniq]
            if feat is not None:
                g.ndata['feat'] = feat[i][uniq]
            glist.append(g)
        bg = dgl.batch(glist)
        return bg 
开发者ID:dmlc,项目名称:dgl,代码行数:26,代码来源:pointnet2.py

示例5: write_txt

# 需要导入模块: import dgl [as 别名]
# 或者: from dgl import batch [as 别名]
def write_txt(batch, seqs, w_file, args):
    # converting the prediction to real text.
    ret = []
    for b, seq in enumerate(seqs):
        txt = []
        for token in seq:
            # copy the entity
            if token>=len(args.text_vocab):
                ent_text = batch['raw_ent_text'][b][token-len(args.text_vocab)]
                ent_text = filter(lambda x:x!='<PAD>', ent_text)
                txt.extend(ent_text)
            else:
                if int(token) not in [args.text_vocab(x) for x in ['<PAD>', '<BOS>', '<EOS>']]:
                    txt.append(args.text_vocab(int(token)))
            if int(token) == args.text_vocab('<EOS>'):
                break
        w_file.write(' '.join([str(x) for x in txt])+'\n')
        ret.append([' '.join([str(x) for x in txt])])
    return ret 
开发者ID:dmlc,项目名称:dgl,代码行数:21,代码来源:utlis.py

示例6: batch_fn

# 需要导入模块: import dgl [as 别名]
# 或者: from dgl import batch [as 别名]
def batch_fn(self, batch_ex):
        batch_title, batch_ent_text, batch_ent_type, batch_rel, batch_text, batch_tgt_text, batch_graph = \
        [], [], [], [], [], [], []
        batch_raw_ent_text = []
        for ex in batch_ex:
            ex_data = ex.get_tensor(self.ent_vocab, self.rel_vocab, self.text_vocab, self.ent_text_vocab, self.title_vocab)
            batch_title.append(ex_data['title'])
            batch_ent_text.append(ex_data['ent_text'])
            batch_ent_type.append(ex_data['ent_type'])
            batch_rel.append(ex_data['rel'])
            batch_text.append(ex_data['text'])
            batch_tgt_text.append(ex_data['tgt_text'])
            batch_graph.append(ex_data['graph'])
            batch_raw_ent_text.append(ex_data['raw_ent_text'])
        batch_title = pad(batch_title, out_type='tensor')
        batch_ent_text, ent_len = pad(batch_ent_text, out_type='tensor', flatten=True)
        batch_ent_type = pad(batch_ent_type, out_type='tensor')
        batch_rel = pad(batch_rel, out_type='tensor')
        batch_text = pad(batch_text, out_type='tensor')
        batch_tgt_text = pad(batch_tgt_text, out_type='tensor')
        batch_graph = dgl.batch(batch_graph)
        batch_graph.to(self.device)
        return {'title': batch_title.to(self.device), 'ent_text': batch_ent_text.to(self.device), 'ent_len': ent_len, \
            'ent_type': batch_ent_type.to(self.device), 'rel': batch_rel.to(self.device), 'text': batch_text.to(self.device), \
            'tgt_text': batch_tgt_text.to(self.device), 'graph': batch_graph, 'raw_ent_text': batch_raw_ent_text} 
开发者ID:dmlc,项目名称:dgl,代码行数:27,代码来源:utlis.py

示例7: test_batch_unbatch

# 需要导入模块: import dgl [as 别名]
# 或者: from dgl import batch [as 别名]
def test_batch_unbatch():
    t1 = tree1()
    t2 = tree2()

    bg = dgl.batch([t1, t2])
    assert bg.number_of_nodes() == 10
    assert bg.number_of_edges() == 8
    assert bg.batch_size == 2
    assert bg.batch_num_nodes == [5, 5]
    assert bg.batch_num_edges == [4, 4]

    tt1, tt2 = dgl.unbatch(bg)
    assert F.allclose(t1.ndata['h'], tt1.ndata['h'])
    assert F.allclose(t1.edata['h'], tt1.edata['h'])
    assert F.allclose(t2.ndata['h'], tt2.ndata['h'])
    assert F.allclose(t2.edata['h'], tt2.edata['h']) 
开发者ID:dmlc,项目名称:dgl,代码行数:18,代码来源:test_batched_graph.py

示例8: test_batch_send_then_recv

# 需要导入模块: import dgl [as 别名]
# 或者: from dgl import batch [as 别名]
def test_batch_send_then_recv():
    t1 = tree1()
    t2 = tree2()

    bg = dgl.batch([t1, t2])
    bg.register_message_func(lambda edges: {'m' : edges.src['h']})
    bg.register_reduce_func(lambda nodes: {'h' : F.sum(nodes.mailbox['m'], 1)})
    u = [3, 4, 2 + 5, 0 + 5]
    v = [1, 1, 4 + 5, 4 + 5]

    bg.send((u, v))
    bg.recv([1, 9]) # assuming recv takes in unique nodes

    t1, t2 = dgl.unbatch(bg)
    assert F.asnumpy(t1.ndata['h'][1]) == 7
    assert F.asnumpy(t2.ndata['h'][4]) == 2 
开发者ID:dmlc,项目名称:dgl,代码行数:18,代码来源:test_batched_graph.py

示例9: test_laplacian_lambda_max

# 需要导入模块: import dgl [as 别名]
# 或者: from dgl import batch [as 别名]
def test_laplacian_lambda_max():
    N = 20
    eps = 1e-6
    # test DGLGraph
    g = dgl.DGLGraph(nx.erdos_renyi_graph(N, 0.3))
    l_max = dgl.laplacian_lambda_max(g)
    assert (l_max[0] < 2 + eps)
    # test batched DGLGraph
    N_arr = [20, 30, 10, 12]
    bg = dgl.batch([
        dgl.DGLGraph(nx.erdos_renyi_graph(N, 0.3))
        for N in N_arr
    ])
    l_max_arr = dgl.laplacian_lambda_max(bg)
    assert len(l_max_arr) == len(N_arr)
    for l_max in l_max_arr:
        assert l_max < 2 + eps 
开发者ID:dmlc,项目名称:dgl,代码行数:19,代码来源:test_transform.py

示例10: test_broadcast_nodes

# 需要导入模块: import dgl [as 别名]
# 或者: from dgl import batch [as 别名]
def test_broadcast_nodes():
    # test#1: basic
    g0 = dgl.DGLGraph(nx.path_graph(10))
    feat0 = F.randn((1, 40))
    ground_truth = F.stack([feat0] * g0.number_of_nodes(), 0)
    assert F.allclose(dgl.broadcast_nodes(g0, feat0), ground_truth)

    # test#2: batched graph
    g1 = dgl.DGLGraph(nx.path_graph(3))
    g2 = dgl.DGLGraph()
    g3 = dgl.DGLGraph(nx.path_graph(12))
    bg = dgl.batch([g0, g1, g2, g3])
    feat1 = F.randn((1, 40))
    feat2 = F.randn((1, 40))
    feat3 = F.randn((1, 40))
    ground_truth = F.cat(
        [feat0] * g0.number_of_nodes() +\
        [feat1] * g1.number_of_nodes() +\
        [feat2] * g2.number_of_nodes() +\
        [feat3] * g3.number_of_nodes(), 0
    )
    assert F.allclose(dgl.broadcast_nodes(
        bg, F.cat([feat0, feat1, feat2, feat3], 0)
    ), ground_truth) 
开发者ID:dmlc,项目名称:dgl,代码行数:26,代码来源:test_readout.py

示例11: test_set2set

# 需要导入模块: import dgl [as 别名]
# 或者: from dgl import batch [as 别名]
def test_set2set():
    ctx = F.ctx()
    g = dgl.DGLGraph(nx.path_graph(10))

    s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers
    s2s = s2s.to(ctx)
    print(s2s)

    # test#1: basic
    h0 = F.randn((g.number_of_nodes(), 5))
    h1 = s2s(g, h0)
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2

    # test#2: batched graph
    g1 = dgl.DGLGraph(nx.path_graph(11))
    g2 = dgl.DGLGraph(nx.path_graph(5))
    bg = dgl.batch([g, g1, g2])
    h0 = F.randn((bg.number_of_nodes(), 5))
    h1 = s2s(bg, h0)
    assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.dim() == 2 
开发者ID:dmlc,项目名称:dgl,代码行数:22,代码来源:test_nn.py

示例12: test_glob_att_pool

# 需要导入模块: import dgl [as 别名]
# 或者: from dgl import batch [as 别名]
def test_glob_att_pool():
    ctx = F.ctx()
    g = dgl.DGLGraph(nx.path_graph(10))

    gap = nn.GlobalAttentionPooling(th.nn.Linear(5, 1), th.nn.Linear(5, 10))
    gap = gap.to(ctx)
    print(gap)

    # test#1: basic
    h0 = F.randn((g.number_of_nodes(), 5))
    h1 = gap(g, h0)
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2

    # test#2: batched graph
    bg = dgl.batch([g, g, g, g])
    h0 = F.randn((bg.number_of_nodes(), 5))
    h1 = gap(bg, h0)
    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.dim() == 2 
开发者ID:dmlc,项目名称:dgl,代码行数:20,代码来源:test_nn.py

示例13: test_glob_att_pool

# 需要导入模块: import dgl [as 别名]
# 或者: from dgl import batch [as 别名]
def test_glob_att_pool():
    g = dgl.DGLGraph(nx.path_graph(10))
    ctx = F.ctx()

    gap = nn.GlobalAttentionPooling(gluon.nn.Dense(1), gluon.nn.Dense(10))
    gap.initialize(ctx=ctx)
    print(gap)
    # test#1: basic
    h0 = F.randn((g.number_of_nodes(), 5))
    h1 = gap(g, h0)
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.ndim == 2

    # test#2: batched graph
    bg = dgl.batch([g, g, g, g])
    h0 = F.randn((bg.number_of_nodes(), 5))
    h1 = gap(bg, h0)
    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.ndim == 2 
开发者ID:dmlc,项目名称:dgl,代码行数:19,代码来源:test_nn.py

示例14: test_glob_att_pool

# 需要导入模块: import dgl [as 别名]
# 或者: from dgl import batch [as 别名]
def test_glob_att_pool():
    g = dgl.DGLGraph(nx.path_graph(10))

    gap = nn.GlobalAttentionPooling(layers.Dense(1), layers.Dense(10))
    print(gap)

    # test#1: basic
    h0 = F.randn((g.number_of_nodes(), 5))
    h1 = gap(g, h0)
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.ndim == 2

    # test#2: batched graph
    bg = dgl.batch([g, g, g, g])
    h0 = F.randn((bg.number_of_nodes(), 5))
    h1 = gap(bg, h0)
    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.ndim == 2 
开发者ID:dmlc,项目名称:dgl,代码行数:18,代码来源:test_nn.py

示例15: collate_csqa_graphs

# 需要导入模块: import dgl [as 别名]
# 或者: from dgl import batch [as 别名]
def collate_csqa_graphs(samples):
    # The input `samples` is a list of pairs
    #  (graph, label, qid, aid, sentv).
    statements, correct_labels, graph_data = map(list, zip(*samples))

    flat_graph_data = []
    for gd in graph_data:
        flat_graph_data.extend(gd)

    # for k, g in enumerate(flat_graph_data):
    #     g.ndata["gid"] = torch.Tensor([k] * len(g.nodes()))
    #     g.edata["gid"] = torch.Tensor([k] * len(g.edges()[0]))

    batched_graph = dgl.batch(flat_graph_data)
    sents_vecs = torch.stack(statements)
    return sents_vecs,  torch.Tensor([[i] for i in correct_labels]), batched_graph 
开发者ID:INK-USC,项目名称:KagNet,代码行数:18,代码来源:csqa_dataset.py


注:本文中的dgl.batch方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。