本文整理汇总了Python中fairseq.utils.set_incremental_state方法的典型用法代码示例。如果您正苦于以下问题:Python utils.set_incremental_state方法的具体用法?Python utils.set_incremental_state怎么用?Python utils.set_incremental_state使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类fairseq.utils
的用法示例。
在下文中一共展示了utils.set_incremental_state方法的14个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _split_encoder_out
# 需要导入模块: from fairseq import utils [as 别名]
# 或者: from fairseq.utils import set_incremental_state [as 别名]
def _split_encoder_out(self, encoder_out, incremental_state):
"""Split and transpose encoder outputs.
This is cached when doing incremental inference.
"""
cached_result = utils.get_incremental_state(self, incremental_state, 'encoder_out')
if cached_result is not None:
return cached_result
# transpose only once to speed up attention layers
encoder_a, encoder_b = encoder_out
encoder_a = encoder_a.transpose(1, 2).contiguous()
result = (encoder_a, encoder_b)
if incremental_state is not None:
utils.set_incremental_state(self, incremental_state, 'encoder_out', result)
return result
示例2: _split_encoder_out
# 需要导入模块: from fairseq import utils [as 别名]
# 或者: from fairseq.utils import set_incremental_state [as 别名]
def _split_encoder_out(self, encoder_out, incremental_state, aux=False):
"""Split and transpose encoder outputs.
This is cached when doing incremental inference.
"""
state_name='encoder_out'
if aux == True:
state_name = 'aux' + state_name
cached_result = utils.get_incremental_state(self, incremental_state, state_name)
if cached_result is not None:
return cached_result
# transpose only once to speed up attention layers
encoder_a, encoder_b = encoder_out
encoder_a = encoder_a.transpose(1, 2).contiguous()
result = (encoder_a, encoder_b)
if incremental_state is not None:
utils.set_incremental_state(self, incremental_state, state_name, result)
return result
示例3: reorder_incremental_state
# 需要导入模块: from fairseq import utils [as 别名]
# 或者: from fairseq.utils import set_incremental_state [as 别名]
def reorder_incremental_state(self, incremental_state, new_order):
"""Reorder buffered internal state (for incremental generation)."""
cached_state = utils.get_incremental_state(
self, incremental_state, "cached_state"
)
if cached_state is None:
return
def reorder_state(state):
if state is None:
return None
if isinstance(state, list):
return [reorder_state(state_i) for state_i in state]
return state.index_select(0, new_order)
new_state = tuple(map(reorder_state, cached_state))
utils.set_incremental_state(self, incremental_state, "cached_state", new_state)
示例4: reorder_incremental_state
# 需要导入模块: from fairseq import utils [as 别名]
# 或者: from fairseq.utils import set_incremental_state [as 别名]
def reorder_incremental_state(self, incremental_state, new_order):
# parent reorders attention model
super().reorder_incremental_state(incremental_state, new_order)
cached_state = utils.get_incremental_state(
self, incremental_state, "cached_state"
)
if cached_state is None:
return
# Last 2 elements of prev_states are encoder projections
# used for ONNX export
for i, state in enumerate(cached_state[:-2]):
cached_state[i] = state.index_select(1, new_order)
utils.set_incremental_state(
self, incremental_state, "cached_state", cached_state
)
示例5: forward
# 需要导入模块: from fairseq import utils [as 别名]
# 或者: from fairseq.utils import set_incremental_state [as 别名]
def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
bbsz = prev_output_tokens.size(0)
vocab = len(self.dictionary)
src_len = encoder_out.size(1)
tgt_len = prev_output_tokens.size(1)
# determine number of steps
if incremental_state is not None:
# cache step number
step = utils.get_incremental_state(self, incremental_state, 'step')
if step is None:
step = 0
utils.set_incremental_state(self, incremental_state, 'step', step + 1)
steps = [step]
else:
steps = list(range(tgt_len))
# define output in terms of raw probs
if hasattr(self.args, 'probs'):
assert self.args.probs.dim() == 3, \
'expected probs to have size bsz*steps*vocab'
probs = self.args.probs.index_select(1, torch.LongTensor(steps))
else:
probs = torch.FloatTensor(bbsz, len(steps), vocab).zero_()
for i, step in enumerate(steps):
# args.beam_probs gives the probability for every vocab element,
# starting with eos, then unknown, and then the rest of the vocab
if step < len(self.args.beam_probs):
probs[:, i, self.dictionary.eos():] = self.args.beam_probs[step]
else:
probs[:, i, self.dictionary.eos()] = 1.0
# random attention
attn = torch.rand(bbsz, src_len, tgt_len)
return Variable(probs), Variable(attn)
示例6: _set_input_buffer
# 需要导入模块: from fairseq import utils [as 别名]
# 或者: from fairseq.utils import set_incremental_state [as 别名]
def _set_input_buffer(self, incremental_state, new_buffer):
return utils.set_incremental_state(self, incremental_state, 'input_buffer', new_buffer)
示例7: _set_input_buffer
# 需要导入模块: from fairseq import utils [as 别名]
# 或者: from fairseq.utils import set_incremental_state [as 别名]
def _set_input_buffer(self, incremental_state, buffer):
utils.set_incremental_state(
self,
incremental_state,
'attn_state',
buffer,
)
示例8: reorder_incremental_state
# 需要导入模块: from fairseq import utils [as 别名]
# 或者: from fairseq.utils import set_incremental_state [as 别名]
def reorder_incremental_state(self, incremental_state, new_order):
super().reorder_incremental_state(incremental_state, new_order)
encoder_out = utils.get_incremental_state(self, incremental_state, 'encoder_out')
if encoder_out is not None:
encoder_out = tuple(eo.index_select(0, new_order) for eo in encoder_out)
utils.set_incremental_state(self, incremental_state, 'encoder_out', encoder_out)
示例9: reorder_incremental_state
# 需要导入模块: from fairseq import utils [as 别名]
# 或者: from fairseq.utils import set_incremental_state [as 别名]
def reorder_incremental_state(self, incremental_state, new_order):
super().reorder_incremental_state(incremental_state, new_order)
cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state')
if cached_state is None:
return
def reorder_state(state):
if isinstance(state, list):
return [reorder_state(state_i) for state_i in state]
return state.index_select(0, new_order)
new_state = tuple(map(reorder_state, cached_state))
utils.set_incremental_state(self, incremental_state, 'cached_state', new_state)
示例10: _set_monotonic_buffer
# 需要导入模块: from fairseq import utils [as 别名]
# 或者: from fairseq.utils import set_incremental_state [as 别名]
def _set_monotonic_buffer(self, incremental_state, buffer):
utils.set_incremental_state(
self,
incremental_state,
'monotonic',
buffer,
)
示例11: set_pointer
# 需要导入模块: from fairseq import utils [as 别名]
# 或者: from fairseq.utils import set_incremental_state [as 别名]
def set_pointer(self, incremental_state, p_choose):
curr_pointer = self.get_pointer(incremental_state)
if len(curr_pointer) == 0:
buffer = torch.zeros_like(p_choose)
else:
buffer = self.get_pointer(incremental_state)["step"]
buffer += (p_choose < 0.5).type_as(buffer)
utils.set_incremental_state(
self,
incremental_state,
'monotonic',
{"step": buffer},
)
示例12: forward
# 需要导入模块: from fairseq import utils [as 别名]
# 或者: from fairseq.utils import set_incremental_state [as 别名]
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
bbsz = prev_output_tokens.size(0)
vocab = len(self.dictionary)
src_len = encoder_out.encoder_out.size(1)
tgt_len = prev_output_tokens.size(1)
# determine number of steps
if incremental_state is not None:
# cache step number
step = utils.get_incremental_state(self, incremental_state, 'step')
if step is None:
step = 0
utils.set_incremental_state(self, incremental_state, 'step', step + 1)
steps = [step]
else:
steps = list(range(tgt_len))
# define output in terms of raw probs
if hasattr(self.args, 'probs'):
assert self.args.probs.dim() == 3, \
'expected probs to have size bsz*steps*vocab'
probs = self.args.probs.index_select(1, torch.LongTensor(steps))
else:
probs = torch.FloatTensor(bbsz, len(steps), vocab).zero_()
for i, step in enumerate(steps):
# args.beam_probs gives the probability for every vocab element,
# starting with eos, then unknown, and then the rest of the vocab
if step < len(self.args.beam_probs):
probs[:, i, self.dictionary.eos():] = self.args.beam_probs[step]
else:
probs[:, i, self.dictionary.eos()] = 1.0
# random attention
attn = torch.rand(bbsz, tgt_len, src_len)
dev = prev_output_tokens.device
return probs.to(dev), {"attn": [attn.to(dev)]}
示例13: _set_input_buffer
# 需要导入模块: from fairseq import utils [as 别名]
# 或者: from fairseq.utils import set_incremental_state [as 别名]
def _set_input_buffer(
self, incremental_state, buffer, incremental_clone_id: str = ""
):
self.incremental_clone_ids.add(incremental_clone_id)
utils.set_incremental_state(
self, incremental_state, "attn_state" + incremental_clone_id, buffer
)
示例14: forward
# 需要导入模块: from fairseq import utils [as 别名]
# 或者: from fairseq.utils import set_incremental_state [as 别名]
def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
bbsz = prev_output_tokens.size(0)
vocab = len(self.dictionary)
src_len = encoder_out.size(1)
tgt_len = prev_output_tokens.size(1)
# determine number of steps
if incremental_state is not None:
# cache step number
step = utils.get_incremental_state(self, incremental_state, 'step')
if step is None:
step = 0
utils.set_incremental_state(self, incremental_state, 'step', step + 1)
steps = [step]
else:
steps = list(range(tgt_len))
# define output in terms of raw probs
if hasattr(self.args, 'probs'):
assert self.args.probs.dim() == 3, \
'expected probs to have size bsz*steps*vocab'
probs = self.args.probs.index_select(1, torch.LongTensor(steps))
else:
probs = torch.FloatTensor(bbsz, len(steps), vocab).zero_()
for i, step in enumerate(steps):
# args.beam_probs gives the probability for every vocab element,
# starting with eos, then unknown, and then the rest of the vocab
if step < len(self.args.beam_probs):
probs[:, i, self.dictionary.eos():] = self.args.beam_probs[step]
else:
probs[:, i, self.dictionary.eos()] = 1.0
# random attention
attn = torch.rand(bbsz, tgt_len, src_len)
return probs, attn