本文整理汇总了Python中onmt.Models方法的典型用法代码示例。如果您正苦于以下问题:Python onmt.Models方法的具体用法?Python onmt.Models怎么用?Python onmt.Models使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类onmt
的用法示例。
在下文中一共展示了onmt.Models方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: forward
# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import Models [as 别名]
def forward(self, src, tgt, lengths, dec_state=None):
"""Forward propagate a `src` and `tgt` pair for training.
Possible initialized with a beginning decoder state.
Args:
src (:obj:`Tensor`):
a source sequence passed to encoder.
typically for inputs this will be a padded :obj:`LongTensor`
of size `[len x batch x features]`. however, may be an
image or other generic input depending on encoder.
tgt (:obj:`LongTensor`):
a target sequence of size `[tgt_len x batch]`.
lengths(:obj:`LongTensor`): the src lengths, pre-padding `[batch]`.
dec_state (:obj:`DecoderState`, optional): initial decoder state
Returns:
(:obj:`FloatTensor`, `dict`, :obj:`onmt.Models.DecoderState`):
* decoder output `[tgt_len x batch x hidden]`
* dictionary attention dists of `[tgt_len x batch x src_len]`
* final decoder state
"""
src_emb = self.encoder.dropout(self.encoder.embeddings(src))
tgt_emb = self.decoder.dropout(self.decoder.embeddings(tgt))
if self.dbg:
# only see past
inftgt = tgt[:-1]
else:
# see present
inftgt = tgt[1:]
inftgt_emb = tgt_emb[1:]
tgt = tgt[:-1] # exclude last target from inputs
tgt_emb = tgt_emb[:-1] # exclude last target from inputs
tgt_length, batch_size, rnn_size = tgt.size()
enc_final, memory_bank = self.encoder(src, lengths, emb=src_emb)
enc_state = self.decoder.init_decoder_state(
src, memory_bank, enc_final)
# enc_state.* should all be 0
if self.inference_network is not None and not self.use_prior:
# inference network q(z|x,y)
q_scores = self.inference_network(
src, inftgt, lengths, src_emb=src_emb, tgt_emb=inftgt_emb) # batch_size, tgt_length, src_length
else:
q_scores = None
decoder_outputs, dec_state, attns, dist_info, decoder_outputs_baseline = \
self.decoder(tgt, memory_bank,
enc_state if dec_state is None
else dec_state,
memory_lengths=lengths,
q_scores=q_scores,
tgt_emb=tgt_emb)
if self.multigpu:
# Not yet supported on multi-gpu
dec_state = None
attns = None
return decoder_outputs, attns, dec_state, dist_info, decoder_outputs_baseline