本文整理匯總了Python中onmt.Models方法的典型用法代碼示例。如果您正苦於以下問題:Python onmt.Models方法的具體用法?Python onmt.Models怎麽用?Python onmt.Models使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類onmt
的用法示例。
在下文中一共展示了onmt.Models方法的1個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: forward
# 需要導入模塊: import onmt [as 別名]
# 或者: from onmt import Models [as 別名]
def forward(self, src, tgt, lengths, dec_state=None):
"""Forward propagate a `src` and `tgt` pair for training.
Possible initialized with a beginning decoder state.
Args:
src (:obj:`Tensor`):
a source sequence passed to encoder.
typically for inputs this will be a padded :obj:`LongTensor`
of size `[len x batch x features]`. however, may be an
image or other generic input depending on encoder.
tgt (:obj:`LongTensor`):
a target sequence of size `[tgt_len x batch]`.
lengths(:obj:`LongTensor`): the src lengths, pre-padding `[batch]`.
dec_state (:obj:`DecoderState`, optional): initial decoder state
Returns:
(:obj:`FloatTensor`, `dict`, :obj:`onmt.Models.DecoderState`):
* decoder output `[tgt_len x batch x hidden]`
* dictionary attention dists of `[tgt_len x batch x src_len]`
* final decoder state
"""
src_emb = self.encoder.dropout(self.encoder.embeddings(src))
tgt_emb = self.decoder.dropout(self.decoder.embeddings(tgt))
if self.dbg:
# only see past
inftgt = tgt[:-1]
else:
# see present
inftgt = tgt[1:]
inftgt_emb = tgt_emb[1:]
tgt = tgt[:-1] # exclude last target from inputs
tgt_emb = tgt_emb[:-1] # exclude last target from inputs
tgt_length, batch_size, rnn_size = tgt.size()
enc_final, memory_bank = self.encoder(src, lengths, emb=src_emb)
enc_state = self.decoder.init_decoder_state(
src, memory_bank, enc_final)
# enc_state.* should all be 0
if self.inference_network is not None and not self.use_prior:
# inference network q(z|x,y)
q_scores = self.inference_network(
src, inftgt, lengths, src_emb=src_emb, tgt_emb=inftgt_emb) # batch_size, tgt_length, src_length
else:
q_scores = None
decoder_outputs, dec_state, attns, dist_info, decoder_outputs_baseline = \
self.decoder(tgt, memory_bank,
enc_state if dec_state is None
else dec_state,
memory_lengths=lengths,
q_scores=q_scores,
tgt_emb=tgt_emb)
if self.multigpu:
# Not yet supported on multi-gpu
dec_state = None
attns = None
return decoder_outputs, attns, dec_state, dist_info, decoder_outputs_baseline