當前位置: 首頁>>代碼示例>>Python>>正文


Python Utils.sequence_mask方法代碼示例

本文整理匯總了Python中onmt.Utils.sequence_mask方法的典型用法代碼示例。如果您正苦於以下問題:Python Utils.sequence_mask方法的具體用法?Python Utils.sequence_mask怎麽用?Python Utils.sequence_mask使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在onmt.Utils的用法示例。


在下文中一共展示了Utils.sequence_mask方法的2個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: forward

# 需要導入模塊: from onmt import Utils [as 別名]
# 或者: from onmt.Utils import sequence_mask [as 別名]
def forward(self, input, context, context_lengths=None, coverage=None):
        """

        Args:
          input (`FloatTensor`): query vectors `[batch x tgt_len x hidden_size]`
          context (`FloatTensor`): source vectors `[batch x src_len x hidden_size]`
          context_lengths (`LongTensor`): the source context lengths `[batch]`
          coverage (`FloatTensor`): None (not supported yet)

        Returns:
          (`FloatTensor`, `FloatTensor`):

          * Computed vector `[tgt_len x batch x hidden_size]`
          * Attention distribtutions for each query
             `[tgt_len x batch x src_len]`
        """

        batch, sourceL, context_size = context.size()
        batch_, targetL, hidden_size = input.size()
        aeq(batch, batch_)

        # compute attention scores, as in Luong et al.
        align = self.score(input, context)  # BS x tgt_len x src_len   64 x 19 x 13

        # pdb.set_trace()

        if context_lengths is not None:
            mask = sequence_mask(context_lengths)
            mask = mask.unsqueeze(1)  # Make it broadcastable.
            align.data.masked_fill_(1 - mask, -float('inf'))

        # Softmax to normalize attention weights
        align_vectors = self.sm(align.view(batch*targetL, sourceL))
        align_vectors = align_vectors.view(batch, targetL, sourceL)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        c = torch.bmm(align_vectors, context)

        # concatenate
        concat_c = torch.cat([c, input], 2).view(batch*targetL, -1)
        attn_h = self.linear_out(concat_c).view(batch, targetL, hidden_size)
        if self.attn_type in ["general", "dot"]:
            attn_h = self.tanh(attn_h)

        attn_h = attn_h.transpose(0, 1).contiguous()
        align_vectors = align_vectors.transpose(0, 1).contiguous()

        # Check output sizes
        targetL_, batch_, dim_ = attn_h.size()
        # aeq(targetL, targetL_)
        # aeq(batch, batch_)
        # aeq(hidden_size, dim_)
        targetL_, batch_, sourceL_ = align_vectors.size()
        # aeq(targetL, targetL_)
        # aeq(batch, batch_)
        # aeq(sourceL, sourceL_)

        return attn_h, align_vectors 
開發者ID:matthewmackay,項目名稱:reversible-rnn,代碼行數:61,代碼來源:MultiSizeAttention.py

示例2: forward

# 需要導入模塊: from onmt import Utils [as 別名]
# 或者: from onmt.Utils import sequence_mask [as 別名]
def forward(self, src, tgt, src_lengths=None, src_emb=None, tgt_emb=None):
        src_final, src_memory_bank = self.src_encoder(src, src_lengths, emb=src_emb)
        src_length, batch_size, rnn_size = src_memory_bank.size()

        tgt_final, tgt_memory_bank = self.tgt_encoder(tgt, emb=tgt_emb)
        self.q_src_h = src_memory_bank
        self.q_tgt_h = tgt_memory_bank

        src_memory_bank = src_memory_bank.transpose(0,1) # batch_size, src_length, rnn_size
        src_memory_bank = src_memory_bank.transpose(1,2) # batch_size, rnn_size, src_length
        tgt_memory_bank = self.W(tgt_memory_bank.transpose(0,1)) # batch_size, tgt_length, rnn_size

        if self.dist_type == "categorical":
            scores = torch.bmm(tgt_memory_bank, src_memory_bank)
            # mask source attention
            assert (self.mask_val == -float('inf'))
            if src_lengths is not None:
                mask = sequence_mask(src_lengths)
                mask = mask.unsqueeze(1)
                scores.data.masked_fill_(1-mask, self.mask_val)
            # scoresF should be softmax
            log_scores = F.log_softmax(scores, dim=-1)
            scores = F.softmax(scores, dim=-1)

            # Make scores : T x N x S
            scores = scores.transpose(0, 1)
            log_scores = log_scores.transpose(0, 1)

            scores = Params(
                alpha=scores,
                log_alpha=log_scores,
                dist_type=self.dist_type,
            )
        elif self.dist_type == "none":
            scores = torch.bmm(tgt_memory_bank, src_memory_bank)
            # mask source attention
            if src_lengths is not None:
                mask = sequence_mask(src_lengths)
                mask = mask.unsqueeze(1)
                scores.data.masked_fill_(1-mask, self.mask_val)
            scores = Params(
                alpha= scores.transpose(0, 1),
                dist_type=self.dist_type,
            )
        else:
            raise Exception("Unsupported dist_type")

        # T x N x S
        return scores 
開發者ID:harvardnlp,項目名稱:var-attn,代碼行數:51,代碼來源:ViModels.py


注:本文中的onmt.Utils.sequence_mask方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。