当前位置: 首页>>代码示例>>Python>>正文


Python torch.bmm方法代码示例

本文整理汇总了Python中torch.bmm方法的典型用法代码示例。如果您正苦于以下问题:Python torch.bmm方法的具体用法?Python torch.bmm怎么用?Python torch.bmm使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch的用法示例。


在下文中一共展示了torch.bmm方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: m_ggnn

# 需要导入模块: import torch [as 别名]
# 或者: from torch import bmm [as 别名]
def m_ggnn(self, h_v, h_w, e_vw, opt={}):

        m = Variable(torch.zeros(h_w.size(0), h_w.size(1), self.args['out']).type_as(h_w.data))

        for w in range(h_w.size(1)):
            if torch.nonzero(e_vw[:, w, :].data).size():
                for i, el in enumerate(self.args['e_label']):
                    ind = (el == e_vw[:,w,:]).type_as(self.learn_args[0][i])

                    parameter_mat = self.learn_args[0][i][None, ...].expand(h_w.size(0), self.learn_args[0][i].size(0),
                                                                            self.learn_args[0][i].size(1))

                    m_w = torch.transpose(torch.bmm(torch.transpose(parameter_mat, 1, 2),
                                                                        torch.transpose(torch.unsqueeze(h_w[:, w, :], 1),
                                                                                        1, 2)), 1, 2)
                    m_w = torch.squeeze(m_w)
                    m[:,w,:] = ind.expand_as(m_w)*m_w
        return m 
开发者ID:priba,项目名称:nmp_qc,代码行数:20,代码来源:MessageFunction.py

示例2: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import bmm [as 别名]
def forward(self, encoding, lengths):
        lengths = Variable(torch.LongTensor(lengths))
        if torch.cuda.is_available():
            lengths = lengths.cuda()
        if self.method == 'mean':
            encoding_pad = nn.utils.rnn.pack_padded_sequence(encoding, lengths.data.tolist(), batch_first=True)
            encoding = nn.utils.rnn.pad_packed_sequence(encoding_pad, batch_first=True, padding_value=0)[0]
            lengths = lengths.float().view(-1, 1)
            return encoding.sum(1) / lengths, None
        elif self.method == 'max':
            return encoding.max(1)  # [bsz, in_dim], [bsz, in_dim] (position)
        elif self.method == 'attn':
            size = encoding.size()  # [bsz, len, in_dim]
            x_flat = encoding.contiguous().view(-1, size[2])  # [bsz*len, in_dim]
            hbar = self.tanh(self.ws1(x_flat))  # [bsz*len, attn_hid]
            alphas = self.ws2(hbar).view(size[0], size[1])  # [bsz, len]
            alphas = nn.utils.rnn.pack_padded_sequence(alphas, lengths.data.tolist(), batch_first=True)
            alphas = nn.utils.rnn.pad_packed_sequence(alphas, batch_first=True, padding_value=-1e8)[0]
            alphas = functional.softmax(alphas, dim=1)  # [bsz, len]
            alphas = alphas.view(size[0], 1, size[1])  # [bsz, 1, len]
            return torch.bmm(alphas, encoding).squeeze(1), alphas  # [bsz, in_dim], [bsz, len]
        elif self.method == 'last':
            return torch.cat([encoding[i][lengths[i] - 1] for i in range(encoding.size(0))], dim=0), None 
开发者ID:ExplorerFreda,项目名称:VSE-C,代码行数:25,代码来源:model.py

示例3: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import bmm [as 别名]
def forward(self, query_embed, in_memory_embed, atten_mask=None):
        if self.atten_type == 'simple': # simple attention
            attention = torch.bmm(in_memory_embed, query_embed.unsqueeze(2)).squeeze(2)
        elif self.atten_type == 'mul': # multiplicative attention
            attention = torch.bmm(in_memory_embed, torch.mm(query_embed, self.W).unsqueeze(2)).squeeze(2)
        elif self.atten_type == 'add': # additive attention
            attention = torch.tanh(torch.mm(in_memory_embed.view(-1, in_memory_embed.size(-1)), self.W2)\
                .view(in_memory_embed.size(0), -1, self.W2.size(-1)) \
                + torch.mm(query_embed, self.W).unsqueeze(1))
            attention = torch.mm(attention.view(-1, attention.size(-1)), self.W3).view(attention.size(0), -1)
        else:
            raise RuntimeError('Unknown atten_type: {}'.format(self.atten_type))

        if atten_mask is not None:
            # Exclude masked elements from the softmax
            attention = atten_mask * attention - (1 - atten_mask) * INF
        return attention 
开发者ID:hugochan,项目名称:BAMnet,代码行数:19,代码来源:modules.py

示例4: find_max_triples

# 需要导入模块: import torch [as 别名]
# 或者: from torch import bmm [as 别名]
def find_max_triples(p1, p2, topN=5, prob_thd=None):
    """ Find a list of (k1, k2) where k1 >= k2 with the maximum values of p1[k1] * p2[k2]
    Args:
        p1 (torch.CudaTensor): (N, L) batched start_idx probabilities
        p2 (torch.CudaTensor): (N, L) batched end_idx probabilities
        topN (int): return topN pairs with highest values
        prob_thd (float):
    Returns:
        batched_sorted_triple: N * [(st_idx, ed_idx, confidence), ...]
    """
    product = torch.bmm(p1.unsqueeze(2), p2.unsqueeze(1))  # (N, L, L), end_idx >= start_idx
    upper_product = torch.stack([torch.triu(p) for p in product]
                                ).data.cpu().numpy()  # (N, L, L) the lower part becomes zeros
    batched_sorted_triple = []
    for idx, e in enumerate(upper_product):
        sorted_triple = topN_array_2d(e, topN=topN)
        if prob_thd is not None:
            sorted_triple = [t for t in sorted_triple if t[2] >= prob_thd]
        batched_sorted_triple.append(sorted_triple)
    return batched_sorted_triple 
开发者ID:jayleicn,项目名称:TVQAplus,代码行数:22,代码来源:model_utils.py

示例5: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import bmm [as 别名]
def forward(self, input, hidden, encoder_outputs):
        embedded = self.embedding(input).view(1, 1, -1)
        embedded = self.dropout(embedded)

        attn_weights = F.softmax(
            self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)
        attn_applied = torch.bmm(attn_weights.unsqueeze(0),
                                 encoder_outputs.unsqueeze(0))

        output = torch.cat((embedded[0], attn_applied[0]), 1)
        output = self.attn_combine(output).unsqueeze(0)

        output = F.relu(output)
        output, hidden = self.gru(output, hidden)

        output = F.log_softmax(self.out(output[0]), dim=1)
        return output, hidden, attn_weights 
开发者ID:EvilPsyCHo,项目名称:TaskBot,代码行数:19,代码来源:tutorial.py

示例6: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import bmm [as 别名]
def forward(self, z_enc_out, u_enc_out, u_input_np, m_t_input, degree_input, last_hidden, z_input_np):
        sparse_z_input = Variable(self.get_sparse_selective_input(z_input_np), requires_grad=False)

        m_embed = self.emb(m_t_input)
        z_context = self.attn_z(last_hidden, z_enc_out)
        u_context = self.attn_u(last_hidden, u_enc_out)
        gru_in = torch.cat([m_embed, u_context, z_context, degree_input.unsqueeze(0)], dim=2)
        gru_out, last_hidden = self.gru(gru_in, last_hidden)
        gen_score = self.proj(torch.cat([z_context, u_context, gru_out], 2)).squeeze(0)
        z_copy_score = F.tanh(self.proj_copy2(z_enc_out.transpose(0, 1)))
        z_copy_score = torch.matmul(z_copy_score, gru_out.squeeze(0).unsqueeze(2)).squeeze(2)
        z_copy_score = z_copy_score.cpu()
        z_copy_score_max = torch.max(z_copy_score, dim=1, keepdim=True)[0]
        z_copy_score = torch.exp(z_copy_score - z_copy_score_max)  # [B,T]
        z_copy_score = torch.log(torch.bmm(z_copy_score.unsqueeze(1), sparse_z_input)).squeeze(
            1) + z_copy_score_max  # [B,V]
        z_copy_score = cuda_(z_copy_score)

        scores = F.softmax(torch.cat([gen_score, z_copy_score], dim=1), dim=1)
        gen_score, z_copy_score = scores[:, :cfg.vocab_size], \
                                  scores[:, cfg.vocab_size:]
        proba = gen_score + z_copy_score[:, :cfg.vocab_size]  # [B,V]
        proba = torch.cat([proba, z_copy_score[:, cfg.vocab_size:]], 1)
        return proba, last_hidden, gru_out 
开发者ID:ConvLab,项目名称:ConvLab,代码行数:26,代码来源:tsd_net.py

示例7: calc_score

# 需要导入模块: import torch [as 别名]
# 或者: from torch import bmm [as 别名]
def calc_score(self, att_query, att_keys):
        """
        att_query is: b x t_q x n
        att_keys is b x t_k x n
        return b x t_q x t_k scores
        """

        b, t_k, n = list(att_keys.size())
        t_q = att_query.size(1)
        if self.mode == 'bahdanau':
            att_query = att_query.unsqueeze(2).expand(b, t_q, t_k, n)
            att_keys = att_keys.unsqueeze(1).expand(b, t_q, t_k, n)
            sum_qk = att_query + att_keys
            sum_qk = sum_qk.view(b * t_k * t_q, n)
            out = self.linear_att(F.tanh(sum_qk)).view(b, t_q, t_k)
        elif self.mode == 'dot_prod':
            out = torch.bmm(att_query, att_keys.transpose(1, 2))
            if hasattr(self, 'scale'):
                out = out * self.scale
        return out 
开发者ID:nadavbh12,项目名称:Character-Level-Language-Modeling-with-Deeper-Self-Attention-pytorch,代码行数:22,代码来源:attention.py

示例8: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import bmm [as 别名]
def forward(self, x):
        res = x
        A = self.down(res)
        B = self.gather_down(res)
        b, c, h, w = A.size()
        A = A.view(b, c, -1)  # (b, c, h*w)
        B = B.view(b, c, -1)  # (b, c, h*w)
        B = self.softmax(B)
        B = B.permute(0, 2, 1)  # (b, h*w, c)

        G = torch.bmm(A, B)  # (b,c,c)

        C = self.distribue_down(res)
        C = C.view(b, c, -1)  # (b, c, h*w)
        C = self.softmax(C)
        C = C.permute(0, 2, 1)  # (b, h*w, c)

        atten = torch.bmm(C, G)  # (b, h*w, c)
        atten = atten.permute(0, 2, 1).view(b, c, h, -1)
        atten = self.up(atten)

        out = res + atten
        return out 
开发者ID:lxtGH,项目名称:Fast_Seg,代码行数:25,代码来源:operators.py

示例9: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import bmm [as 别名]
def forward(self, x):
        batchsize = x.size()[0]
        n_pts = x.size()[2]
        trans = self.stn(x)
        x = x.transpose(2,1)
        x = torch.bmm(x, trans)
        x = x.transpose(2,1)
        x = F.relu(self.bn1(self.conv1(x)))
        pointfeat = x
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)
        if self.global_feat:
            return x, trans
        else:
            x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)
            return torch.cat([x, pointfeat], 1), trans 
开发者ID:seowok,项目名称:TreeGAN,代码行数:20,代码来源:pointnet.py

示例10: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import bmm [as 别名]
def forward(self, inputs, y=None):
        # Apply convs
        theta = self.theta(inputs)
        phi = F.max_pool2d(self.phi(inputs), [2, 2])
        g = F.max_pool2d(self.g(inputs), [2, 2])
        # Perform reshapes
        theta = theta.view(-1, self.channels // self.heads, inputs.shape[2] * inputs.shape[3])
        phi = phi.view(-1, self.channels // self.heads, inputs.shape[2] * inputs.shape[3] // 4)
        g = g.view(-1, self.channels // 2, inputs.shape[2] * inputs.shape[3] // 4)
        # Matmul and softmax to get attention maps
        beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1)
        # Attention map times g path
        o = self.o(torch.bmm(g, beta.transpose(1, 2)).view(-1, self.channels // 2, inputs.shape[2],
                                                           inputs.shape[3]))
        outputs = self.gamma * o + inputs
        return outputs 
开发者ID:bayesiains,项目名称:nsf,代码行数:18,代码来源:attention.py

示例11: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import bmm [as 别名]
def forward(self, context, question, context_padding, question_padding): 
        context_padding = torch.cat([context.new_zeros((context.size(0), 1), dtype=torch.long)==1, context_padding], 1)
        question_padding = torch.cat([question.new_zeros((question.size(0), 1), dtype=torch.long)==1, question_padding], 1)

        context_sentinel = self.embed_sentinel(context.new_zeros((context.size(0), 1), dtype=torch.long))
        context = torch.cat([context_sentinel, self.dropout(context)], 1) # batch_size x (context_length + 1) x features

        question_sentinel = self.embed_sentinel(question.new_ones((question.size(0), 1), dtype=torch.long))
        question = torch.cat([question_sentinel, question], 1) # batch_size x (question_length + 1) x features
        question = torch.tanh(self.proj(question)) # batch_size x (question_length + 1) x features

        affinity = context.bmm(question.transpose(1,2)) # batch_size x (context_length + 1) x (question_length + 1)
        attn_over_context = self.normalize(affinity, context_padding) # batch_size x (context_length + 1) x 1
        attn_over_question = self.normalize(affinity.transpose(1,2), question_padding) # batch_size x (question_length + 1) x 1
        sum_of_context = self.attn(attn_over_context, context) # batch_size x (question_length + 1) x features
        sum_of_question = self.attn(attn_over_question, question) # batch_size x (context_length + 1) x features
        coattn_context = self.attn(attn_over_question, sum_of_context) # batch_size x (context_length + 1) x features
        coattn_question = self.attn(attn_over_context, sum_of_question) # batch_size x (question_length + 1) x features
        return torch.cat([coattn_context, sum_of_question], 2)[:, 1:], torch.cat([coattn_question, sum_of_context], 2)[:, 1:] 
开发者ID:salesforce,项目名称:decaNLP,代码行数:21,代码来源:common.py

示例12: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import bmm [as 别名]
def forward(self, h, adj):
        n = h.size(0) # h is of size n x f_in
        h_prime = torch.matmul(h.unsqueeze(0), self.w) #  n_head x n x f_out
        attn_src = torch.bmm(h_prime, self.a_src) # n_head x n x 1
        attn_dst = torch.bmm(h_prime, self.a_dst) # n_head x n x 1
        attn = attn_src.expand(-1, -1, n) + attn_dst.expand(-1, -1, n).permute(0, 2, 1) # n_head x n x n

        attn = self.leaky_relu(attn)
        attn.data.masked_fill_(1 - adj, float("-inf"))
        attn = self.softmax(attn) # n_head x n x n
        attn = self.dropout(attn)
        output = torch.bmm(attn, h_prime) # n_head x n x f_out

        if self.bias is not None:
            return output + self.bias
        else:
            return output 
开发者ID:xptree,项目名称:DeepInf,代码行数:19,代码来源:gat_layers.py

示例13: convolutional_layer

# 需要导入模块: import torch [as 别名]
# 或者: from torch import bmm [as 别名]
def convolutional_layer(self, inputs):
        convolution_all = []
        conv_wts = []
        for i in range(self.seq_len):
            convolution_one_month = []
            for j in range(self.pad_size):
                convolution = self.conv(torch.unsqueeze(inputs[:, i, j], dim=1))
                convolution_one_month.append(convolution)
            convolution_one_month = torch.stack(convolution_one_month)
            convolution_one_month = torch.squeeze(convolution_one_month, dim=3)
            convolution_one_month = torch.transpose(convolution_one_month, 0, 1)
            convolution_one_month = torch.transpose(convolution_one_month, 1, 2)
            convolution_one_month = torch.squeeze(convolution_one_month, dim=1)
            convolution_one_month = self.func_tanh(convolution_one_month)
            convolution_one_month = torch.unsqueeze(convolution_one_month, dim=1)
            vec = torch.bmm(convolution_one_month, inputs[:, i])
            convolution_all.append(vec)
            conv_wts.append(convolution_one_month)
        convolution_all = torch.stack(convolution_all, dim=1)
        convolution_all = torch.squeeze(convolution_all, dim=2)
        conv_wts = torch.squeeze(torch.stack(conv_wts, dim=1), dim=2)
        return convolution_all, conv_wts 
开发者ID:BarnesLab,项目名称:Patient2Vec,代码行数:24,代码来源:Patient2Vec.py

示例14: get_loss

# 需要导入模块: import torch [as 别名]
# 或者: from torch import bmm [as 别名]
def get_loss(pred, y, criterion, mtr, a=0.5):
    """
    To calculate loss
    :param pred: predicted value
    :param y: actual value
    :param criterion: nn.CrossEntropyLoss
    :param mtr: beta matrix
    """
    mtr_t = torch.transpose(mtr, 1, 2)
    aa = torch.bmm(mtr, mtr_t)
    loss_fn = 0
    for i in range(aa.size()[0]):
        aai = torch.add(aa[i, ], Variable(torch.neg(torch.eye(mtr.size()[1]))))
        loss_fn += torch.trace(torch.mul(aai, aai).data)
    loss_fn /= aa.size()[0]
    loss = torch.add(criterion(pred, y), Variable(torch.FloatTensor([loss_fn * a])))
    return loss 
开发者ID:BarnesLab,项目名称:Patient2Vec,代码行数:19,代码来源:Patient2Vec.py

示例15: train

# 需要导入模块: import torch [as 别名]
# 或者: from torch import bmm [as 别名]
def train(model_q, model_k, device, train_loader, queue, optimizer, epoch, temp=0.07):
    model_q.train()
    total_loss = 0

    for batch_idx, (data, target) in enumerate(train_loader):
        x_q = data[0]
        x_k = data[1]

        x_q, x_k = x_q.to(device), x_k.to(device)
        q = model_q(x_q)
        k = model_k(x_k)
        k = k.detach()

        N = data[0].shape[0]
        K = queue.shape[0]
        l_pos = torch.bmm(q.view(N,1,-1), k.view(N,-1,1))
        l_neg = torch.mm(q.view(N,-1), queue.T.view(-1,K))

        logits = torch.cat([l_pos.view(N, 1), l_neg], dim=1)

        labels = torch.zeros(N, dtype=torch.long)
        labels = labels.to(device)

        cross_entropy_loss = nn.CrossEntropyLoss()
        loss = cross_entropy_loss(logits/temp, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        momentum_update(model_q, model_k)

        queue = queue_data(queue, k)
        queue = dequeue_data(queue)

    total_loss /= len(train_loader.dataset)

    print('Train Epoch: {} \tLoss: {:.6f}'.format(epoch, total_loss)) 
开发者ID:peisuke,项目名称:MomentumContrast.pytorch,代码行数:42,代码来源:train.py


注:本文中的torch.bmm方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。