當前位置: 首頁>>代碼示例>>Python>>正文


Python torch.einsum方法代碼示例

本文整理匯總了Python中torch.einsum方法的典型用法代碼示例。如果您正苦於以下問題:Python torch.einsum方法的具體用法?Python torch.einsum怎麽用?Python torch.einsum使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在torch的用法示例。


在下文中一共展示了torch.einsum方法的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: forward

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import einsum [as 別名]
def forward(self, x, rbf, sbf, idx_kj, idx_ji):
        rbf = self.lin_rbf(rbf)
        sbf = self.lin_sbf(sbf)

        x_ji = self.act(self.lin_ji(x))
        x_kj = self.act(self.lin_kj(x))
        x_kj = x_kj * rbf
        x_kj = torch.einsum('wj,wl,ijl->wi', sbf, x_kj[idx_kj], self.W)
        x_kj = scatter(x_kj, idx_ji, dim=0, dim_size=x.size(0))

        h = x_ji + x_kj
        for layer in self.layers_before_skip:
            h = layer(h)
        h = self.act(self.lin(h)) + x
        for layer in self.layers_after_skip:
            h = layer(h)

        return h 
開發者ID:rusty1s,項目名稱:pytorch_geometric,代碼行數:20,代碼來源:dimenet.py

示例2: _get_overlaps_tensor

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import einsum [as 別名]
def _get_overlaps_tensor(self, L):
        """Transforms the input label matrix to a three-way overlaps tensor.

        Args:
            L: (np.array) An n x m array of LF output labels, in {0,...,k} if
                self.abstains, else in {1,...,k}, generated by m conditionally
                independent LFs on n data points

        Outputs:
            O: (torch.Tensor) A (m, m, m, k, k, k) tensor of the label-specific
            empirical overlap rates; that is,

                O[i,j,k,y1,y2,y3] = P(\lf_i = y1, \lf_j = y2, \lf_k = y3)

            where this quantity is computed empirically by this function, based
            on the label matrix L.
        """
        n, m = L.shape

        # Convert from a (n,m) matrix of ints to a (k_lf, n, m) indicator tensor
        LY = np.array([np.where(L == y, 1, 0) for y in range(self.k_0, self.k + 1)])

        # Form the three-way overlaps matrix
        O = np.einsum("abc,dbe,fbg->cegadf", LY, LY, LY) / n
        return torch.from_numpy(O).float() 
開發者ID:HazyResearch,項目名稱:metal,代碼行數:27,代碼來源:class_balance.py

示例3: bdd_message_func

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import einsum [as 別名]
def bdd_message_func(self, edges):
        """Message function for block-diagonal-decomposition regularizer"""
        if edges.src['h'].dtype == th.int64 and len(edges.src['h'].shape) == 1:
            raise TypeError('Block decomposition does not allow integer ID feature.')

        # calculate msg @ W_r before put msg into edge
        if self.low_mem:
            etypes = th.unique(edges.data['type'])
            msg = th.empty((edges.src['h'].shape[0], self.out_feat),
                           device=edges.src['h'].device)
            for etype in etypes:
                loc = edges.data['type'] == etype
                w = self.weight[etype].view(self.num_bases, self.submat_in, self.submat_out)
                src = edges.src['h'][loc].view(-1, self.num_bases, self.submat_in)
                sub_msg = th.einsum('abc,bcd->abd', src, w)
                sub_msg = sub_msg.reshape(-1, self.out_feat)
                msg[loc] = sub_msg
        else:
            weight = self.weight.index_select(0, edges.data['type']).view(
                -1, self.submat_in, self.submat_out)
            node = edges.src['h'].view(-1, 1, self.submat_in)
            msg = th.bmm(node, weight).view(-1, self.out_feat)
        if 'norm' in edges.data:
            msg = msg * edges.data['norm']
        return {'msg': msg} 
開發者ID:dmlc,項目名稱:dgl,代碼行數:27,代碼來源:relgraphconv.py

示例4: msg_func

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import einsum [as 別名]
def msg_func(edges):
    """Send messages along edges.

    Parameters
    ----------
    edges : EdgeBatch
        A batch of edges.

    Returns
    -------
    dict mapping 'm' to Float32 tensor of shape (E, K * T)
        Messages computed. E for the number of edges, K for the number of
        radial filters and T for the number of features to use
        (types of atomic number in the paper).
    """
    return {'m': th.einsum(
        'ij,ik->ijk', edges.src['hv'], edges.data['he']).view(len(edges), -1)} 
開發者ID:dmlc,項目名稱:dgl,代碼行數:19,代碼來源:atomicconv.py

示例5: forward

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import einsum [as 別名]
def forward(ctx, input, kernel, kernel_size, stride=1, padding=0, dilation=1):
        (bs, ch), in_sz = input.shape[:2], input.shape[2:]
        if kernel.size(1) > 1 and kernel.size(1) != ch:
            raise ValueError('Incompatible input and kernel sizes.')
        ctx.input_size = in_sz
        ctx.kernel_size = _pair(kernel_size)
        ctx.kernel_ch = kernel.size(1)
        ctx.dilation = _pair(dilation)
        ctx.padding = _pair(padding)
        ctx.stride = _pair(stride)
        ctx.save_for_backward(input if ctx.needs_input_grad[1] else None,
                              kernel if ctx.needs_input_grad[0] else None)
        ctx._backend = type2backend[input.type()]

        cols = F.unfold(input, ctx.kernel_size, ctx.dilation, ctx.padding, ctx.stride)

        output = cols.view(bs, ch, *kernel.shape[2:]) * kernel
        output = torch.einsum('ijklmn->ijmn', (output,))

        return output.clone()  # TODO check whether a .clone() is needed here 
開發者ID:openseg-group,項目名稱:openseg.pytorch,代碼行數:22,代碼來源:pac.py

示例6: pacconv2d

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import einsum [as 別名]
def pacconv2d(input, kernel, weight, bias=None, stride=1, padding=0, dilation=1, shared_filters=False,
              native_impl=False):
    kernel_size = tuple(weight.shape[-2:])
    stride = _pair(stride)
    padding = _pair(padding)
    dilation = _pair(dilation)

    if native_impl:
        # im2col on input
        im_cols = nd2col(input, kernel_size, stride=stride, padding=padding, dilation=dilation)

        # main computation
        if shared_filters:
            output = torch.einsum('ijklmn,zykl->ijmn', (im_cols * kernel, weight))
        else:
            output = torch.einsum('ijklmn,ojkl->iomn', (im_cols * kernel, weight))

        if bias is not None:
            output += bias.view(1, -1, 1, 1)
    else:
        output = PacConv2dFn.apply(input, kernel, weight, bias, stride, padding, dilation, shared_filters)

    return output 
開發者ID:openseg-group,項目名稱:openseg.pytorch,代碼行數:25,代碼來源:pac.py

示例7: forward

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import einsum [as 別名]
def forward(self, x, y):
        """
        The computation logic of MatchingTensor.

        :param inputs: two input tensors.
        """

        if self._normalize:
            x = F.normalize(x, p=2, dim=-1)
            y = F.normalize(y, p=2, dim=-1)

        # output = [b, c, l, r]
        output = torch.einsum(
            'bld,cde,bre->bclr',
            x, self.interaction_matrix, y
        )
        return output 
開發者ID:NTMC-Community,項目名稱:MatchZoo-py,代碼行數:19,代碼來源:matching_tensor.py

示例8: kfac_mat_prod

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import einsum [as 別名]
def kfac_mat_prod(factors):
    """Return function v ↦ (A ⊗ B ⊗ ...)v for `factors = [A, B, ...]` """
    assert all_tensors_of_order(order=2, tensors=factors)

    shapes = [list(f.size()) for f in factors]
    _, col_dims = zip(*shapes)

    num_factors = len(shapes)
    equation = kfac_mat_prod_einsum_equation(num_factors)

    @kfacmp_unsqueeze_if_missing_dim(mat_dim=2)
    def kfacmp(mat):
        assert is_matrix(mat)
        _, mat_cols = mat.shape
        mat_reshaped = mat.view(*(col_dims), mat_cols)
        return einsum(equation, mat_reshaped, *factors).contiguous().view(-1, mat_cols)

    return kfacmp 
開發者ID:f-dangel,項目名稱:backpack,代碼行數:20,代碼來源:kroneckers.py

示例9: _sqrt_hessian_sampled

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import einsum [as 別名]
def _sqrt_hessian_sampled(self, module, g_inp, g_out, mc_samples=1):
        self._check_2nd_order_parameters(module)

        M = mc_samples
        C = module.input0.shape[1]

        probs = self._get_probs(module)
        V_dim = 0
        probs_unsqueezed = probs.unsqueeze(V_dim).repeat(M, 1, 1)

        multi = multinomial(probs, M, replacement=True)
        classes = one_hot(multi, num_classes=C)
        classes = einsum("nvc->vnc", classes).float()

        sqrt_mc_h = (probs_unsqueezed - classes) / sqrt(M)

        if module.reduction == "mean":
            N = module.input0.shape[0]
            sqrt_mc_h /= sqrt(N)

        return sqrt_mc_h 
開發者ID:f-dangel,項目名稱:backpack,代碼行數:23,代碼來源:crossentropyloss.py

示例10: _make_hessian_mat_prod

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import einsum [as 別名]
def _make_hessian_mat_prod(self, module, g_inp, g_out):
        """Multiplication of the input Hessian with a matrix."""
        self._check_2nd_order_parameters(module)

        probs = self._get_probs(module)

        def hessian_mat_prod(mat):
            Hmat = einsum("bi,cbi->cbi", (probs, mat)) - einsum(
                "bi,bj,cbj->cbi", (probs, probs, mat)
            )

            if module.reduction == "mean":
                N = module.input0.shape[0]
                Hmat /= N

            return Hmat

        return hessian_mat_prod 
開發者ID:f-dangel,項目名稱:backpack,代碼行數:20,代碼來源:crossentropyloss.py

示例11: projx

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import einsum [as 別名]
def projx(self, x: torch.Tensor) -> torch.Tensor:
        U, _, V = linalg.batch_linalg.svd(x)
        return torch.einsum("...ik,...jk->...ij", U, V) 
開發者ID:geoopt,項目名稱:geoopt,代碼行數:5,代碼來源:stiefel.py

示例12: forward

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import einsum [as 別名]
def forward(cls, ctx, X, target, alpha, proj_args):
        """
        X (FloatTensor): n x num_classes
        target (LongTensor): n, the indices of the target classes
        """
        assert X.shape[0] == target.shape[0]

        p_star = cls.project(X, alpha, **proj_args)
        loss = cls.omega(p_star, alpha)

        p_star.scatter_add_(1, target.unsqueeze(1), torch.full_like(p_star, -1))
        loss += torch.einsum("ij,ij->i", p_star, X)
        ctx.save_for_backward(p_star)

        return loss 
開發者ID:deep-spin,項目名稱:entmax,代碼行數:17,代碼來源:losses.py

示例13: _rank3_trace

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import einsum [as 別名]
def _rank3_trace(x):
    return torch.einsum('ijj->i', x) 
開發者ID:rusty1s,項目名稱:pytorch_geometric,代碼行數:4,代碼來源:mincut_pool.py

示例14: forward

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import einsum [as 別名]
def forward(self, pos_edge_index, neg_edge_index):

        x = F.relu(self.conv1(data.x, data.train_pos_edge_index))
        x = self.conv2(x, data.train_pos_edge_index)

        total_edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
        x_j = torch.index_select(x, 0, total_edge_index[0])
        x_i = torch.index_select(x, 0, total_edge_index[1])
        return torch.einsum("ef,ef->e", x_i, x_j) 
開發者ID:rusty1s,項目名稱:pytorch_geometric,代碼行數:11,代碼來源:link_pred.py

示例15: relative_logits_1d

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import einsum [as 別名]
def relative_logits_1d(self, q, rel_k, H, W, Nh, case):
        rel_logits = torch.einsum('bhxyd,md->bhxym', q, rel_k)
        rel_logits = torch.reshape(rel_logits, (-1, Nh * H, W, 2 * W - 1))
        rel_logits = self.rel_to_abs(rel_logits)

        rel_logits = torch.reshape(rel_logits, (-1, Nh, H, W, W))
        rel_logits = torch.unsqueeze(rel_logits, dim=3)
        rel_logits = rel_logits.repeat((1, 1, 1, H, 1, 1))

        if case == "w":
            rel_logits = torch.transpose(rel_logits, 3, 4)
        elif case == "h":
            rel_logits = torch.transpose(rel_logits, 2, 4).transpose(4, 5).transpose(3, 5)
        rel_logits = torch.reshape(rel_logits, (-1, Nh, H * W, H * W))
        return rel_logits 
開發者ID:leaderj1001,項目名稱:Attention-Augmented-Conv2d,代碼行數:17,代碼來源:attention_augmented_conv.py


注:本文中的torch.einsum方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。