本文整理汇总了Python中dgl.batch方法的典型用法代码示例。如果您正苦于以下问题:Python dgl.batch方法的具体用法?Python dgl.batch怎么用?Python dgl.batch使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类dgl
的用法示例。
在下文中一共展示了dgl.batch方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: collate_fn
# 需要导入模块: import dgl [as 别名]
# 或者: from dgl import batch [as 别名]
def collate_fn(batch):
'''
collate_fn for dataset batching
transform ndata to tensor (in gpu is available)
'''
graphs, labels = map(list, zip(*batch))
#cuda = torch.cuda.is_available()
# batch graphs and cast to PyTorch tensor
for graph in graphs:
for (key, value) in graph.ndata.items():
graph.ndata[key] = torch.FloatTensor(value)
batched_graphs = dgl.batch(graphs)
# cast to PyTorch tensor
batched_labels = torch.LongTensor(np.array(labels))
return batched_graphs, batched_labels
示例2: forward
# 需要导入模块: import dgl [as 别名]
# 或者: from dgl import batch [as 别名]
def forward(self, doc_ids, is_20ng=None):
sub_graphs = [self.seq_to_graph(doc) for doc in doc_ids]
batch_graph = dgl.batch(sub_graphs)
batch_graph.update_all(
message_func=dgl.function.src_mul_edge('h', 'w', 'weighted_message'),
reduce_func=dgl.function.max('weighted_message', 'h')
)
h1 = dgl.sum_nodes(batch_graph, feat='h')
drop1 = self.dropout(h1)
act1 = self.activation(drop1)
l = self.Linear(act1)
return l
示例3: batcher
# 需要导入模块: import dgl [as 别名]
# 或者: from dgl import batch [as 别名]
def batcher(batch):
g_batch = [[] for _ in range(coarsening_levels + 1)]
x_batch = []
y_batch = []
for x, y in batch:
x = torch.cat([x.view(-1), x.new_zeros(928 - 28 ** 2)], 0)
x = x[perm]
x_batch.append(x)
y_batch.append(y)
for i in range(coarsening_levels + 1):
g_batch[i].append(g_arr[i])
x_batch = torch.cat(x_batch).unsqueeze(-1)
y_batch = torch.LongTensor(y_batch)
g_batch = [dgl.batch(g) for g in g_batch]
return g_batch, x_batch, y_batch
示例4: forward
# 需要导入模块: import dgl [as 别名]
# 或者: from dgl import batch [as 别名]
def forward(self, pos, centroids, feat=None):
dev = pos.device
group_idx = self.frnn(pos, centroids)
B, N, _ = pos.shape
glist = []
for i in range(B):
center = torch.zeros((N)).to(dev)
center[centroids[i]] = 1
src = group_idx[i].contiguous().view(-1)
dst = centroids[i].view(-1, 1).repeat(1, self.n_neighbor).view(-1)
unified = torch.cat([src, dst])
uniq, inv_idx = torch.unique(unified, return_inverse=True)
src_idx = inv_idx[:src.shape[0]]
dst_idx = inv_idx[src.shape[0]:]
g = dgl.DGLGraph((src_idx.cpu(), dst_idx.cpu()), readonly=True)
g.ndata['pos'] = pos[i][uniq]
g.ndata['center'] = center[uniq]
if feat is not None:
g.ndata['feat'] = feat[i][uniq]
glist.append(g)
bg = dgl.batch(glist)
return bg
示例5: write_txt
# 需要导入模块: import dgl [as 别名]
# 或者: from dgl import batch [as 别名]
def write_txt(batch, seqs, w_file, args):
# converting the prediction to real text.
ret = []
for b, seq in enumerate(seqs):
txt = []
for token in seq:
# copy the entity
if token>=len(args.text_vocab):
ent_text = batch['raw_ent_text'][b][token-len(args.text_vocab)]
ent_text = filter(lambda x:x!='<PAD>', ent_text)
txt.extend(ent_text)
else:
if int(token) not in [args.text_vocab(x) for x in ['<PAD>', '<BOS>', '<EOS>']]:
txt.append(args.text_vocab(int(token)))
if int(token) == args.text_vocab('<EOS>'):
break
w_file.write(' '.join([str(x) for x in txt])+'\n')
ret.append([' '.join([str(x) for x in txt])])
return ret
示例6: batch_fn
# 需要导入模块: import dgl [as 别名]
# 或者: from dgl import batch [as 别名]
def batch_fn(self, batch_ex):
batch_title, batch_ent_text, batch_ent_type, batch_rel, batch_text, batch_tgt_text, batch_graph = \
[], [], [], [], [], [], []
batch_raw_ent_text = []
for ex in batch_ex:
ex_data = ex.get_tensor(self.ent_vocab, self.rel_vocab, self.text_vocab, self.ent_text_vocab, self.title_vocab)
batch_title.append(ex_data['title'])
batch_ent_text.append(ex_data['ent_text'])
batch_ent_type.append(ex_data['ent_type'])
batch_rel.append(ex_data['rel'])
batch_text.append(ex_data['text'])
batch_tgt_text.append(ex_data['tgt_text'])
batch_graph.append(ex_data['graph'])
batch_raw_ent_text.append(ex_data['raw_ent_text'])
batch_title = pad(batch_title, out_type='tensor')
batch_ent_text, ent_len = pad(batch_ent_text, out_type='tensor', flatten=True)
batch_ent_type = pad(batch_ent_type, out_type='tensor')
batch_rel = pad(batch_rel, out_type='tensor')
batch_text = pad(batch_text, out_type='tensor')
batch_tgt_text = pad(batch_tgt_text, out_type='tensor')
batch_graph = dgl.batch(batch_graph)
batch_graph.to(self.device)
return {'title': batch_title.to(self.device), 'ent_text': batch_ent_text.to(self.device), 'ent_len': ent_len, \
'ent_type': batch_ent_type.to(self.device), 'rel': batch_rel.to(self.device), 'text': batch_text.to(self.device), \
'tgt_text': batch_tgt_text.to(self.device), 'graph': batch_graph, 'raw_ent_text': batch_raw_ent_text}
示例7: test_batch_unbatch
# 需要导入模块: import dgl [as 别名]
# 或者: from dgl import batch [as 别名]
def test_batch_unbatch():
t1 = tree1()
t2 = tree2()
bg = dgl.batch([t1, t2])
assert bg.number_of_nodes() == 10
assert bg.number_of_edges() == 8
assert bg.batch_size == 2
assert bg.batch_num_nodes == [5, 5]
assert bg.batch_num_edges == [4, 4]
tt1, tt2 = dgl.unbatch(bg)
assert F.allclose(t1.ndata['h'], tt1.ndata['h'])
assert F.allclose(t1.edata['h'], tt1.edata['h'])
assert F.allclose(t2.ndata['h'], tt2.ndata['h'])
assert F.allclose(t2.edata['h'], tt2.edata['h'])
示例8: test_batch_send_then_recv
# 需要导入模块: import dgl [as 别名]
# 或者: from dgl import batch [as 别名]
def test_batch_send_then_recv():
t1 = tree1()
t2 = tree2()
bg = dgl.batch([t1, t2])
bg.register_message_func(lambda edges: {'m' : edges.src['h']})
bg.register_reduce_func(lambda nodes: {'h' : F.sum(nodes.mailbox['m'], 1)})
u = [3, 4, 2 + 5, 0 + 5]
v = [1, 1, 4 + 5, 4 + 5]
bg.send((u, v))
bg.recv([1, 9]) # assuming recv takes in unique nodes
t1, t2 = dgl.unbatch(bg)
assert F.asnumpy(t1.ndata['h'][1]) == 7
assert F.asnumpy(t2.ndata['h'][4]) == 2
示例9: test_laplacian_lambda_max
# 需要导入模块: import dgl [as 别名]
# 或者: from dgl import batch [as 别名]
def test_laplacian_lambda_max():
N = 20
eps = 1e-6
# test DGLGraph
g = dgl.DGLGraph(nx.erdos_renyi_graph(N, 0.3))
l_max = dgl.laplacian_lambda_max(g)
assert (l_max[0] < 2 + eps)
# test batched DGLGraph
N_arr = [20, 30, 10, 12]
bg = dgl.batch([
dgl.DGLGraph(nx.erdos_renyi_graph(N, 0.3))
for N in N_arr
])
l_max_arr = dgl.laplacian_lambda_max(bg)
assert len(l_max_arr) == len(N_arr)
for l_max in l_max_arr:
assert l_max < 2 + eps
示例10: test_broadcast_nodes
# 需要导入模块: import dgl [as 别名]
# 或者: from dgl import batch [as 别名]
def test_broadcast_nodes():
# test#1: basic
g0 = dgl.DGLGraph(nx.path_graph(10))
feat0 = F.randn((1, 40))
ground_truth = F.stack([feat0] * g0.number_of_nodes(), 0)
assert F.allclose(dgl.broadcast_nodes(g0, feat0), ground_truth)
# test#2: batched graph
g1 = dgl.DGLGraph(nx.path_graph(3))
g2 = dgl.DGLGraph()
g3 = dgl.DGLGraph(nx.path_graph(12))
bg = dgl.batch([g0, g1, g2, g3])
feat1 = F.randn((1, 40))
feat2 = F.randn((1, 40))
feat3 = F.randn((1, 40))
ground_truth = F.cat(
[feat0] * g0.number_of_nodes() +\
[feat1] * g1.number_of_nodes() +\
[feat2] * g2.number_of_nodes() +\
[feat3] * g3.number_of_nodes(), 0
)
assert F.allclose(dgl.broadcast_nodes(
bg, F.cat([feat0, feat1, feat2, feat3], 0)
), ground_truth)
示例11: test_set2set
# 需要导入模块: import dgl [as 别名]
# 或者: from dgl import batch [as 别名]
def test_set2set():
ctx = F.ctx()
g = dgl.DGLGraph(nx.path_graph(10))
s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers
s2s = s2s.to(ctx)
print(s2s)
# test#1: basic
h0 = F.randn((g.number_of_nodes(), 5))
h1 = s2s(g, h0)
assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
# test#2: batched graph
g1 = dgl.DGLGraph(nx.path_graph(11))
g2 = dgl.DGLGraph(nx.path_graph(5))
bg = dgl.batch([g, g1, g2])
h0 = F.randn((bg.number_of_nodes(), 5))
h1 = s2s(bg, h0)
assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.dim() == 2
示例12: test_glob_att_pool
# 需要导入模块: import dgl [as 别名]
# 或者: from dgl import batch [as 别名]
def test_glob_att_pool():
ctx = F.ctx()
g = dgl.DGLGraph(nx.path_graph(10))
gap = nn.GlobalAttentionPooling(th.nn.Linear(5, 1), th.nn.Linear(5, 10))
gap = gap.to(ctx)
print(gap)
# test#1: basic
h0 = F.randn((g.number_of_nodes(), 5))
h1 = gap(g, h0)
assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
# test#2: batched graph
bg = dgl.batch([g, g, g, g])
h0 = F.randn((bg.number_of_nodes(), 5))
h1 = gap(bg, h0)
assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.dim() == 2
示例13: test_glob_att_pool
# 需要导入模块: import dgl [as 别名]
# 或者: from dgl import batch [as 别名]
def test_glob_att_pool():
g = dgl.DGLGraph(nx.path_graph(10))
ctx = F.ctx()
gap = nn.GlobalAttentionPooling(gluon.nn.Dense(1), gluon.nn.Dense(10))
gap.initialize(ctx=ctx)
print(gap)
# test#1: basic
h0 = F.randn((g.number_of_nodes(), 5))
h1 = gap(g, h0)
assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.ndim == 2
# test#2: batched graph
bg = dgl.batch([g, g, g, g])
h0 = F.randn((bg.number_of_nodes(), 5))
h1 = gap(bg, h0)
assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.ndim == 2
示例14: test_glob_att_pool
# 需要导入模块: import dgl [as 别名]
# 或者: from dgl import batch [as 别名]
def test_glob_att_pool():
g = dgl.DGLGraph(nx.path_graph(10))
gap = nn.GlobalAttentionPooling(layers.Dense(1), layers.Dense(10))
print(gap)
# test#1: basic
h0 = F.randn((g.number_of_nodes(), 5))
h1 = gap(g, h0)
assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.ndim == 2
# test#2: batched graph
bg = dgl.batch([g, g, g, g])
h0 = F.randn((bg.number_of_nodes(), 5))
h1 = gap(bg, h0)
assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.ndim == 2
示例15: collate_csqa_graphs
# 需要导入模块: import dgl [as 别名]
# 或者: from dgl import batch [as 别名]
def collate_csqa_graphs(samples):
# The input `samples` is a list of pairs
# (graph, label, qid, aid, sentv).
statements, correct_labels, graph_data = map(list, zip(*samples))
flat_graph_data = []
for gd in graph_data:
flat_graph_data.extend(gd)
# for k, g in enumerate(flat_graph_data):
# g.ndata["gid"] = torch.Tensor([k] * len(g.nodes()))
# g.edata["gid"] = torch.Tensor([k] * len(g.edges()[0]))
batched_graph = dgl.batch(flat_graph_data)
sents_vecs = torch.stack(statements)
return sents_vecs, torch.Tensor([[i] for i in correct_labels]), batched_graph