本文整理匯總了Python中torch.einsum方法的典型用法代碼示例。如果您正苦於以下問題:Python torch.einsum方法的具體用法?Python torch.einsum怎麽用?Python torch.einsum使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類torch
的用法示例。
在下文中一共展示了torch.einsum方法的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: forward
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import einsum [as 別名]
def forward(self, x, rbf, sbf, idx_kj, idx_ji):
rbf = self.lin_rbf(rbf)
sbf = self.lin_sbf(sbf)
x_ji = self.act(self.lin_ji(x))
x_kj = self.act(self.lin_kj(x))
x_kj = x_kj * rbf
x_kj = torch.einsum('wj,wl,ijl->wi', sbf, x_kj[idx_kj], self.W)
x_kj = scatter(x_kj, idx_ji, dim=0, dim_size=x.size(0))
h = x_ji + x_kj
for layer in self.layers_before_skip:
h = layer(h)
h = self.act(self.lin(h)) + x
for layer in self.layers_after_skip:
h = layer(h)
return h
示例2: _get_overlaps_tensor
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import einsum [as 別名]
def _get_overlaps_tensor(self, L):
"""Transforms the input label matrix to a three-way overlaps tensor.
Args:
L: (np.array) An n x m array of LF output labels, in {0,...,k} if
self.abstains, else in {1,...,k}, generated by m conditionally
independent LFs on n data points
Outputs:
O: (torch.Tensor) A (m, m, m, k, k, k) tensor of the label-specific
empirical overlap rates; that is,
O[i,j,k,y1,y2,y3] = P(\lf_i = y1, \lf_j = y2, \lf_k = y3)
where this quantity is computed empirically by this function, based
on the label matrix L.
"""
n, m = L.shape
# Convert from a (n,m) matrix of ints to a (k_lf, n, m) indicator tensor
LY = np.array([np.where(L == y, 1, 0) for y in range(self.k_0, self.k + 1)])
# Form the three-way overlaps matrix
O = np.einsum("abc,dbe,fbg->cegadf", LY, LY, LY) / n
return torch.from_numpy(O).float()
示例3: bdd_message_func
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import einsum [as 別名]
def bdd_message_func(self, edges):
"""Message function for block-diagonal-decomposition regularizer"""
if edges.src['h'].dtype == th.int64 and len(edges.src['h'].shape) == 1:
raise TypeError('Block decomposition does not allow integer ID feature.')
# calculate msg @ W_r before put msg into edge
if self.low_mem:
etypes = th.unique(edges.data['type'])
msg = th.empty((edges.src['h'].shape[0], self.out_feat),
device=edges.src['h'].device)
for etype in etypes:
loc = edges.data['type'] == etype
w = self.weight[etype].view(self.num_bases, self.submat_in, self.submat_out)
src = edges.src['h'][loc].view(-1, self.num_bases, self.submat_in)
sub_msg = th.einsum('abc,bcd->abd', src, w)
sub_msg = sub_msg.reshape(-1, self.out_feat)
msg[loc] = sub_msg
else:
weight = self.weight.index_select(0, edges.data['type']).view(
-1, self.submat_in, self.submat_out)
node = edges.src['h'].view(-1, 1, self.submat_in)
msg = th.bmm(node, weight).view(-1, self.out_feat)
if 'norm' in edges.data:
msg = msg * edges.data['norm']
return {'msg': msg}
示例4: msg_func
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import einsum [as 別名]
def msg_func(edges):
"""Send messages along edges.
Parameters
----------
edges : EdgeBatch
A batch of edges.
Returns
-------
dict mapping 'm' to Float32 tensor of shape (E, K * T)
Messages computed. E for the number of edges, K for the number of
radial filters and T for the number of features to use
(types of atomic number in the paper).
"""
return {'m': th.einsum(
'ij,ik->ijk', edges.src['hv'], edges.data['he']).view(len(edges), -1)}
示例5: forward
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import einsum [as 別名]
def forward(ctx, input, kernel, kernel_size, stride=1, padding=0, dilation=1):
(bs, ch), in_sz = input.shape[:2], input.shape[2:]
if kernel.size(1) > 1 and kernel.size(1) != ch:
raise ValueError('Incompatible input and kernel sizes.')
ctx.input_size = in_sz
ctx.kernel_size = _pair(kernel_size)
ctx.kernel_ch = kernel.size(1)
ctx.dilation = _pair(dilation)
ctx.padding = _pair(padding)
ctx.stride = _pair(stride)
ctx.save_for_backward(input if ctx.needs_input_grad[1] else None,
kernel if ctx.needs_input_grad[0] else None)
ctx._backend = type2backend[input.type()]
cols = F.unfold(input, ctx.kernel_size, ctx.dilation, ctx.padding, ctx.stride)
output = cols.view(bs, ch, *kernel.shape[2:]) * kernel
output = torch.einsum('ijklmn->ijmn', (output,))
return output.clone() # TODO check whether a .clone() is needed here
示例6: pacconv2d
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import einsum [as 別名]
def pacconv2d(input, kernel, weight, bias=None, stride=1, padding=0, dilation=1, shared_filters=False,
native_impl=False):
kernel_size = tuple(weight.shape[-2:])
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
if native_impl:
# im2col on input
im_cols = nd2col(input, kernel_size, stride=stride, padding=padding, dilation=dilation)
# main computation
if shared_filters:
output = torch.einsum('ijklmn,zykl->ijmn', (im_cols * kernel, weight))
else:
output = torch.einsum('ijklmn,ojkl->iomn', (im_cols * kernel, weight))
if bias is not None:
output += bias.view(1, -1, 1, 1)
else:
output = PacConv2dFn.apply(input, kernel, weight, bias, stride, padding, dilation, shared_filters)
return output
示例7: forward
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import einsum [as 別名]
def forward(self, x, y):
"""
The computation logic of MatchingTensor.
:param inputs: two input tensors.
"""
if self._normalize:
x = F.normalize(x, p=2, dim=-1)
y = F.normalize(y, p=2, dim=-1)
# output = [b, c, l, r]
output = torch.einsum(
'bld,cde,bre->bclr',
x, self.interaction_matrix, y
)
return output
示例8: kfac_mat_prod
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import einsum [as 別名]
def kfac_mat_prod(factors):
"""Return function v ↦ (A ⊗ B ⊗ ...)v for `factors = [A, B, ...]` """
assert all_tensors_of_order(order=2, tensors=factors)
shapes = [list(f.size()) for f in factors]
_, col_dims = zip(*shapes)
num_factors = len(shapes)
equation = kfac_mat_prod_einsum_equation(num_factors)
@kfacmp_unsqueeze_if_missing_dim(mat_dim=2)
def kfacmp(mat):
assert is_matrix(mat)
_, mat_cols = mat.shape
mat_reshaped = mat.view(*(col_dims), mat_cols)
return einsum(equation, mat_reshaped, *factors).contiguous().view(-1, mat_cols)
return kfacmp
示例9: _sqrt_hessian_sampled
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import einsum [as 別名]
def _sqrt_hessian_sampled(self, module, g_inp, g_out, mc_samples=1):
self._check_2nd_order_parameters(module)
M = mc_samples
C = module.input0.shape[1]
probs = self._get_probs(module)
V_dim = 0
probs_unsqueezed = probs.unsqueeze(V_dim).repeat(M, 1, 1)
multi = multinomial(probs, M, replacement=True)
classes = one_hot(multi, num_classes=C)
classes = einsum("nvc->vnc", classes).float()
sqrt_mc_h = (probs_unsqueezed - classes) / sqrt(M)
if module.reduction == "mean":
N = module.input0.shape[0]
sqrt_mc_h /= sqrt(N)
return sqrt_mc_h
示例10: _make_hessian_mat_prod
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import einsum [as 別名]
def _make_hessian_mat_prod(self, module, g_inp, g_out):
"""Multiplication of the input Hessian with a matrix."""
self._check_2nd_order_parameters(module)
probs = self._get_probs(module)
def hessian_mat_prod(mat):
Hmat = einsum("bi,cbi->cbi", (probs, mat)) - einsum(
"bi,bj,cbj->cbi", (probs, probs, mat)
)
if module.reduction == "mean":
N = module.input0.shape[0]
Hmat /= N
return Hmat
return hessian_mat_prod
示例11: projx
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import einsum [as 別名]
def projx(self, x: torch.Tensor) -> torch.Tensor:
U, _, V = linalg.batch_linalg.svd(x)
return torch.einsum("...ik,...jk->...ij", U, V)
示例12: forward
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import einsum [as 別名]
def forward(cls, ctx, X, target, alpha, proj_args):
"""
X (FloatTensor): n x num_classes
target (LongTensor): n, the indices of the target classes
"""
assert X.shape[0] == target.shape[0]
p_star = cls.project(X, alpha, **proj_args)
loss = cls.omega(p_star, alpha)
p_star.scatter_add_(1, target.unsqueeze(1), torch.full_like(p_star, -1))
loss += torch.einsum("ij,ij->i", p_star, X)
ctx.save_for_backward(p_star)
return loss
示例13: _rank3_trace
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import einsum [as 別名]
def _rank3_trace(x):
return torch.einsum('ijj->i', x)
示例14: forward
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import einsum [as 別名]
def forward(self, pos_edge_index, neg_edge_index):
x = F.relu(self.conv1(data.x, data.train_pos_edge_index))
x = self.conv2(x, data.train_pos_edge_index)
total_edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
x_j = torch.index_select(x, 0, total_edge_index[0])
x_i = torch.index_select(x, 0, total_edge_index[1])
return torch.einsum("ef,ef->e", x_i, x_j)
示例15: relative_logits_1d
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import einsum [as 別名]
def relative_logits_1d(self, q, rel_k, H, W, Nh, case):
rel_logits = torch.einsum('bhxyd,md->bhxym', q, rel_k)
rel_logits = torch.reshape(rel_logits, (-1, Nh * H, W, 2 * W - 1))
rel_logits = self.rel_to_abs(rel_logits)
rel_logits = torch.reshape(rel_logits, (-1, Nh, H, W, W))
rel_logits = torch.unsqueeze(rel_logits, dim=3)
rel_logits = rel_logits.repeat((1, 1, 1, H, 1, 1))
if case == "w":
rel_logits = torch.transpose(rel_logits, 3, 4)
elif case == "h":
rel_logits = torch.transpose(rel_logits, 2, 4).transpose(4, 5).transpose(3, 5)
rel_logits = torch.reshape(rel_logits, (-1, Nh, H * W, H * W))
return rel_logits