本文整理汇总了Python中torch.LongTensor.sum方法的典型用法代码示例。如果您正苦于以下问题:Python LongTensor.sum方法的具体用法?Python LongTensor.sum怎么用?Python LongTensor.sum使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.LongTensor
的用法示例。
在下文中一共展示了LongTensor.sum方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _joint_likelihood
# 需要导入模块: from torch import LongTensor [as 别名]
# 或者: from torch.LongTensor import sum [as 别名]
def _joint_likelihood(self,
logits: torch.Tensor,
tags: torch.Tensor,
mask: torch.LongTensor) -> torch.Tensor:
"""
Computes the numerator term for the log-likelihood, which is just score(inputs, tags)
"""
batch_size, sequence_length, num_tags = logits.data.shape
# Transpose batch size and sequence dimensions:
logits = logits.transpose(0, 1).contiguous()
mask = mask.float().transpose(0, 1).contiguous()
tags = tags.transpose(0, 1).contiguous()
# Start with the transition scores from start_tag to the first tag in each input
if self.include_start_end_transitions:
score = self.start_transitions.index_select(0, tags[0])
else:
score = 0.0
# Broadcast the transition scores to one per batch element
broadcast_transitions = self.transitions.view(1, num_tags, num_tags).expand(batch_size, num_tags, num_tags)
# Add up the scores for the observed transitions and all the inputs but the last
for i in range(sequence_length - 1):
# Each is shape (batch_size,)
current_tag, next_tag = tags[i], tags[i+1]
# The scores for transitioning from current_tag to next_tag
transition_score = (
broadcast_transitions
# Choose the current_tag-th row for each input
.gather(1, current_tag.view(batch_size, 1, 1).expand(batch_size, 1, num_tags))
# Squeeze down to (batch_size, num_tags)
.squeeze(1)
# Then choose the next_tag-th column for each of those
.gather(1, next_tag.view(batch_size, 1))
# And squeeze down to (batch_size,)
.squeeze(1)
)
# The score for using current_tag
emit_score = logits[i].gather(1, current_tag.view(batch_size, 1)).squeeze(1)
# Include transition score if next element is unmasked,
# input_score if this element is unmasked.
score = score + transition_score * mask[i + 1] + emit_score * mask[i]
# Transition from last state to "stop" state. To start with, we need to find the last tag
# for each instance.
last_tag_index = mask.sum(0).long() - 1
last_tags = tags.gather(0, last_tag_index.view(1, batch_size).expand(sequence_length, batch_size))
# Is (sequence_length, batch_size), but all the columns are the same, so take the first.
last_tags = last_tags[0]
# Compute score of transitioning to `stop_tag` from each "last tag".
if self.include_start_end_transitions:
last_transition_score = self.end_transitions.index_select(0, last_tags)
else:
last_transition_score = 0.0
# Add the last input if it's not masked.
last_inputs = logits[-1] # (batch_size, num_tags)
last_input_score = last_inputs.gather(1, last_tags.view(-1, 1)) # (batch_size, 1)
last_input_score = last_input_score.squeeze() # (batch_size,)
score = score + last_transition_score + last_input_score * mask[-1]
return score