本文整理汇总了Python中allennlp.nn.util.batched_index_select方法的典型用法代码示例。如果您正苦于以下问题:Python util.batched_index_select方法的具体用法?Python util.batched_index_select怎么用?Python util.batched_index_select使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类allennlp.nn.util
的用法示例。
在下文中一共展示了util.batched_index_select方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test_correct_sequence_elements_are_embedded
# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import batched_index_select [as 别名]
def test_correct_sequence_elements_are_embedded(self):
sequence_tensor = torch.randn([2, 5, 7])
# Concatentate start and end points together to form our representation.
extractor = EndpointSpanExtractor(7, "x,y")
indices = torch.LongTensor([[[1, 3], [2, 4]], [[0, 2], [3, 4]]])
span_representations = extractor(sequence_tensor, indices)
assert list(span_representations.size()) == [2, 2, 14]
assert extractor.get_output_dim() == 14
assert extractor.get_input_dim() == 7
start_indices, end_indices = indices.split(1, -1)
# We just concatenated the start and end embeddings together, so
# we can check they match the original indices if we split them apart.
start_embeddings, end_embeddings = span_representations.split(7, -1)
correct_start_embeddings = batched_index_select(sequence_tensor, start_indices.squeeze())
correct_end_embeddings = batched_index_select(sequence_tensor, end_indices.squeeze())
numpy.testing.assert_array_equal(
start_embeddings.data.numpy(), correct_start_embeddings.data.numpy()
)
numpy.testing.assert_array_equal(
end_embeddings.data.numpy(), correct_end_embeddings.data.numpy()
)
示例2: test_masked_topk_selects_top_scored_items_and_respects_masking
# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import batched_index_select [as 别名]
def test_masked_topk_selects_top_scored_items_and_respects_masking(self):
items = torch.randn([3, 4, 5]).clamp(min=0.0, max=1.0)
items[0, :2, :] = 1
items[1, 2:, :] = 1
items[2, 2:, :] = 1
scores = items.sum(-1)
mask = torch.ones([3, 4]).bool()
mask[1, 0] = 0
mask[1, 3] = 0
pruned_scores, pruned_mask, pruned_indices = util.masked_topk(scores, mask, 2)
# Second element in the batch would have indices 2, 3, but
# 3 and 0 are masked, so instead it has 1, 2.
numpy.testing.assert_array_equal(
pruned_indices.data.numpy(), numpy.array([[0, 1], [1, 2], [2, 3]])
)
numpy.testing.assert_array_equal(pruned_mask.data.numpy(), numpy.ones([3, 2]))
# scores should be the result of index_selecting the pruned_indices.
correct_scores = util.batched_index_select(scores.unsqueeze(-1), pruned_indices).squeeze(-1)
self.assert_array_equal_with_mask(correct_scores, pruned_scores, pruned_mask)
示例3: test_batched_index_select
# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import batched_index_select [as 别名]
def test_batched_index_select(self):
indices = numpy.array([[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]]])
# Each element is a vector of it's index.
targets = torch.ones([2, 10, 3]).cumsum(1) - 1
# Make the second batch double it's index so they're different.
targets[1, :, :] *= 2
indices = torch.tensor(indices, dtype=torch.long)
selected = util.batched_index_select(targets, indices)
assert list(selected.size()) == [2, 2, 2, 3]
ones = numpy.ones([3])
numpy.testing.assert_array_equal(selected[0, 0, 0, :].data.numpy(), ones)
numpy.testing.assert_array_equal(selected[0, 0, 1, :].data.numpy(), ones * 2)
numpy.testing.assert_array_equal(selected[0, 1, 0, :].data.numpy(), ones * 3)
numpy.testing.assert_array_equal(selected[0, 1, 1, :].data.numpy(), ones * 4)
numpy.testing.assert_array_equal(selected[1, 0, 0, :].data.numpy(), ones * 10)
numpy.testing.assert_array_equal(selected[1, 0, 1, :].data.numpy(), ones * 12)
numpy.testing.assert_array_equal(selected[1, 1, 0, :].data.numpy(), ones * 14)
numpy.testing.assert_array_equal(selected[1, 1, 1, :].data.numpy(), ones * 16)
示例4: test_masked_indices_are_handled_correctly
# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import batched_index_select [as 别名]
def test_masked_indices_are_handled_correctly(self):
sequence_tensor = torch.randn([2, 5, 7])
# concatentate start and end points together to form our representation.
extractor = EndpointSpanExtractor(7, "x,y")
indices = torch.LongTensor([[[1, 3], [2, 4]], [[0, 2], [3, 4]]])
span_representations = extractor(sequence_tensor, indices)
# Make a mask with the second batch element completely masked.
indices_mask = torch.tensor([[True, True], [False, False]])
span_representations = extractor(sequence_tensor, indices, span_indices_mask=indices_mask)
start_embeddings, end_embeddings = span_representations.split(7, -1)
start_indices, end_indices = indices.split(1, -1)
correct_start_embeddings = batched_index_select(
sequence_tensor, start_indices.squeeze()
).data
# Completely masked second batch element, so it should all be zero.
correct_start_embeddings[1, :, :].fill_(0)
correct_end_embeddings = batched_index_select(sequence_tensor, end_indices.squeeze()).data
correct_end_embeddings[1, :, :].fill_(0)
numpy.testing.assert_array_equal(
start_embeddings.data.numpy(), correct_start_embeddings.numpy()
)
numpy.testing.assert_array_equal(
end_embeddings.data.numpy(), correct_end_embeddings.numpy()
)
示例5: test_masked_indices_are_handled_correctly_with_exclusive_indices
# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import batched_index_select [as 别名]
def test_masked_indices_are_handled_correctly_with_exclusive_indices(self):
sequence_tensor = torch.randn([2, 5, 8])
# concatentate start and end points together to form our representation
# for both the forward and backward directions.
extractor = EndpointSpanExtractor(8, "x,y", use_exclusive_start_indices=True)
indices = torch.LongTensor([[[1, 3], [2, 4]], [[0, 2], [0, 1]]])
sequence_mask = torch.tensor(
[[True, True, True, True, True], [True, True, True, False, False]]
)
span_representations = extractor(sequence_tensor, indices, sequence_mask=sequence_mask)
# We just concatenated the start and end embeddings together, so
# we can check they match the original indices if we split them apart.
start_embeddings, end_embeddings = span_representations.split(8, -1)
correct_start_indices = torch.LongTensor([[0, 1], [-1, -1]])
# These indices should be -1, so they'll be replaced with a sentinel. Here,
# we'll set them to a value other than -1 so we can index select the indices and
# replace them later.
correct_start_indices[1, 0] = 1
correct_start_indices[1, 1] = 1
correct_end_indices = torch.LongTensor([[3, 4], [2, 1]])
correct_start_embeddings = batched_index_select(
sequence_tensor.contiguous(), correct_start_indices
)
# This element had sequence_tensor index of 0, so it's exclusive index is the start sentinel.
correct_start_embeddings[1, 0] = extractor._start_sentinel.data
correct_start_embeddings[1, 1] = extractor._start_sentinel.data
numpy.testing.assert_array_equal(
start_embeddings.data.numpy(), correct_start_embeddings.data.numpy()
)
correct_end_embeddings = batched_index_select(
sequence_tensor.contiguous(), correct_end_indices
)
numpy.testing.assert_array_equal(
end_embeddings.data.numpy(), correct_end_embeddings.data.numpy()
)
示例6: test_batched_index_select
# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import batched_index_select [as 别名]
def test_batched_index_select(self):
indices = numpy.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
# Each element is a vector of its index.
targets = torch.ones([2, 10, 3]).cumsum(1) - 1
# Make the second batch double its index so they're different.
targets[1, :, :] *= 2
indices = torch.tensor(indices, dtype=torch.long)
selected = util.batched_index_select(targets, indices)
assert list(selected.size()) == [2, 2, 2, 3]
ones = numpy.ones([3])
numpy.testing.assert_array_equal(selected[0, 0, 0, :].data.numpy(), ones)
numpy.testing.assert_array_equal(selected[0, 0, 1, :].data.numpy(), ones * 2)
numpy.testing.assert_array_equal(selected[0, 1, 0, :].data.numpy(), ones * 3)
numpy.testing.assert_array_equal(selected[0, 1, 1, :].data.numpy(), ones * 4)
numpy.testing.assert_array_equal(selected[1, 0, 0, :].data.numpy(), ones * 10)
numpy.testing.assert_array_equal(selected[1, 0, 1, :].data.numpy(), ones * 12)
numpy.testing.assert_array_equal(selected[1, 1, 0, :].data.numpy(), ones * 14)
numpy.testing.assert_array_equal(selected[1, 1, 1, :].data.numpy(), ones * 16)
indices = numpy.array([[[1, 11], [3, 4]], [[5, 6], [7, 8]]])
indices = torch.tensor(indices, dtype=torch.long)
with pytest.raises(ConfigurationError):
util.batched_index_select(targets, indices)
indices = numpy.array([[[1, -1], [3, 4]], [[5, 6], [7, 8]]])
indices = torch.tensor(indices, dtype=torch.long)
with pytest.raises(ConfigurationError):
util.batched_index_select(targets, indices)
示例7: test_masked_topk_selects_top_scored_items_and_respects_masking_different_num_items
# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import batched_index_select [as 别名]
def test_masked_topk_selects_top_scored_items_and_respects_masking_different_num_items(self):
items = torch.randn([3, 4, 5]).clamp(min=0.0, max=1.0)
items[0, 0, :] = 1.5
items[0, 1, :] = 2
items[0, 3, :] = 1
items[1, 1:3, :] = 1
items[2, 0, :] = 1
items[2, 1, :] = 2
items[2, 2, :] = 1.5
scores = items.sum(-1)
mask = torch.ones([3, 4]).bool()
mask[1, 3] = 0
k = torch.tensor([3, 2, 1], dtype=torch.long)
pruned_scores, pruned_mask, pruned_indices = util.masked_topk(scores, mask, k)
# Second element in the batch would have indices 2, 3, but
# 3 and 0 are masked, so instead it has 1, 2.
numpy.testing.assert_array_equal(
pruned_indices.data.numpy(), numpy.array([[0, 1, 3], [1, 2, 2], [1, 2, 2]])
)
numpy.testing.assert_array_equal(
pruned_mask.data.numpy(), numpy.array([[1, 1, 1], [1, 1, 0], [1, 0, 0]])
)
# scores should be the result of index_selecting the pruned_indices.
correct_scores = util.batched_index_select(scores.unsqueeze(-1), pruned_indices).squeeze(-1)
self.assert_array_equal_with_mask(correct_scores, pruned_scores, pruned_mask)
示例8: test_masked_topk_works_for_row_with_no_items_requested
# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import batched_index_select [as 别名]
def test_masked_topk_works_for_row_with_no_items_requested(self):
# Case where `num_items_to_keep` is a tensor rather than an int. Make sure it does the right
# thing when no items are requested for one of the rows.
items = torch.randn([3, 4, 5]).clamp(min=0.0, max=1.0)
items[0, :3, :] = 1
items[1, 2:, :] = 1
items[2, 2:, :] = 1
scores = items.sum(-1)
mask = torch.ones([3, 4]).bool()
mask[1, 0] = 0
mask[1, 3] = 0
k = torch.tensor([3, 2, 0], dtype=torch.long)
pruned_scores, pruned_mask, pruned_indices = util.masked_topk(scores, mask, k)
# First element just picks top three entries. Second would pick entries 2 and 3, but 0 and 3
# are masked, so it takes 1 and 2 (repeating the second index). The third element is
# entirely masked and just repeats the largest index with a top-3 score.
numpy.testing.assert_array_equal(
pruned_indices.data.numpy(), numpy.array([[0, 1, 2], [1, 2, 2], [3, 3, 3]])
)
numpy.testing.assert_array_equal(
pruned_mask.data.numpy(), numpy.array([[1, 1, 1], [1, 1, 0], [0, 0, 0]])
)
# scores should be the result of index_selecting the pruned_indices.
correct_scores = util.batched_index_select(scores.unsqueeze(-1), pruned_indices).squeeze(-1)
self.assert_array_equal_with_mask(correct_scores, pruned_scores, pruned_mask)
示例9: test_span_pruner_selects_top_scored_spans_and_respects_masking
# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import batched_index_select [as 别名]
def test_span_pruner_selects_top_scored_spans_and_respects_masking(self):
# Really simple scorer - sum up the embedding_dim.
scorer = lambda tensor: tensor.sum(-1).unsqueeze(-1)
pruner = SpanPruner(scorer=scorer)
spans = torch.randn([3, 4, 5]).clamp(min=0.0, max=1.0)
spans[0, :2, :] = 1
spans[1, 2:, :] = 1
spans[2, 2:, :] = 1
mask = torch.ones([3, 4])
mask[1, 0] = 0
mask[1, 3] = 0
pruned_embeddings, pruned_mask, pruned_indices, pruned_scores = pruner(spans, mask, 2)
# Second element in the batch would have indices 2, 3, but
# 3 and 0 are masked, so instead it has 1, 2.
numpy.testing.assert_array_equal(pruned_indices.data.numpy(), numpy.array([[0, 1],
[1, 2],
[2, 3]]))
numpy.testing.assert_array_equal(pruned_mask.data.numpy(), numpy.ones([3, 2]))
# embeddings should be the result of index_selecting the pruned_indices.
correct_embeddings = batched_index_select(spans, pruned_indices)
numpy.testing.assert_array_equal(correct_embeddings.data.numpy(),
pruned_embeddings.data.numpy())
# scores should be the sum of the correct embedding elements.
numpy.testing.assert_array_equal(correct_embeddings.sum(-1).unsqueeze(-1).data.numpy(),
pruned_scores.data.numpy())
示例10: test_span_scorer_works_for_completely_masked_rows
# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import batched_index_select [as 别名]
def test_span_scorer_works_for_completely_masked_rows(self):
# Really simple scorer - sum up the embedding_dim.
scorer = lambda tensor: tensor.sum(-1).unsqueeze(-1)
pruner = SpanPruner(scorer=scorer) # type: ignore
spans = torch.randn([3, 4, 5]).clamp(min=0.0, max=1.0)
spans[0, :2, :] = 1
spans[1, 2:, :] = 1
spans[2, 2:, :] = 1
mask = torch.ones([3, 4])
mask[1, 0] = 0
mask[1, 3] = 0
mask[2, :] = 0 # fully masked last batch element.
pruned_embeddings, pruned_mask, pruned_indices, pruned_scores = pruner(spans, mask, 2)
# We can't check the last row here, because it's completely masked.
# Instead we'll check that the scores for these elements are -inf.
numpy.testing.assert_array_equal(pruned_indices[:2].data.numpy(), numpy.array([[0, 1],
[1, 2]]))
numpy.testing.assert_array_equal(pruned_mask.data.numpy(), numpy.array([[1, 1],
[1, 1],
[0, 0]]))
# embeddings should be the result of index_selecting the pruned_indices.
correct_embeddings = batched_index_select(spans, pruned_indices)
numpy.testing.assert_array_equal(correct_embeddings.data.numpy(),
pruned_embeddings.data.numpy())
# scores should be the sum of the correct embedding elements, with
# masked elements equal to -inf.
correct_scores = correct_embeddings.sum(-1).unsqueeze(-1).data.numpy()
correct_scores[2, :] = float(u"-inf")
numpy.testing.assert_array_equal(correct_scores,
pruned_scores.data.numpy())
示例11: test_correct_sequence_elements_are_embedded
# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import batched_index_select [as 别名]
def test_correct_sequence_elements_are_embedded(self):
sequence_tensor = torch.randn([2, 5, 7])
# Concatentate start and end points together to form our representation.
extractor = EndpointSpanExtractor(7, u"x,y")
indices = torch.LongTensor([[[1, 3],
[2, 4]],
[[0, 2],
[3, 4]]])
span_representations = extractor(sequence_tensor, indices)
assert list(span_representations.size()) == [2, 2, 14]
assert extractor.get_output_dim() == 14
assert extractor.get_input_dim() == 7
start_indices, end_indices = indices.split(1, -1)
# We just concatenated the start and end embeddings together, so
# we can check they match the original indices if we split them apart.
start_embeddings, end_embeddings = span_representations.split(7, -1)
correct_start_embeddings = batched_index_select(sequence_tensor, start_indices.squeeze())
correct_end_embeddings = batched_index_select(sequence_tensor, end_indices.squeeze())
numpy.testing.assert_array_equal(start_embeddings.data.numpy(),
correct_start_embeddings.data.numpy())
numpy.testing.assert_array_equal(end_embeddings.data.numpy(),
correct_end_embeddings.data.numpy())
示例12: test_masked_indices_are_handled_correctly
# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import batched_index_select [as 别名]
def test_masked_indices_are_handled_correctly(self):
sequence_tensor = torch.randn([2, 5, 7])
# concatentate start and end points together to form our representation.
extractor = EndpointSpanExtractor(7, u"x,y")
indices = torch.LongTensor([[[1, 3],
[2, 4]],
[[0, 2],
[3, 4]]])
span_representations = extractor(sequence_tensor, indices)
# Make a mask with the second batch element completely masked.
indices_mask = torch.LongTensor([[1, 1], [0, 0]])
span_representations = extractor(sequence_tensor, indices, span_indices_mask=indices_mask)
start_embeddings, end_embeddings = span_representations.split(7, -1)
start_indices, end_indices = indices.split(1, -1)
correct_start_embeddings = batched_index_select(sequence_tensor, start_indices.squeeze()).data
# Completely masked second batch element, so it should all be zero.
correct_start_embeddings[1, :, :].fill_(0)
correct_end_embeddings = batched_index_select(sequence_tensor, end_indices.squeeze()).data
correct_end_embeddings[1, :, :].fill_(0)
numpy.testing.assert_array_equal(start_embeddings.data.numpy(),
correct_start_embeddings.numpy())
numpy.testing.assert_array_equal(end_embeddings.data.numpy(),
correct_end_embeddings.numpy())
示例13: forward
# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import batched_index_select [as 别名]
def forward(self,
sequence_tensor ,
span_indices ,
sequence_mask = None,
span_indices_mask = None) :
# shape (batch_size, num_spans)
span_starts, span_ends = [index.squeeze(-1) for index in span_indices.split(1, dim=-1)]
if span_indices_mask is not None:
# It's not strictly necessary to multiply the span indices by the mask here,
# but it's possible that the span representation was padded with something other
# than 0 (such as -1, which would be an invalid index), so we do so anyway to
# be safe.
span_starts = span_starts * span_indices_mask
span_ends = span_ends * span_indices_mask
if not self._use_exclusive_start_indices:
start_embeddings = util.batched_index_select(sequence_tensor, span_starts)
end_embeddings = util.batched_index_select(sequence_tensor, span_ends)
else:
# We want `exclusive` span starts, so we remove 1 from the forward span starts
# as the AllenNLP ``SpanField`` is inclusive.
# shape (batch_size, num_spans)
exclusive_span_starts = span_starts - 1
# shape (batch_size, num_spans, 1)
start_sentinel_mask = (exclusive_span_starts == -1).long().unsqueeze(-1)
exclusive_span_starts = exclusive_span_starts * (1 - start_sentinel_mask.squeeze(-1))
# We'll check the indices here at runtime, because it's difficult to debug
# if this goes wrong and it's tricky to get right.
if (exclusive_span_starts < 0).any():
raise ValueError("Adjusted span indices must lie inside the the sequence tensor, "
"but found: exclusive_span_starts: {exclusive_span_starts}.")
start_embeddings = util.batched_index_select(sequence_tensor, exclusive_span_starts)
end_embeddings = util.batched_index_select(sequence_tensor, span_ends)
# We're using sentinels, so we need to replace all the elements which were
# outside the dimensions of the sequence_tensor with the start sentinel.
float_start_sentinel_mask = start_sentinel_mask.float()
start_embeddings = start_embeddings * (1 - float_start_sentinel_mask)\
+ float_start_sentinel_mask * self._start_sentinel
combined_tensors = util.combine_tensors(self._combination, [start_embeddings, end_embeddings])
if self._span_width_embedding is not None:
# Embed the span widths and concatenate to the rest of the representations.
if self._bucket_widths:
span_widths = util.bucket_values(span_ends - span_starts,
num_total_buckets=self._num_width_embeddings)
else:
span_widths = span_ends - span_starts
span_width_embeddings = self._span_width_embedding(span_widths)
return torch.cat([combined_tensors, span_width_embeddings], -1)
if span_indices_mask is not None:
return combined_tensors * span_indices_mask.unsqueeze(-1).float()
return combined_tensors
示例14: forward
# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import batched_index_select [as 别名]
def forward(self,
sequence_tensor: torch.FloatTensor,
span_indices: torch.LongTensor,
span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor:
# both of shape (batch_size, num_spans, 1)
span_starts, span_ends = span_indices.split(1, dim=-1)
# shape (batch_size, num_spans, 1)
# These span widths are off by 1, because the span ends are `inclusive`.
span_widths = span_ends - span_starts
# We need to know the maximum span width so we can
# generate indices to extract the spans from the sequence tensor.
# These indices will then get masked below, such that if the length
# of a given span is smaller than the max, the rest of the values
# are masked.
max_batch_span_width = span_widths.max().item() + 1
# Shape: (1, 1, max_batch_span_width)
max_span_range_indices = util.get_range_vector(max_batch_span_width,
util.get_device_of(sequence_tensor)).view(1, 1, -1)
# Shape: (batch_size, num_spans, max_batch_span_width)
# This is a broadcasted comparison - for each span we are considering,
# we are creating a range vector of size max_span_width, but masking values
# which are greater than the actual length of the span.
#
# We're using <= here (and for the mask below) because the span ends are
# inclusive, so we want to include indices which are equal to span_widths rather
# than using it as a non-inclusive upper bound.
span_mask = (max_span_range_indices <= span_widths).float()
raw_span_indices = span_ends - max_span_range_indices
# We also don't want to include span indices which are less than zero,
# which happens because some spans near the beginning of the sequence
# have an end index < max_batch_span_width, so we add this to the mask here.
span_mask = span_mask * (raw_span_indices >= 0).float()
span_indices = torch.nn.functional.relu(raw_span_indices.float()).long()
# Shape: (batch_size * num_spans * max_batch_span_width)
flat_span_indices = util.flatten_and_batch_shift_indices(span_indices, sequence_tensor.size(1))
# Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim)
span_embeddings = util.batched_index_select(sequence_tensor, span_indices, flat_span_indices)
text_embeddings = span_embeddings * span_mask.unsqueeze(-1)
sum_text_embeddings = text_embeddings.sum(dim=2)
return sum_text_embeddings
示例15: forward
# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import batched_index_select [as 别名]
def forward(self,
sequence_tensor: torch.FloatTensor,
span_indices: torch.LongTensor,
span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor:
# both of shape (batch_size, num_spans, 1)
span_starts, span_ends = span_indices.split(1, dim=-1)
# shape (batch_size, num_spans, 1)
# These span widths are off by 1, because the span ends are `inclusive`.
span_widths = span_ends - span_starts
# We need to know the maximum span width so we can
# generate indices to extract the spans from the sequence tensor.
# These indices will then get masked below, such that if the length
# of a given span is smaller than the max, the rest of the values
# are masked.
max_batch_span_width = span_widths.max().item() + 1
# Shape: (1, 1, max_batch_span_width)
max_span_range_indices = util.get_range_vector(max_batch_span_width,
util.get_device_of(sequence_tensor)).view(1, 1, -1)
# Shape: (batch_size, num_spans, max_batch_span_width)
# This is a broadcasted comparison - for each span we are considering,
# we are creating a range vector of size max_span_width, but masking values
# which are greater than the actual length of the span.
#
# We're using <= here (and for the mask below) because the span ends are
# inclusive, so we want to include indices which are equal to span_widths rather
# than using it as a non-inclusive upper bound.
span_mask = (max_span_range_indices <= span_widths).float()
raw_span_indices = span_ends - max_span_range_indices
# We also don't want to include span indices which are less than zero,
# which happens because some spans near the beginning of the sequence
# have an end index < max_batch_span_width, so we add this to the mask here.
span_mask = span_mask * (raw_span_indices >= 0).float()
span_indices = torch.nn.functional.relu(raw_span_indices.float()).long()
# Shape: (batch_size * num_spans * max_batch_span_width)
flat_span_indices = util.flatten_and_batch_shift_indices(span_indices, sequence_tensor.size(1))
# Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim)
span_embeddings = util.batched_index_select(sequence_tensor, span_indices, flat_span_indices)
text_embeddings = span_embeddings * span_mask.unsqueeze(-1)
sum_text_embeddings = text_embeddings.sum(dim=2)
span_num = span_mask.unsqueeze(-1).sum(dim=2)
mean_text_embeddings = sum_text_embeddings / span_num
return mean_text_embeddings
# sequence_tensor = torch.randn(2, 5, 5)
# span_indices = torch.LongTensor([[[0, 1]], [[1, 3]]])
# extractor = MeanSpanExtractor(5)
# print(extractor(sequence_tensor, span_indices))
# print("====")
# print((sequence_tensor[0][0] + sequence_tensor[0][1]) / 2)
# print((sequence_tensor[1][1] + sequence_tensor[1][2] + sequence_tensor[1][3])/3 )