本文整理汇总了Python中torch_scatter.scatter_max方法的典型用法代码示例。如果您正苦于以下问题:Python torch_scatter.scatter_max方法的具体用法?Python torch_scatter.scatter_max怎么用?Python torch_scatter.scatter_max使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch_scatter
的用法示例。
在下文中一共展示了torch_scatter.scatter_max方法的12个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: __call__
# 需要导入模块: import torch_scatter [as 别名]
# 或者: from torch_scatter import scatter_max [as 别名]
def __call__(self, data):
row, col = data.edge_index
N = data.num_nodes
deg = degree(row, N, dtype=torch.float)
deg_col = deg[col]
min_deg, _ = scatter_min(deg_col, row, dim_size=N)
min_deg[min_deg > 10000] = 0
max_deg, _ = scatter_max(deg_col, row, dim_size=N)
max_deg[max_deg < -10000] = 0
mean_deg = scatter_mean(deg_col, row, dim_size=N)
std_deg = scatter_std(deg_col, row, dim_size=N)
x = torch.stack([deg, min_deg, max_deg, mean_deg, std_deg], dim=1)
if data.x is not None:
data.x = data.x.view(-1, 1) if data.x.dim() == 1 else data.x
data.x = torch.cat([data.x, x], dim=-1)
else:
data.x = x
return data
示例2: __call__
# 需要导入模块: import torch_scatter [as 别名]
# 或者: from torch_scatter import scatter_max [as 别名]
def __call__(self, data):
(row, col), pos, pseudo = data.edge_index, data.pos, data.edge_attr
cart = pos[col] - pos[row]
cart = cart.view(-1, 1) if cart.dim() == 1 else cart
max_value, _ = scatter_max(cart.abs(), row, 0, dim_size=pos.size(0))
max_value = max_value.max(dim=-1, keepdim=True)[0]
cart = cart / (2 * max_value[row]) + 0.5
if pseudo is not None and self.cat:
pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo
data.edge_attr = torch.cat([pseudo, cart.type_as(pseudo)], dim=-1)
else:
data.edge_attr = cart
return data
示例3: softmax1
# 需要导入模块: import torch_scatter [as 别名]
# 或者: from torch_scatter import scatter_max [as 别名]
def softmax1(src, index, num_nodes=None):
r"""Computes a sparsely evaluated softmax.
Given a value tensor :attr:`src`, this function first groups the values
along the first dimension based on the indices specified in :attr:`index`,
and then proceeds to compute the softmax individually for each group.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements for applying the softmax.
num_nodes (int, optional): The number of nodes, *i.e.*
:obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`)
:rtype: :class:`Tensor`
"""
num_nodes = maybe_num_nodes(index, num_nodes)
out = src - scatter_max(src, index, dim=0, dim_size=num_nodes)[0][index]
out = out.exp()
out = out / (
scatter_add(out, index, dim=0, dim_size=num_nodes)[index] + 1e-16)
return out
示例4: softmax2
# 需要导入模块: import torch_scatter [as 别名]
# 或者: from torch_scatter import scatter_max [as 别名]
def softmax2(src, index, num_nodes=None):
r"""Computes a sparsely evaluated softmax.
Given a value tensor :attr:`src`, this function first groups the values
along the first dimension based on the indices specified in :attr:`index`,
and then proceeds to compute the softmax individually for each group.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements for applying the softmax.
num_nodes (int, optional): The number of nodes, *i.e.*
:obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`)
:rtype: :class:`Tensor`
"""
num_nodes = maybe_num_nodes(index, num_nodes)
out = src # - scatter_max(src, index, dim=0, dim_size=num_nodes)[0][index]
out = out.exp()
out = out / (
scatter_add(out, index, dim=0, dim_size=num_nodes)[index] + 1e-16)
return out
示例5: scatter_softmax
# 需要导入模块: import torch_scatter [as 别名]
# 或者: from torch_scatter import scatter_max [as 别名]
def scatter_softmax(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
eps: float = 1e-12) -> torch.Tensor:
if not torch.is_floating_point(src):
raise ValueError('`scatter_softmax` can only be computed over tensors '
'with floating point data types.')
index = broadcast(index, src, dim)
max_value_per_index = scatter_max(src, index, dim=dim)[0]
max_per_src_element = max_value_per_index.gather(dim, index)
recentered_scores = src - max_per_src_element
recentered_scores_exp = recentered_scores.exp()
sum_per_index = scatter_sum(recentered_scores_exp, index, dim)
normalizing_constants = sum_per_index.add_(eps).gather(dim, index)
return recentered_scores_exp.div(normalizing_constants)
示例6: scatter_log_softmax
# 需要导入模块: import torch_scatter [as 别名]
# 或者: from torch_scatter import scatter_max [as 别名]
def scatter_log_softmax(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
eps: float = 1e-12) -> torch.Tensor:
if not torch.is_floating_point(src):
raise ValueError('`scatter_log_softmax` can only be computed over '
'tensors with floating point data types.')
index = broadcast(index, src, dim)
max_value_per_index = scatter_max(src, index, dim=dim)[0]
max_per_src_element = max_value_per_index.gather(dim, index)
recentered_scores = src - max_per_src_element
sum_per_index = scatter_sum(recentered_scores.exp(), index, dim)
normalizing_constants = sum_per_index.add_(eps).log_().gather(dim, index)
return recentered_scores.sub_(normalizing_constants)
示例7: topk
# 需要导入模块: import torch_scatter [as 别名]
# 或者: from torch_scatter import scatter_max [as 别名]
def topk(x, ratio, batch, min_score=None, tol=1e-7):
if min_score is not None:
# Make sure that we do not drop all nodes in a graph.
scores_max = scatter_max(x, batch)[0][batch] - tol
scores_min = scores_max.clamp(max=min_score)
perm = torch.nonzero(x > scores_min).view(-1)
else:
num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
batch_size, max_num_nodes = num_nodes.size(0), num_nodes.max().item()
cum_num_nodes = torch.cat(
[num_nodes.new_zeros(1),
num_nodes.cumsum(dim=0)[:-1]], dim=0)
index = torch.arange(batch.size(0), dtype=torch.long, device=x.device)
index = (index - cum_num_nodes[batch]) + (batch * max_num_nodes)
dense_x = x.new_full((batch_size * max_num_nodes, ),
torch.finfo(x.dtype).min)
dense_x[index] = x
dense_x = dense_x.view(batch_size, max_num_nodes)
_, perm = dense_x.sort(dim=-1, descending=True)
perm = perm + cum_num_nodes.view(-1, 1)
perm = perm.view(-1)
k = (ratio * num_nodes.to(torch.float)).ceil().to(torch.long)
mask = [
torch.arange(k[i], dtype=torch.long, device=x.device) +
i * max_num_nodes for i in range(batch_size)
]
mask = torch.cat(mask, dim=0)
perm = perm[mask]
return perm
示例8: scatter_max
# 需要导入模块: import torch_scatter [as 别名]
# 或者: from torch_scatter import scatter_max [as 别名]
def scatter_max(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
return orig_smax(src, index, dim, out, dim_size, fill_value)[0]
示例9: decode
# 需要导入模块: import torch_scatter [as 别名]
# 或者: from torch_scatter import scatter_max [as 别名]
def decode(self, y, ext_x, prev_states, prev_context, encoder_features, encoder_mask):
# forward one step lstm
# y : [b]
embedded = self.embedding(y.unsqueeze(1))
lstm_inputs = self.reduce_layer(torch.cat([embedded, prev_context], 2))
output, states = self.lstm(lstm_inputs, prev_states)
context, energy = self.attention(output,
encoder_features,
encoder_mask)
concat_input = torch.cat((output, context), 2).squeeze(1)
logit_input = torch.tanh(self.concat_layer(concat_input))
logit = self.logit_layer(logit_input) # [b, |V|]
if config.use_pointer:
batch_size = y.size(0)
num_oov = max(torch.max(ext_x - self.vocab_size + 1), 0)
zeros = torch.zeros((batch_size, num_oov), device=config.device)
extended_logit = torch.cat([logit, zeros], dim=1)
out = torch.zeros_like(extended_logit) - INF
out, _ = scatter_max(energy, ext_x, out=out)
out = out.masked_fill(out == -INF, 0)
logit = extended_logit + out
logit = logit.masked_fill(logit == -INF, 0)
# forcing UNK prob 0
logit[:, UNK_ID] = -INF
return logit, states, context
示例10: readout
# 需要导入模块: import torch_scatter [as 别名]
# 或者: from torch_scatter import scatter_max [as 别名]
def readout(x, batch):
x_mean = scatter_mean(x, batch, dim=0)
x_max, _ = scatter_max(x, batch, dim=0)
return torch.cat((x_mean, x_max), dim=-1)
示例11: forward
# 需要导入模块: import torch_scatter [as 别名]
# 或者: from torch_scatter import scatter_max [as 别名]
def forward(self, trg_seq, ext_src_seq, init_states, encoder_outputs, encoder_mask):
# trg_seq : [b,t]
# init_states : [2,b,d]
# encoder_outputs : [b,t,d]
# init_states : a tuple of [2, b, d]
device = trg_seq.device
batch_size, max_len = trg_seq.size()
hidden_size = encoder_outputs.size(-1)
memories = self.get_encoder_features(encoder_outputs)
logits = []
# init decoder hidden states and context vector
prev_states = init_states
prev_context = torch.zeros((batch_size, 1, hidden_size))
prev_context = prev_context.to(device)
for i in range(max_len):
y_i = trg_seq[:, i].unsqueeze(1) # [b, 1]
embedded = self.embedding(y_i) # [b, 1, d]
lstm_inputs = self.reduce_layer(
torch.cat([embedded, prev_context], 2))
output, states = self.lstm(lstm_inputs, prev_states)
# encoder-decoder attention
context, energy = self.attention(output, memories, encoder_mask)
concat_input = torch.cat((output, context), dim=2).squeeze(dim=1)
logit_input = torch.tanh(self.concat_layer(concat_input))
logit = self.logit_layer(logit_input) # [b, |V|]
# maxout pointer network
if config.use_pointer:
num_oov = max(torch.max(ext_src_seq - self.vocab_size + 1), 0)
zeros = torch.zeros((batch_size, num_oov),
device=config.device)
extended_logit = torch.cat([logit, zeros], dim=1)
out = torch.zeros_like(extended_logit) - INF
out, _ = scatter_max(energy, ext_src_seq, out=out)
out = out.masked_fill(out == -INF, 0)
logit = extended_logit + out
logit = logit.masked_fill(logit == 0, -INF)
logits.append(logit)
# update prev state and context
prev_states = states
prev_context = context
logits = torch.stack(logits, dim=1) # [b, t, |V|]
return logits
示例12: forward
# 需要导入模块: import torch_scatter [as 别名]
# 或者: from torch_scatter import scatter_max [as 别名]
def forward(self, x, edge_index, edge_weight=None, batch=None):
if batch is None:
batch = edge_index.new_zeros(x.size(0))
# NxF
x = x.unsqueeze(-1) if x.dim() == 1 else x
# Add Self Loops
fill_value = 1
num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
edge_index, edge_weight = add_remaining_self_loops(edge_index=edge_index, edge_weight=edge_weight,
fill_value=fill_value, num_nodes=num_nodes.sum())
N = x.size(0) # total num of nodes in batch
# ExF
x_pool = self.gnn_intra_cluster(x=x, edge_index=edge_index, edge_weight=edge_weight)
x_pool_j = x_pool[edge_index[1]]
x_j = x[edge_index[1]]
#---Master query formation---
# NxF
X_q, _ = scatter_max(x_pool_j, edge_index[0], dim=0)
# NxF
M_q = self.lin_q(X_q)
# ExF
M_q = M_q[edge_index[0].tolist()]
score = self.gat_att(torch.cat((M_q, x_pool_j), dim=-1))
score = F.leaky_relu(score, self.negative_slope)
score = softmax(score, edge_index[0], num_nodes=num_nodes.sum())
# Sample attention coefficients stochastically.
score = F.dropout(score, p=self.dropout_att, training=self.training)
# ExF
v_j = x_j * score.view(-1, 1)
#---Aggregation---
# NxF
out = scatter_add(v_j, edge_index[0], dim=0)
#---Cluster Selection
# Nx1
fitness = torch.sigmoid(self.gnn_score(x=out, edge_index=edge_index)).view(-1)
perm = topk(x=fitness, ratio=self.ratio, batch=batch)
x = out[perm] * fitness[perm].view(-1, 1)
#---Maintaining Graph Connectivity
batch = batch[perm]
edge_index, edge_weight = graph_connectivity(
device = x.device,
perm=perm,
edge_index=edge_index,
edge_weight=edge_weight,
score=score,
ratio=self.ratio,
batch=batch,
N=N)
return x, edge_index, edge_weight, batch, perm