本文整理汇总了Python中torch_scatter.scatter方法的典型用法代码示例。如果您正苦于以下问题:Python torch_scatter.scatter方法的具体用法?Python torch_scatter.scatter怎么用?Python torch_scatter.scatter使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch_scatter
的用法示例。
在下文中一共展示了torch_scatter.scatter方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: global_mean_pool
# 需要导入模块: import torch_scatter [as 别名]
# 或者: from torch_scatter import scatter [as 别名]
def global_mean_pool(x, batch, size: Optional[int] = None):
r"""Returns batch-wise graph-level-outputs by averaging node features
across the node dimension, so that for a single graph
:math:`\mathcal{G}_i` its output is computed by
.. math::
\mathbf{r}_i = \frac{1}{N_i} \sum_{n=1}^{N_i} \mathbf{x}_n
Args:
x (Tensor): Node feature matrix
:math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`.
batch (LongTensor): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots,
B-1\}}^N`, which assigns each node to a specific example.
size (int, optional): Batch-size :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
:rtype: :class:`Tensor`
"""
size = int(batch.max().item() + 1) if size is None else size
return scatter(x, batch, dim=0, dim_size=size, reduce='mean')
示例2: global_max_pool
# 需要导入模块: import torch_scatter [as 别名]
# 或者: from torch_scatter import scatter [as 别名]
def global_max_pool(x, batch, size: Optional[int] = None):
r"""Returns batch-wise graph-level-outputs by taking the channel-wise
maximum across the node dimension, so that for a single graph
:math:`\mathcal{G}_i` its output is computed by
.. math::
\mathbf{r}_i = \mathrm{max}_{n=1}^{N_i} \, \mathbf{x}_n
Args:
x (Tensor): Node feature matrix
:math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`.
batch (LongTensor): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots,
B-1\}}^N`, which assigns each node to a specific example.
size (int, optional): Batch-size :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
:rtype: :class:`Tensor`
"""
size = int(batch.max().item() + 1) if size is None else size
return scatter(x, batch, dim=0, dim_size=size, reduce='max')
示例3: forward
# 需要导入模块: import torch_scatter [as 别名]
# 或者: from torch_scatter import scatter [as 别名]
def forward(self, x, rbf, sbf, idx_kj, idx_ji):
rbf = self.lin_rbf(rbf)
sbf = self.lin_sbf(sbf)
x_ji = self.act(self.lin_ji(x))
x_kj = self.act(self.lin_kj(x))
x_kj = x_kj * rbf
x_kj = torch.einsum('wj,wl,ijl->wi', sbf, x_kj[idx_kj], self.W)
x_kj = scatter(x_kj, idx_ji, dim=0, dim_size=x.size(0))
h = x_ji + x_kj
for layer in self.layers_before_skip:
h = layer(h)
h = self.act(self.lin(h)) + x
for layer in self.layers_after_skip:
h = layer(h)
return h
示例4: aggregate
# 需要导入模块: import torch_scatter [as 别名]
# 或者: from torch_scatter import scatter [as 别名]
def aggregate(self, inputs: Tensor, index: Tensor,
ptr: Optional[Tensor] = None,
dim_size: Optional[int] = None) -> Tensor:
r"""Aggregates messages from neighbors as
:math:`\square_{j \in \mathcal{N}(i)}`.
Takes in the output of message computation as first argument and any
argument which was initially passed to :meth:`propagate`.
By default, this function will delegate its call to scatter functions
that support "add", "mean" and "max" operations as specified in
:meth:`__init__` by the :obj:`aggr` argument.
"""
if ptr is not None:
ptr = expand_left(ptr, dim=self.node_dim, dims=inputs.dim())
return segment_csr(inputs, ptr, reduce=self.aggr)
else:
return scatter(inputs, index, dim=self.node_dim, dim_size=dim_size,
reduce=self.aggr)
示例5: test_broadcasting
# 需要导入模块: import torch_scatter [as 别名]
# 或者: from torch_scatter import scatter [as 别名]
def test_broadcasting(reduce, device):
B, C, H, W = (4, 3, 8, 8)
src = torch.randn((B, C, H, W), device=device)
index = torch.randint(0, H, (H, )).to(device, torch.long)
out = scatter(src, index, dim=2, dim_size=H, reduce=reduce)
assert out.size() == (B, C, H, W)
src = torch.randn((B, C, H, W), device=device)
index = torch.randint(0, H, (B, 1, H, W)).to(device, torch.long)
out = scatter(src, index, dim=2, dim_size=H, reduce=reduce)
assert out.size() == (B, C, H, W)
src = torch.randn((B, C, H, W), device=device)
index = torch.randint(0, H, (H, )).to(device, torch.long)
out = scatter(src, index, dim=2, dim_size=H, reduce=reduce)
assert out.size() == (B, C, H, W)
示例6: test_zero_elements
# 需要导入模块: import torch_scatter [as 别名]
# 或者: from torch_scatter import scatter [as 别名]
def test_zero_elements(reduce, dtype, device):
x = torch.randn(0, 0, 0, 16, dtype=dtype, device=device,
requires_grad=True)
index = tensor([], torch.long, device)
indptr = tensor([], torch.long, device)
out = scatter(x, index, dim=0, dim_size=0, reduce=reduce)
out.backward(torch.randn_like(out))
assert out.size() == (0, 0, 0, 16)
out = segment_coo(x, index, dim_size=0, reduce=reduce)
out.backward(torch.randn_like(out))
assert out.size() == (0, 0, 0, 16)
out = gather_coo(x, index)
out.backward(torch.randn_like(out))
assert out.size() == (0, 0, 0, 16)
out = segment_csr(x, indptr, reduce=reduce)
out.backward(torch.randn_like(out))
assert out.size() == (0, 0, 0, 16)
out = gather_csr(x, indptr)
out.backward(torch.randn_like(out))
assert out.size() == (0, 0, 0, 16)
示例7: test_forward
# 需要导入模块: import torch_scatter [as 别名]
# 或者: from torch_scatter import scatter [as 别名]
def test_forward(test, reduce, dtype):
device = torch.device('cuda:1')
src = tensor(test['src'], dtype, device)
index = tensor(test['index'], torch.long, device)
indptr = tensor(test['indptr'], torch.long, device)
dim = test['dim']
expected = tensor(test[reduce], dtype, device)
out = torch_scatter.scatter(src, index, dim, reduce=reduce)
assert torch.all(out == expected)
out = torch_scatter.segment_coo(src, index, reduce=reduce)
assert torch.all(out == expected)
out = torch_scatter.segment_csr(src, indptr, reduce=reduce)
assert torch.all(out == expected)
示例8: extract_node_features
# 需要导入模块: import torch_scatter [as 别名]
# 或者: from torch_scatter import scatter [as 别名]
def extract_node_features(self, aggr='add'):
file_path = 'init_node_features_{}.pt'.format(aggr)
if os.path.isfile(file_path):
print('{} exists'.format(file_path))
else:
if aggr in ['add', 'mean', 'max']:
node_features = scatter(self.edge_attr,
self.edge_index[0],
dim=0,
dim_size=self.total_no_of_nodes,
reduce=aggr)
else:
raise Exception('Unknown Aggr Method')
torch.save(node_features, file_path)
print('Node features extracted are saved into file {}'.format(file_path))
return file_path
示例9: softmax
# 需要导入模块: import torch_scatter [as 别名]
# 或者: from torch_scatter import scatter [as 别名]
def softmax(src: Tensor, index: Tensor, ptr: Optional[Tensor] = None,
num_nodes: Optional[int] = None) -> Tensor:
r"""Computes a sparsely evaluated softmax.
Given a value tensor :attr:`src`, this function first groups the values
along the first dimension based on the indices specified in :attr:`index`,
and then proceeds to compute the softmax individually for each group.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements for applying the softmax.
ptr (LongTensor, optional): If given, computes the softmax based on
sorted inputs in CSR representation. (default: :obj:`None`)
num_nodes (int, optional): The number of nodes, *i.e.*
:obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`)
:rtype: :class:`Tensor`
"""
if ptr is None:
N = maybe_num_nodes(index, num_nodes)
out = src - scatter(src, index, dim=0, dim_size=N, reduce='max')[index]
out = out.exp()
out_sum = scatter(out, index, dim=0, dim_size=N, reduce='sum')[index]
return out / (out_sum + 1e-16)
else:
out = src - gather_csr(segment_csr(src, ptr, reduce='max'), ptr)
out = out.exp()
out_sum = gather_csr(segment_csr(out, ptr, reduce='sum'), ptr)
return out / (out_sum + 1e-16)
示例10: _max_pool_x
# 需要导入模块: import torch_scatter [as 别名]
# 或者: from torch_scatter import scatter [as 别名]
def _max_pool_x(cluster, x, size: Optional[int] = None):
return scatter(x, cluster, dim=0, dim_size=size, reduce='max')
示例11: _avg_pool_x
# 需要导入模块: import torch_scatter [as 别名]
# 或者: from torch_scatter import scatter [as 别名]
def _avg_pool_x(cluster, x, size: Optional[int] = None):
return scatter(x, cluster, dim=0, dim_size=size, reduce='mean')
示例12: forward
# 需要导入模块: import torch_scatter [as 别名]
# 或者: from torch_scatter import scatter [as 别名]
def forward(self, z, pos, batch=None):
assert z.dim() == 1 and z.dtype == torch.long
batch = torch.zeros_like(z) if batch is None else batch
h = self.embedding(z)
edge_index = radius_graph(pos, r=self.cutoff, batch=batch)
row, col = edge_index
edge_weight = (pos[row] - pos[col]).norm(dim=-1)
edge_attr = self.distance_expansion(edge_weight)
for interaction in self.interactions:
h = h + interaction(h, edge_index, edge_weight, edge_attr)
h = self.lin1(h)
h = self.act(h)
h = self.lin2(h)
if self.dipole:
# Get center of mass.
mass = self.atomic_mass[z].view(-1, 1)
c = scatter(mass * pos, batch, dim=0) / scatter(mass, batch, dim=0)
h = h * (pos - c[batch])
if not self.dipole and self.mean is not None and self.std is not None:
h = h * self.std + self.mean
if not self.dipole and self.atomref is not None:
h = h + self.atomref(z)
out = scatter(h, batch, dim=0, reduce=self.readout)
if self.dipole:
out = torch.norm(out, dim=-1, keepdim=True)
if self.scale is not None:
out = self.scale * out
return out
示例13: aggregate
# 需要导入模块: import torch_scatter [as 别名]
# 或者: from torch_scatter import scatter [as 别名]
def aggregate(self, inputs: Tensor, index: Tensor,
dim_size: Optional[int] = None) -> Tensor:
out_mean = scatter(inputs, index, dim=self.node_dim, dim_size=dim_size,
reduce='mean')
out_max = scatter(inputs, index, dim=self.node_dim, dim_size=dim_size,
reduce='max')
return torch.cat([out_mean, out_max], dim=-1)
示例14: aggregate
# 需要导入模块: import torch_scatter [as 别名]
# 或者: from torch_scatter import scatter [as 别名]
def aggregate(self, inputs: Tensor, edge_type: Tensor, index: Tensor,
dim_size: Optional[int] = None) -> Tensor:
# Compute normalization in separation for each `edge_type`.
if self.aggr == 'mean':
norm = F.one_hot(edge_type, self.num_relations).to(torch.float)
norm = scatter(norm, index, dim=0, dim_size=dim_size)[index]
norm = torch.gather(norm, 1, edge_type.view(-1, 1))
norm = 1. / norm.clamp_(1.)
inputs = norm * inputs
return scatter(inputs, index, dim=self.node_dim, dim_size=dim_size)
示例15: test_backward
# 需要导入模块: import torch_scatter [as 别名]
# 或者: from torch_scatter import scatter [as 别名]
def test_backward(test, reduce, device):
src = tensor(test['src'], torch.double, device)
src.requires_grad_()
index = tensor(test['index'], torch.long, device)
dim = test['dim']
assert gradcheck(torch_scatter.scatter,
(src, index, dim, None, None, reduce))