本文整理匯總了Python中onmt.Utils.sequence_mask方法的典型用法代碼示例。如果您正苦於以下問題:Python Utils.sequence_mask方法的具體用法?Python Utils.sequence_mask怎麽用?Python Utils.sequence_mask使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類onmt.Utils
的用法示例。
在下文中一共展示了Utils.sequence_mask方法的2個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: forward
# 需要導入模塊: from onmt import Utils [as 別名]
# 或者: from onmt.Utils import sequence_mask [as 別名]
def forward(self, input, context, context_lengths=None, coverage=None):
"""
Args:
input (`FloatTensor`): query vectors `[batch x tgt_len x hidden_size]`
context (`FloatTensor`): source vectors `[batch x src_len x hidden_size]`
context_lengths (`LongTensor`): the source context lengths `[batch]`
coverage (`FloatTensor`): None (not supported yet)
Returns:
(`FloatTensor`, `FloatTensor`):
* Computed vector `[tgt_len x batch x hidden_size]`
* Attention distribtutions for each query
`[tgt_len x batch x src_len]`
"""
batch, sourceL, context_size = context.size()
batch_, targetL, hidden_size = input.size()
aeq(batch, batch_)
# compute attention scores, as in Luong et al.
align = self.score(input, context) # BS x tgt_len x src_len 64 x 19 x 13
# pdb.set_trace()
if context_lengths is not None:
mask = sequence_mask(context_lengths)
mask = mask.unsqueeze(1) # Make it broadcastable.
align.data.masked_fill_(1 - mask, -float('inf'))
# Softmax to normalize attention weights
align_vectors = self.sm(align.view(batch*targetL, sourceL))
align_vectors = align_vectors.view(batch, targetL, sourceL)
# each context vector c_t is the weighted average
# over all the source hidden states
c = torch.bmm(align_vectors, context)
# concatenate
concat_c = torch.cat([c, input], 2).view(batch*targetL, -1)
attn_h = self.linear_out(concat_c).view(batch, targetL, hidden_size)
if self.attn_type in ["general", "dot"]:
attn_h = self.tanh(attn_h)
attn_h = attn_h.transpose(0, 1).contiguous()
align_vectors = align_vectors.transpose(0, 1).contiguous()
# Check output sizes
targetL_, batch_, dim_ = attn_h.size()
# aeq(targetL, targetL_)
# aeq(batch, batch_)
# aeq(hidden_size, dim_)
targetL_, batch_, sourceL_ = align_vectors.size()
# aeq(targetL, targetL_)
# aeq(batch, batch_)
# aeq(sourceL, sourceL_)
return attn_h, align_vectors
示例2: forward
# 需要導入模塊: from onmt import Utils [as 別名]
# 或者: from onmt.Utils import sequence_mask [as 別名]
def forward(self, src, tgt, src_lengths=None, src_emb=None, tgt_emb=None):
src_final, src_memory_bank = self.src_encoder(src, src_lengths, emb=src_emb)
src_length, batch_size, rnn_size = src_memory_bank.size()
tgt_final, tgt_memory_bank = self.tgt_encoder(tgt, emb=tgt_emb)
self.q_src_h = src_memory_bank
self.q_tgt_h = tgt_memory_bank
src_memory_bank = src_memory_bank.transpose(0,1) # batch_size, src_length, rnn_size
src_memory_bank = src_memory_bank.transpose(1,2) # batch_size, rnn_size, src_length
tgt_memory_bank = self.W(tgt_memory_bank.transpose(0,1)) # batch_size, tgt_length, rnn_size
if self.dist_type == "categorical":
scores = torch.bmm(tgt_memory_bank, src_memory_bank)
# mask source attention
assert (self.mask_val == -float('inf'))
if src_lengths is not None:
mask = sequence_mask(src_lengths)
mask = mask.unsqueeze(1)
scores.data.masked_fill_(1-mask, self.mask_val)
# scoresF should be softmax
log_scores = F.log_softmax(scores, dim=-1)
scores = F.softmax(scores, dim=-1)
# Make scores : T x N x S
scores = scores.transpose(0, 1)
log_scores = log_scores.transpose(0, 1)
scores = Params(
alpha=scores,
log_alpha=log_scores,
dist_type=self.dist_type,
)
elif self.dist_type == "none":
scores = torch.bmm(tgt_memory_bank, src_memory_bank)
# mask source attention
if src_lengths is not None:
mask = sequence_mask(src_lengths)
mask = mask.unsqueeze(1)
scores.data.masked_fill_(1-mask, self.mask_val)
scores = Params(
alpha= scores.transpose(0, 1),
dist_type=self.dist_type,
)
else:
raise Exception("Unsupported dist_type")
# T x N x S
return scores