当前位置: 首页>>代码示例>>Python>>正文


Python torch.baddbmm方法代码示例

本文整理汇总了Python中torch.baddbmm方法的典型用法代码示例。如果您正苦于以下问题:Python torch.baddbmm方法的具体用法?Python torch.baddbmm怎么用?Python torch.baddbmm使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch的用法示例。


在下文中一共展示了torch.baddbmm方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: backward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import baddbmm [as 别名]
def backward(self, grad_output):

        grad_input1 = self.input1.new(self.input1.size()).zero_()

        # if grad_output.is_cuda:
        #    self.batchgrid = self.batchgrid.cuda()
        #    grad_input1 = grad_input1.cuda()

        grad_input1 = torch.baddbmm(grad_input1,
                                    torch.transpose(grad_output.view(-1,
                                                                     self.height * self.width,
                                                                     2), 1, 2),
                                    self.batchgrid.view(-1,
                                                        self.height *
                                                        self.width,
                                                        3))
        return grad_input1 
开发者ID:ucbdrive,项目名称:3d-vehicle-tracking,代码行数:19,代码来源:gridgen.py

示例2: VarLSTMCell

# 需要导入模块: import torch [as 别名]
# 或者: from torch import baddbmm [as 别名]
def VarLSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None, noise_in=None, noise_hidden=None):
    input = input.expand(4, *input.size()) if noise_in is None else input.unsqueeze(0) * noise_in

    hx, cx = hidden
    hx = hx.expand(4, *hx.size()) if noise_hidden is None else hx.unsqueeze(0) * noise_hidden

    gates = torch.baddbmm(b_ih.unsqueeze(1), input, w_ih) + torch.baddbmm(b_hh.unsqueeze(1), hx, w_hh)

    ingate, forgetgate, cellgate, outgate = gates

    ingate = F.sigmoid(ingate)
    forgetgate = F.sigmoid(forgetgate)
    cellgate = F.tanh(cellgate)
    outgate = F.sigmoid(outgate)

    cy = (forgetgate * cx) + (ingate * cellgate)
    hy = outgate * F.tanh(cy)

    return hy, cy 
开发者ID:thomas0809,项目名称:GraphIE,代码行数:21,代码来源:variational_rnn.py

示例3: SkipConnectLSTMCell

# 需要导入模块: import torch [as 别名]
# 或者: from torch import baddbmm [as 别名]
def SkipConnectLSTMCell(input, hidden, hidden_skip, w_ih, w_hh, b_ih=None, b_hh=None, noise_in=None, noise_hidden=None):
    input = input.expand(4, *input.size()) if noise_in is None else input.unsqueeze(0) * noise_in

    hx, cx = hidden
    hx = torch.cat([hx, hidden_skip], dim=1)
    hx = hx.expand(4, *hx.size()) if noise_hidden is None else hx.unsqueeze(0) * noise_hidden

    gates = torch.baddbmm(b_ih.unsqueeze(1), input, w_ih) + torch.baddbmm(b_hh.unsqueeze(1), hx, w_hh)

    ingate, forgetgate, cellgate, outgate = gates

    ingate = F.sigmoid(ingate)
    forgetgate = F.sigmoid(forgetgate)
    cellgate = F.tanh(cellgate)
    outgate = F.sigmoid(outgate)

    cy = (forgetgate * cx) + (ingate * cellgate)
    hy = outgate * F.tanh(cy)

    return hy, cy 
开发者ID:thomas0809,项目名称:GraphIE,代码行数:22,代码来源:skipconnect_rnn.py

示例4: SkipConnectGRUCell

# 需要导入模块: import torch [as 别名]
# 或者: from torch import baddbmm [as 别名]
def SkipConnectGRUCell(input, hidden, hidden_skip, w_ih, w_hh, b_ih=None, b_hh=None, noise_in=None, noise_hidden=None):
    input = input.expand(3, *input.size()) if noise_in is None else input.unsqueeze(0) * noise_in
    hx = torch.cat([hidden, hidden_skip], dim=1)
    hx = hx.expand(3, *hx.size()) if noise_hidden is None else hx.unsqueeze(0) * noise_hidden

    gi = torch.baddbmm(b_ih.unsqueeze(1), input, w_ih)
    gh = torch.baddbmm(b_hh.unsqueeze(1), hx, w_hh)
    i_r, i_i, i_n = gi
    h_r, h_i, h_n = gh

    resetgate = F.sigmoid(i_r + h_r)
    inputgate = F.sigmoid(i_i + h_i)
    newgate = F.tanh(i_n + resetgate * h_n)
    hy = newgate + inputgate * (hidden - newgate)

    return hy 
开发者ID:thomas0809,项目名称:GraphIE,代码行数:18,代码来源:skipconnect_rnn.py

示例5: batch_transform_xyz

# 需要导入模块: import torch [as 别名]
# 或者: from torch import baddbmm [as 别名]
def batch_transform_xyz(xyz_tensor, R, t, get_Jacobian=True):
    '''
    transform the point cloud w.r.t. the transformation matrix
    :param xyz_tensor: B * 3 * H * W
    :param R: rotation matrix B * 3 * 3
    :param t: translation vector B * 3
    '''
    B, C, H, W = xyz_tensor.size()
    t_tensor = t.contiguous().view(B,3,1).repeat(1,1,H*W)
    p_tensor = xyz_tensor.contiguous().view(B, C, H*W)
    # the transformation process is simply:
    # p' = t + R*p
    xyz_t_tensor = torch.baddbmm(t_tensor, R, p_tensor)

    if get_Jacobian:
        # return both the transformed tensor and its Jacobian matrix
        J_r = R.bmm(batch_skew_symmetric_matrix(-1*p_tensor.permute(0,2,1)))
        J_t = -1 * torch.eye(3).view(1,3,3).expand(B,3,3)
        J = torch.cat((J_r, J_t), 1)
        return xyz_t_tensor.view(B, C, H, W), J
    else:
        return xyz_t_tensor.view(B, C, H, W) 
开发者ID:lvzhaoyang,项目名称:DeeperInverseCompositionalAlgorithm,代码行数:24,代码来源:geometry.py

示例6: var_lstm_cell

# 需要导入模块: import torch [as 别名]
# 或者: from torch import baddbmm [as 别名]
def var_lstm_cell(input: Tensor, hidden: Tuple[Tensor, Tensor], w_ih: Tensor, w_hh: Tensor,
                  b_ih: Tensor = None, b_hh: Tensor = None, noise_in: Tensor = None, noise_hidden: Tensor = None) \
        -> Tuple[Tensor, Tensor]:
    input = input.expand(4, *input.size()) if noise_in is None else input.unsqueeze(0) * noise_in

    hx, cx = hidden
    hx = hx.expand(4, *hx.size()) if noise_hidden is None else hx.unsqueeze(0) * noise_hidden

    gates = torch.add(torch.baddbmm(b_ih.unsqueeze(1), input, w_ih), torch.baddbmm(b_hh.unsqueeze(1), hx, w_hh))

    ingate, forgetgate, cellgate, outgate = gates

    ingate = torch.sigmoid(ingate)
    forgetgate = torch.sigmoid(forgetgate)
    cellgate = torch.tanh(cellgate)
    outgate = torch.sigmoid(outgate)

    cy = torch.add(torch.mul(forgetgate, cx), torch.mul(ingate, cellgate))
    hy = torch.mul(outgate, torch.tanh(cy))

    return hy, cy 
开发者ID:yahshibu,项目名称:nested-ner-tacl2020-transformers,代码行数:23,代码来源:variational_rnn.py

示例7: apply_classification_weights

# 需要导入模块: import torch [as 别名]
# 或者: from torch import baddbmm [as 别名]
def apply_classification_weights(self, features, cls_weights):
        """
        Given feature and weights, computing negative log-likelihoods of nKnovel classes
        (B x n x nFeat, B x nKnovel x nFeat) -> B x n x nKnovel

        :param features: features of query set.
        :type features: torch.FloatTensor
        :param cls_weights: generated weights.
        :type cls_weights: torch.FloatTensor
        :return: classification scores
        :rtype: torch.FloatTensor
        """
        features = F.normalize(features, p=2, dim=features.dim()-1, eps=1e-12)
        cls_weights = F.normalize(cls_weights, p=2, dim=cls_weights.dim()-1, eps=1e-12)

        cls_scores = self.scale_cls * torch.baddbmm(1.0, self.bias.view(1, 1, 1), 1.0,
                                                    features, cls_weights.transpose(1,2))
        return cls_scores 
开发者ID:amzn,项目名称:xfer,代码行数:20,代码来源:sib.py

示例8: VarLSTMCell

# 需要导入模块: import torch [as 别名]
# 或者: from torch import baddbmm [as 别名]
def VarLSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None, noise_in=None, noise_hidden=None):
    input = input.expand(4, *input.size()) if noise_in is None else input.unsqueeze(0) * noise_in

    hx, cx = hidden
    hx = hx.expand(4, *hx.size()) if noise_hidden is None else hx.unsqueeze(0) * noise_hidden

    gates = torch.baddbmm(b_ih.unsqueeze(1), input, w_ih) + torch.baddbmm(b_hh.unsqueeze(1), hx, w_hh)

    ingate, forgetgate, cellgate, outgate = gates

    ingate = torch.sigmoid(ingate)
    forgetgate = torch.sigmoid(forgetgate)
    cellgate = torch.tanh(cellgate)
    outgate = torch.sigmoid(outgate)

    cy = (forgetgate * cx) + (ingate * cellgate)
    hy = outgate * torch.tanh(cy)

    return hy, cy 
开发者ID:XuezheMax,项目名称:NeuroNLP2,代码行数:21,代码来源:variational_rnn.py

示例9: SkipConnectLSTMCell

# 需要导入模块: import torch [as 别名]
# 或者: from torch import baddbmm [as 别名]
def SkipConnectLSTMCell(input, hidden, hidden_skip, w_ih, w_hh, b_ih=None, b_hh=None, noise_in=None, noise_hidden=None):
    input = input.expand(4, *input.size()) if noise_in is None else input.unsqueeze(0) * noise_in

    hx, cx = hidden
    hx = torch.cat([hx, hidden_skip], dim=1)
    hx = hx.expand(4, *hx.size()) if noise_hidden is None else hx.unsqueeze(0) * noise_hidden

    gates = torch.baddbmm(b_ih.unsqueeze(1), input, w_ih) + torch.baddbmm(b_hh.unsqueeze(1), hx, w_hh)

    ingate, forgetgate, cellgate, outgate = gates

    ingate = torch.sigmoid(ingate)
    forgetgate = torch.sigmoid(forgetgate)
    cellgate = torch.tanh(cellgate)
    outgate = torch.sigmoid(outgate)

    cy = (forgetgate * cx) + (ingate * cellgate)
    hy = outgate * torch.tanh(cy)

    return hy, cy 
开发者ID:XuezheMax,项目名称:NeuroNLP2,代码行数:22,代码来源:skipconnect_rnn.py

示例10: SkipConnectGRUCell

# 需要导入模块: import torch [as 别名]
# 或者: from torch import baddbmm [as 别名]
def SkipConnectGRUCell(input, hidden, hidden_skip, w_ih, w_hh, b_ih=None, b_hh=None, noise_in=None, noise_hidden=None):
    input = input.expand(3, *input.size()) if noise_in is None else input.unsqueeze(0) * noise_in
    hx = torch.cat([hidden, hidden_skip], dim=1)
    hx = hx.expand(3, *hx.size()) if noise_hidden is None else hx.unsqueeze(0) * noise_hidden

    gi = torch.baddbmm(b_ih.unsqueeze(1), input, w_ih)
    gh = torch.baddbmm(b_hh.unsqueeze(1), hx, w_hh)
    i_r, i_i, i_n = gi
    h_r, h_i, h_n = gh

    resetgate = torch.sigmoid(i_r + h_r)
    inputgate = torch.sigmoid(i_i + h_i)
    newgate = torch.tanh(i_n + resetgate * h_n)
    hy = newgate + inputgate * (hidden - newgate)

    return hy 
开发者ID:XuezheMax,项目名称:NeuroNLP2,代码行数:18,代码来源:skipconnect_rnn.py

示例11: backward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import baddbmm [as 别名]
def backward(self, grad_output):

        grad_input1 = self.input1.new(self.input1.size()).zero_()

        # if grad_output.is_cuda:
        #    self.batchgrid = self.batchgrid.cuda()
        #    grad_input1 = grad_input1.cuda()

        grad_input1 = torch.baddbmm(grad_input1, torch.transpose(grad_output.view(-1, self.height*self.width, 2), 1,2), self.batchgrid.view(-1, self.height*self.width, 3))
        return grad_input1 
开发者ID:guoruoqian,项目名称:cascade-rcnn_Pytorch,代码行数:12,代码来源:gridgen.py

示例12: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import baddbmm [as 别名]
def forward(self, input, indices=None):
        """
        Shape:
            - target_batch :math:`(N, E, 1+N_r)`where `N = length, E = embedding size, N_r = noise ratio`
        """

        if indices is None:
            return super(IndexLinear, self).forward(input)
        # the pytorch's [] operator BP can't correctly
        input = input.unsqueeze(1)
        target_batch = self.weight.index_select(0, indices.view(-1)).view(indices.size(0), indices.size(1), -1).transpose(1,2)
        bias = self.bias.index_select(0, indices.view(-1)).view(indices.size(0), 1, indices.size(1))
        out = torch.baddbmm(1, bias, 1, input, target_batch)
        return out.squeeze() 
开发者ID:chenyuntc,项目名称:PyTorchText,代码行数:16,代码来源:nce.py

示例13: batched_l2_dist

# 需要导入模块: import torch [as 别名]
# 或者: from torch import baddbmm [as 别名]
def batched_l2_dist(a, b):
    a_squared = a.norm(dim=-1).pow(2)
    b_squared = b.norm(dim=-1).pow(2)

    squared_res = th.baddbmm(
        b_squared.unsqueeze(-2), a, b.transpose(-2, -1), alpha=-2
    ).add_(a_squared.unsqueeze(-1))
    res = squared_res.clamp_min_(1e-30).sqrt_()
    return res 
开发者ID:dmlc,项目名称:dgl,代码行数:11,代码来源:score_fun.py

示例14: VarGRUCell

# 需要导入模块: import torch [as 别名]
# 或者: from torch import baddbmm [as 别名]
def VarGRUCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None, noise_in=None, noise_hidden=None):
    input = input.expand(3, *input.size()) if noise_in is None else input.unsqueeze(0) * noise_in
    hx = hidden.expand(3, *hidden.size()) if noise_hidden is None else hidden.unsqueeze(0) * noise_hidden

    gi = torch.baddbmm(b_ih.unsqueeze(1), input, w_ih)
    gh = torch.baddbmm(b_hh.unsqueeze(1), hx, w_hh)
    i_r, i_i, i_n = gi
    h_r, h_i, h_n = gh

    resetgate = F.sigmoid(i_r + h_r)
    inputgate = F.sigmoid(i_i + h_i)
    newgate = F.tanh(i_n + resetgate * h_n)
    hy = newgate + inputgate * (hidden - newgate)

    return hy 
开发者ID:thomas0809,项目名称:GraphIE,代码行数:17,代码来源:variational_rnn.py

示例15: apply_classification_weights

# 需要导入模块: import torch [as 别名]
# 或者: from torch import baddbmm [as 别名]
def apply_classification_weights(self, features, cls_weights):
        """Applies the classification weight vectors to the feature vectors.

        Args:
            features: A 3D tensor of shape
                [batch_size x num_test_examples x num_channels] with the feature
                vectors (of `num_channels` length) of each example on each
                trainining episode in the batch. `batch_size` is the number of
                training episodes in the batch and `num_test_examples` is the
                number of test examples of each training episode.
            cls_weights: A 3D tensor of shape [batch_size x nK x num_channels]
                that includes the classification weight vectors
                (of `num_channels` length) of the `nK` categories used on
                each training episode in the batch. `nK` is the number of
                categories (e.g., the number of base categories plus the number
                of novel categories) used on each training episode.

        Return:
            cls_scores: A 3D tensor with shape
                [batch_size x num_test_examples x nK] that represents the
                classification scores of the test examples for the `nK`
                categories.
        """
        if self.classifier_type=='cosine':
            features = F.normalize(
                features, p=2, dim=features.dim()-1, eps=1e-12)
            cls_weights = F.normalize(
                cls_weights, p=2, dim=cls_weights.dim()-1, eps=1e-12)

        cls_scores = self.scale_cls * torch.baddbmm(1.0,
            self.bias.view(1, 1, 1), 1.0, features, cls_weights.transpose(1,2))
        return cls_scores 
开发者ID:gidariss,项目名称:FewShotWithoutForgetting,代码行数:34,代码来源:ClassifierWithFewShotGenerationModule.py


注:本文中的torch.baddbmm方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。