当前位置: 首页>>代码示例>>Python>>正文


Python util.batched_index_select方法代码示例

本文整理汇总了Python中allennlp.nn.util.batched_index_select方法的典型用法代码示例。如果您正苦于以下问题:Python util.batched_index_select方法的具体用法?Python util.batched_index_select怎么用?Python util.batched_index_select使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在allennlp.nn.util的用法示例。


在下文中一共展示了util.batched_index_select方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_correct_sequence_elements_are_embedded

# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import batched_index_select [as 别名]
def test_correct_sequence_elements_are_embedded(self):
        sequence_tensor = torch.randn([2, 5, 7])
        # Concatentate start and end points together to form our representation.
        extractor = EndpointSpanExtractor(7, "x,y")

        indices = torch.LongTensor([[[1, 3], [2, 4]], [[0, 2], [3, 4]]])
        span_representations = extractor(sequence_tensor, indices)

        assert list(span_representations.size()) == [2, 2, 14]
        assert extractor.get_output_dim() == 14
        assert extractor.get_input_dim() == 7

        start_indices, end_indices = indices.split(1, -1)
        # We just concatenated the start and end embeddings together, so
        # we can check they match the original indices if we split them apart.
        start_embeddings, end_embeddings = span_representations.split(7, -1)

        correct_start_embeddings = batched_index_select(sequence_tensor, start_indices.squeeze())
        correct_end_embeddings = batched_index_select(sequence_tensor, end_indices.squeeze())
        numpy.testing.assert_array_equal(
            start_embeddings.data.numpy(), correct_start_embeddings.data.numpy()
        )
        numpy.testing.assert_array_equal(
            end_embeddings.data.numpy(), correct_end_embeddings.data.numpy()
        ) 
开发者ID:allenai,项目名称:allennlp,代码行数:27,代码来源:endpoint_span_extractor_test.py

示例2: test_masked_topk_selects_top_scored_items_and_respects_masking

# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import batched_index_select [as 别名]
def test_masked_topk_selects_top_scored_items_and_respects_masking(self):
        items = torch.randn([3, 4, 5]).clamp(min=0.0, max=1.0)
        items[0, :2, :] = 1
        items[1, 2:, :] = 1
        items[2, 2:, :] = 1

        scores = items.sum(-1)

        mask = torch.ones([3, 4]).bool()
        mask[1, 0] = 0
        mask[1, 3] = 0

        pruned_scores, pruned_mask, pruned_indices = util.masked_topk(scores, mask, 2)

        # Second element in the batch would have indices 2, 3, but
        # 3 and 0 are masked, so instead it has 1, 2.
        numpy.testing.assert_array_equal(
            pruned_indices.data.numpy(), numpy.array([[0, 1], [1, 2], [2, 3]])
        )
        numpy.testing.assert_array_equal(pruned_mask.data.numpy(), numpy.ones([3, 2]))

        # scores should be the result of index_selecting the pruned_indices.
        correct_scores = util.batched_index_select(scores.unsqueeze(-1), pruned_indices).squeeze(-1)
        self.assert_array_equal_with_mask(correct_scores, pruned_scores, pruned_mask) 
开发者ID:allenai,项目名称:allennlp,代码行数:26,代码来源:util_test.py

示例3: test_batched_index_select

# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import batched_index_select [as 别名]
def test_batched_index_select(self):
        indices = numpy.array([[[1, 2],
                                [3, 4]],
                               [[5, 6],
                                [7, 8]]])
        # Each element is a vector of it's index.
        targets = torch.ones([2, 10, 3]).cumsum(1) - 1
        # Make the second batch double it's index so they're different.
        targets[1, :, :] *= 2
        indices = torch.tensor(indices, dtype=torch.long)
        selected = util.batched_index_select(targets, indices)

        assert list(selected.size()) == [2, 2, 2, 3]
        ones = numpy.ones([3])
        numpy.testing.assert_array_equal(selected[0, 0, 0, :].data.numpy(), ones)
        numpy.testing.assert_array_equal(selected[0, 0, 1, :].data.numpy(), ones * 2)
        numpy.testing.assert_array_equal(selected[0, 1, 0, :].data.numpy(), ones * 3)
        numpy.testing.assert_array_equal(selected[0, 1, 1, :].data.numpy(), ones * 4)

        numpy.testing.assert_array_equal(selected[1, 0, 0, :].data.numpy(), ones * 10)
        numpy.testing.assert_array_equal(selected[1, 0, 1, :].data.numpy(), ones * 12)
        numpy.testing.assert_array_equal(selected[1, 1, 0, :].data.numpy(), ones * 14)
        numpy.testing.assert_array_equal(selected[1, 1, 1, :].data.numpy(), ones * 16) 
开发者ID:plasticityai,项目名称:magnitude,代码行数:25,代码来源:util_test.py

示例4: test_masked_indices_are_handled_correctly

# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import batched_index_select [as 别名]
def test_masked_indices_are_handled_correctly(self):
        sequence_tensor = torch.randn([2, 5, 7])
        # concatentate start and end points together to form our representation.
        extractor = EndpointSpanExtractor(7, "x,y")

        indices = torch.LongTensor([[[1, 3], [2, 4]], [[0, 2], [3, 4]]])
        span_representations = extractor(sequence_tensor, indices)

        # Make a mask with the second batch element completely masked.
        indices_mask = torch.tensor([[True, True], [False, False]])

        span_representations = extractor(sequence_tensor, indices, span_indices_mask=indices_mask)
        start_embeddings, end_embeddings = span_representations.split(7, -1)
        start_indices, end_indices = indices.split(1, -1)

        correct_start_embeddings = batched_index_select(
            sequence_tensor, start_indices.squeeze()
        ).data
        # Completely masked second batch element, so it should all be zero.
        correct_start_embeddings[1, :, :].fill_(0)
        correct_end_embeddings = batched_index_select(sequence_tensor, end_indices.squeeze()).data
        correct_end_embeddings[1, :, :].fill_(0)
        numpy.testing.assert_array_equal(
            start_embeddings.data.numpy(), correct_start_embeddings.numpy()
        )
        numpy.testing.assert_array_equal(
            end_embeddings.data.numpy(), correct_end_embeddings.numpy()
        ) 
开发者ID:allenai,项目名称:allennlp,代码行数:30,代码来源:endpoint_span_extractor_test.py

示例5: test_masked_indices_are_handled_correctly_with_exclusive_indices

# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import batched_index_select [as 别名]
def test_masked_indices_are_handled_correctly_with_exclusive_indices(self):
        sequence_tensor = torch.randn([2, 5, 8])
        # concatentate start and end points together to form our representation
        # for both the forward and backward directions.
        extractor = EndpointSpanExtractor(8, "x,y", use_exclusive_start_indices=True)
        indices = torch.LongTensor([[[1, 3], [2, 4]], [[0, 2], [0, 1]]])
        sequence_mask = torch.tensor(
            [[True, True, True, True, True], [True, True, True, False, False]]
        )

        span_representations = extractor(sequence_tensor, indices, sequence_mask=sequence_mask)

        # We just concatenated the start and end embeddings together, so
        # we can check they match the original indices if we split them apart.
        start_embeddings, end_embeddings = span_representations.split(8, -1)

        correct_start_indices = torch.LongTensor([[0, 1], [-1, -1]])
        # These indices should be -1, so they'll be replaced with a sentinel. Here,
        # we'll set them to a value other than -1 so we can index select the indices and
        # replace them later.
        correct_start_indices[1, 0] = 1
        correct_start_indices[1, 1] = 1

        correct_end_indices = torch.LongTensor([[3, 4], [2, 1]])

        correct_start_embeddings = batched_index_select(
            sequence_tensor.contiguous(), correct_start_indices
        )
        # This element had sequence_tensor index of 0, so it's exclusive index is the start sentinel.
        correct_start_embeddings[1, 0] = extractor._start_sentinel.data
        correct_start_embeddings[1, 1] = extractor._start_sentinel.data
        numpy.testing.assert_array_equal(
            start_embeddings.data.numpy(), correct_start_embeddings.data.numpy()
        )

        correct_end_embeddings = batched_index_select(
            sequence_tensor.contiguous(), correct_end_indices
        )
        numpy.testing.assert_array_equal(
            end_embeddings.data.numpy(), correct_end_embeddings.data.numpy()
        ) 
开发者ID:allenai,项目名称:allennlp,代码行数:43,代码来源:endpoint_span_extractor_test.py

示例6: test_batched_index_select

# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import batched_index_select [as 别名]
def test_batched_index_select(self):
        indices = numpy.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
        # Each element is a vector of its index.
        targets = torch.ones([2, 10, 3]).cumsum(1) - 1
        # Make the second batch double its index so they're different.
        targets[1, :, :] *= 2
        indices = torch.tensor(indices, dtype=torch.long)
        selected = util.batched_index_select(targets, indices)

        assert list(selected.size()) == [2, 2, 2, 3]
        ones = numpy.ones([3])
        numpy.testing.assert_array_equal(selected[0, 0, 0, :].data.numpy(), ones)
        numpy.testing.assert_array_equal(selected[0, 0, 1, :].data.numpy(), ones * 2)
        numpy.testing.assert_array_equal(selected[0, 1, 0, :].data.numpy(), ones * 3)
        numpy.testing.assert_array_equal(selected[0, 1, 1, :].data.numpy(), ones * 4)

        numpy.testing.assert_array_equal(selected[1, 0, 0, :].data.numpy(), ones * 10)
        numpy.testing.assert_array_equal(selected[1, 0, 1, :].data.numpy(), ones * 12)
        numpy.testing.assert_array_equal(selected[1, 1, 0, :].data.numpy(), ones * 14)
        numpy.testing.assert_array_equal(selected[1, 1, 1, :].data.numpy(), ones * 16)

        indices = numpy.array([[[1, 11], [3, 4]], [[5, 6], [7, 8]]])
        indices = torch.tensor(indices, dtype=torch.long)
        with pytest.raises(ConfigurationError):
            util.batched_index_select(targets, indices)

        indices = numpy.array([[[1, -1], [3, 4]], [[5, 6], [7, 8]]])
        indices = torch.tensor(indices, dtype=torch.long)
        with pytest.raises(ConfigurationError):
            util.batched_index_select(targets, indices) 
开发者ID:allenai,项目名称:allennlp,代码行数:32,代码来源:util_test.py

示例7: test_masked_topk_selects_top_scored_items_and_respects_masking_different_num_items

# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import batched_index_select [as 别名]
def test_masked_topk_selects_top_scored_items_and_respects_masking_different_num_items(self):
        items = torch.randn([3, 4, 5]).clamp(min=0.0, max=1.0)
        items[0, 0, :] = 1.5
        items[0, 1, :] = 2
        items[0, 3, :] = 1
        items[1, 1:3, :] = 1
        items[2, 0, :] = 1
        items[2, 1, :] = 2
        items[2, 2, :] = 1.5

        scores = items.sum(-1)

        mask = torch.ones([3, 4]).bool()
        mask[1, 3] = 0

        k = torch.tensor([3, 2, 1], dtype=torch.long)

        pruned_scores, pruned_mask, pruned_indices = util.masked_topk(scores, mask, k)

        # Second element in the batch would have indices 2, 3, but
        # 3 and 0 are masked, so instead it has 1, 2.
        numpy.testing.assert_array_equal(
            pruned_indices.data.numpy(), numpy.array([[0, 1, 3], [1, 2, 2], [1, 2, 2]])
        )
        numpy.testing.assert_array_equal(
            pruned_mask.data.numpy(), numpy.array([[1, 1, 1], [1, 1, 0], [1, 0, 0]])
        )

        # scores should be the result of index_selecting the pruned_indices.
        correct_scores = util.batched_index_select(scores.unsqueeze(-1), pruned_indices).squeeze(-1)
        self.assert_array_equal_with_mask(correct_scores, pruned_scores, pruned_mask) 
开发者ID:allenai,项目名称:allennlp,代码行数:33,代码来源:util_test.py

示例8: test_masked_topk_works_for_row_with_no_items_requested

# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import batched_index_select [as 别名]
def test_masked_topk_works_for_row_with_no_items_requested(self):
        # Case where `num_items_to_keep` is a tensor rather than an int. Make sure it does the right
        # thing when no items are requested for one of the rows.

        items = torch.randn([3, 4, 5]).clamp(min=0.0, max=1.0)
        items[0, :3, :] = 1
        items[1, 2:, :] = 1
        items[2, 2:, :] = 1

        scores = items.sum(-1)

        mask = torch.ones([3, 4]).bool()
        mask[1, 0] = 0
        mask[1, 3] = 0

        k = torch.tensor([3, 2, 0], dtype=torch.long)

        pruned_scores, pruned_mask, pruned_indices = util.masked_topk(scores, mask, k)

        # First element just picks top three entries. Second would pick entries 2 and 3, but 0 and 3
        # are masked, so it takes 1 and 2 (repeating the second index). The third element is
        # entirely masked and just repeats the largest index with a top-3 score.
        numpy.testing.assert_array_equal(
            pruned_indices.data.numpy(), numpy.array([[0, 1, 2], [1, 2, 2], [3, 3, 3]])
        )
        numpy.testing.assert_array_equal(
            pruned_mask.data.numpy(), numpy.array([[1, 1, 1], [1, 1, 0], [0, 0, 0]])
        )

        # scores should be the result of index_selecting the pruned_indices.
        correct_scores = util.batched_index_select(scores.unsqueeze(-1), pruned_indices).squeeze(-1)
        self.assert_array_equal_with_mask(correct_scores, pruned_scores, pruned_mask) 
开发者ID:allenai,项目名称:allennlp,代码行数:34,代码来源:util_test.py

示例9: test_span_pruner_selects_top_scored_spans_and_respects_masking

# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import batched_index_select [as 别名]
def test_span_pruner_selects_top_scored_spans_and_respects_masking(self):
        # Really simple scorer - sum up the embedding_dim.
        scorer = lambda tensor: tensor.sum(-1).unsqueeze(-1)
        pruner = SpanPruner(scorer=scorer)

        spans = torch.randn([3, 4, 5]).clamp(min=0.0, max=1.0)
        spans[0, :2, :] = 1
        spans[1, 2:, :] = 1
        spans[2, 2:, :] = 1

        mask = torch.ones([3, 4])
        mask[1, 0] = 0
        mask[1, 3] = 0
        pruned_embeddings, pruned_mask, pruned_indices, pruned_scores = pruner(spans, mask, 2)

        # Second element in the batch would have indices 2, 3, but
        # 3 and 0 are masked, so instead it has 1, 2.
        numpy.testing.assert_array_equal(pruned_indices.data.numpy(), numpy.array([[0, 1],
                                                                                   [1, 2],
                                                                                   [2, 3]]))
        numpy.testing.assert_array_equal(pruned_mask.data.numpy(), numpy.ones([3, 2]))

        # embeddings should be the result of index_selecting the pruned_indices.
        correct_embeddings = batched_index_select(spans, pruned_indices)
        numpy.testing.assert_array_equal(correct_embeddings.data.numpy(),
                                         pruned_embeddings.data.numpy())
        # scores should be the sum of the correct embedding elements.
        numpy.testing.assert_array_equal(correct_embeddings.sum(-1).unsqueeze(-1).data.numpy(),
                                         pruned_scores.data.numpy()) 
开发者ID:plasticityai,项目名称:magnitude,代码行数:31,代码来源:span_pruner_test.py

示例10: test_span_scorer_works_for_completely_masked_rows

# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import batched_index_select [as 别名]
def test_span_scorer_works_for_completely_masked_rows(self):
        # Really simple scorer - sum up the embedding_dim.
        scorer = lambda tensor: tensor.sum(-1).unsqueeze(-1)
        pruner = SpanPruner(scorer=scorer) # type: ignore

        spans = torch.randn([3, 4, 5]).clamp(min=0.0, max=1.0)
        spans[0, :2, :] = 1
        spans[1, 2:, :] = 1
        spans[2, 2:, :] = 1

        mask = torch.ones([3, 4])
        mask[1, 0] = 0
        mask[1, 3] = 0
        mask[2, :] = 0 # fully masked last batch element.

        pruned_embeddings, pruned_mask, pruned_indices, pruned_scores = pruner(spans, mask, 2)

        # We can't check the last row here, because it's completely masked.
        # Instead we'll check that the scores for these elements are -inf.
        numpy.testing.assert_array_equal(pruned_indices[:2].data.numpy(), numpy.array([[0, 1],
                                                                                       [1, 2]]))
        numpy.testing.assert_array_equal(pruned_mask.data.numpy(), numpy.array([[1, 1],
                                                                                [1, 1],
                                                                                [0, 0]]))
        # embeddings should be the result of index_selecting the pruned_indices.
        correct_embeddings = batched_index_select(spans, pruned_indices)
        numpy.testing.assert_array_equal(correct_embeddings.data.numpy(),
                                         pruned_embeddings.data.numpy())
        # scores should be the sum of the correct embedding elements, with
        # masked elements equal to -inf.
        correct_scores = correct_embeddings.sum(-1).unsqueeze(-1).data.numpy()
        correct_scores[2, :] = float(u"-inf")
        numpy.testing.assert_array_equal(correct_scores,
                                         pruned_scores.data.numpy()) 
开发者ID:plasticityai,项目名称:magnitude,代码行数:36,代码来源:span_pruner_test.py

示例11: test_correct_sequence_elements_are_embedded

# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import batched_index_select [as 别名]
def test_correct_sequence_elements_are_embedded(self):
        sequence_tensor = torch.randn([2, 5, 7])
        # Concatentate start and end points together to form our representation.
        extractor = EndpointSpanExtractor(7, u"x,y")

        indices = torch.LongTensor([[[1, 3],
                                     [2, 4]],
                                    [[0, 2],
                                     [3, 4]]])
        span_representations = extractor(sequence_tensor, indices)

        assert list(span_representations.size()) == [2, 2, 14]
        assert extractor.get_output_dim() == 14
        assert extractor.get_input_dim() == 7

        start_indices, end_indices = indices.split(1, -1)
        # We just concatenated the start and end embeddings together, so
        # we can check they match the original indices if we split them apart.
        start_embeddings, end_embeddings = span_representations.split(7, -1)

        correct_start_embeddings = batched_index_select(sequence_tensor, start_indices.squeeze())
        correct_end_embeddings = batched_index_select(sequence_tensor, end_indices.squeeze())
        numpy.testing.assert_array_equal(start_embeddings.data.numpy(),
                                         correct_start_embeddings.data.numpy())
        numpy.testing.assert_array_equal(end_embeddings.data.numpy(),
                                         correct_end_embeddings.data.numpy()) 
开发者ID:plasticityai,项目名称:magnitude,代码行数:28,代码来源:endpoint_span_extractor_test.py

示例12: test_masked_indices_are_handled_correctly

# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import batched_index_select [as 别名]
def test_masked_indices_are_handled_correctly(self):
        sequence_tensor = torch.randn([2, 5, 7])
        # concatentate start and end points together to form our representation.
        extractor = EndpointSpanExtractor(7, u"x,y")

        indices = torch.LongTensor([[[1, 3],
                                     [2, 4]],
                                    [[0, 2],
                                     [3, 4]]])
        span_representations = extractor(sequence_tensor, indices)

        # Make a mask with the second batch element completely masked.
        indices_mask = torch.LongTensor([[1, 1], [0, 0]])

        span_representations = extractor(sequence_tensor, indices, span_indices_mask=indices_mask)
        start_embeddings, end_embeddings = span_representations.split(7, -1)
        start_indices, end_indices = indices.split(1, -1)

        correct_start_embeddings = batched_index_select(sequence_tensor, start_indices.squeeze()).data
        # Completely masked second batch element, so it should all be zero.
        correct_start_embeddings[1, :, :].fill_(0)
        correct_end_embeddings = batched_index_select(sequence_tensor, end_indices.squeeze()).data
        correct_end_embeddings[1, :, :].fill_(0)
        numpy.testing.assert_array_equal(start_embeddings.data.numpy(),
                                         correct_start_embeddings.numpy())
        numpy.testing.assert_array_equal(end_embeddings.data.numpy(),
                                         correct_end_embeddings.numpy()) 
开发者ID:plasticityai,项目名称:magnitude,代码行数:29,代码来源:endpoint_span_extractor_test.py

示例13: forward

# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import batched_index_select [as 别名]
def forward(self,
                sequence_tensor                   ,
                span_indices                  ,
                sequence_mask                   = None,
                span_indices_mask                   = None)        :
        # shape (batch_size, num_spans)
        span_starts, span_ends = [index.squeeze(-1) for index in span_indices.split(1, dim=-1)]

        if span_indices_mask is not None:
            # It's not strictly necessary to multiply the span indices by the mask here,
            # but it's possible that the span representation was padded with something other
            # than 0 (such as -1, which would be an invalid index), so we do so anyway to
            # be safe.
            span_starts = span_starts * span_indices_mask
            span_ends = span_ends * span_indices_mask

        if not self._use_exclusive_start_indices:
            start_embeddings = util.batched_index_select(sequence_tensor, span_starts)
            end_embeddings = util.batched_index_select(sequence_tensor, span_ends)

        else:
            # We want `exclusive` span starts, so we remove 1 from the forward span starts
            # as the AllenNLP ``SpanField`` is inclusive.
            # shape (batch_size, num_spans)
            exclusive_span_starts = span_starts - 1
            # shape (batch_size, num_spans, 1)
            start_sentinel_mask = (exclusive_span_starts == -1).long().unsqueeze(-1)
            exclusive_span_starts = exclusive_span_starts * (1 - start_sentinel_mask.squeeze(-1))

            # We'll check the indices here at runtime, because it's difficult to debug
            # if this goes wrong and it's tricky to get right.
            if (exclusive_span_starts < 0).any():
                raise ValueError("Adjusted span indices must lie inside the the sequence tensor, "
                                 "but found: exclusive_span_starts: {exclusive_span_starts}.")

            start_embeddings = util.batched_index_select(sequence_tensor, exclusive_span_starts)
            end_embeddings = util.batched_index_select(sequence_tensor, span_ends)

            # We're using sentinels, so we need to replace all the elements which were
            # outside the dimensions of the sequence_tensor with the start sentinel.
            float_start_sentinel_mask = start_sentinel_mask.float()
            start_embeddings = start_embeddings * (1 - float_start_sentinel_mask)\
                                        + float_start_sentinel_mask * self._start_sentinel

        combined_tensors = util.combine_tensors(self._combination, [start_embeddings, end_embeddings])
        if self._span_width_embedding is not None:
            # Embed the span widths and concatenate to the rest of the representations.
            if self._bucket_widths:
                span_widths = util.bucket_values(span_ends - span_starts,
                                                 num_total_buckets=self._num_width_embeddings)
            else:
                span_widths = span_ends - span_starts

            span_width_embeddings = self._span_width_embedding(span_widths)
            return torch.cat([combined_tensors, span_width_embeddings], -1)

        if span_indices_mask is not None:
            return combined_tensors * span_indices_mask.unsqueeze(-1).float()
        return combined_tensors 
开发者ID:plasticityai,项目名称:magnitude,代码行数:61,代码来源:endpoint_span_extractor.py

示例14: forward

# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import batched_index_select [as 别名]
def forward(self,
                sequence_tensor: torch.FloatTensor,
                span_indices: torch.LongTensor,
                span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor:
        # both of shape (batch_size, num_spans, 1)
        span_starts, span_ends = span_indices.split(1, dim=-1)

        # shape (batch_size, num_spans, 1)
        # These span widths are off by 1, because the span ends are `inclusive`.
        span_widths = span_ends - span_starts

        # We need to know the maximum span width so we can
        # generate indices to extract the spans from the sequence tensor.
        # These indices will then get masked below, such that if the length
        # of a given span is smaller than the max, the rest of the values
        # are masked.
        max_batch_span_width = span_widths.max().item() + 1

        # Shape: (1, 1, max_batch_span_width)
        max_span_range_indices = util.get_range_vector(max_batch_span_width,
                                                       util.get_device_of(sequence_tensor)).view(1, 1, -1)
        # Shape: (batch_size, num_spans, max_batch_span_width)
        # This is a broadcasted comparison - for each span we are considering,
        # we are creating a range vector of size max_span_width, but masking values
        # which are greater than the actual length of the span.
        #
        # We're using <= here (and for the mask below) because the span ends are
        # inclusive, so we want to include indices which are equal to span_widths rather
        # than using it as a non-inclusive upper bound.
        span_mask = (max_span_range_indices <= span_widths).float()
        raw_span_indices = span_ends - max_span_range_indices
        # We also don't want to include span indices which are less than zero,
        # which happens because some spans near the beginning of the sequence
        # have an end index < max_batch_span_width, so we add this to the mask here.
        span_mask = span_mask * (raw_span_indices >= 0).float()
        span_indices = torch.nn.functional.relu(raw_span_indices.float()).long()

        # Shape: (batch_size * num_spans * max_batch_span_width)
        flat_span_indices = util.flatten_and_batch_shift_indices(span_indices, sequence_tensor.size(1))

        # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim)
        span_embeddings = util.batched_index_select(sequence_tensor, span_indices, flat_span_indices)
        text_embeddings = span_embeddings * span_mask.unsqueeze(-1)
        sum_text_embeddings = text_embeddings.sum(dim=2)
        return sum_text_embeddings 
开发者ID:changzhisun,项目名称:AntNRE,代码行数:47,代码来源:sum_span_extractor.py

示例15: forward

# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import batched_index_select [as 别名]
def forward(self,
                sequence_tensor: torch.FloatTensor,
                span_indices: torch.LongTensor,
                span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor:
        # both of shape (batch_size, num_spans, 1)
        span_starts, span_ends = span_indices.split(1, dim=-1)

        # shape (batch_size, num_spans, 1)
        # These span widths are off by 1, because the span ends are `inclusive`.
        span_widths = span_ends - span_starts

        # We need to know the maximum span width so we can
        # generate indices to extract the spans from the sequence tensor.
        # These indices will then get masked below, such that if the length
        # of a given span is smaller than the max, the rest of the values
        # are masked.
        max_batch_span_width = span_widths.max().item() + 1

        # Shape: (1, 1, max_batch_span_width)
        max_span_range_indices = util.get_range_vector(max_batch_span_width,
                                                       util.get_device_of(sequence_tensor)).view(1, 1, -1)
        # Shape: (batch_size, num_spans, max_batch_span_width)
        # This is a broadcasted comparison - for each span we are considering,
        # we are creating a range vector of size max_span_width, but masking values
        # which are greater than the actual length of the span.
        #
        # We're using <= here (and for the mask below) because the span ends are
        # inclusive, so we want to include indices which are equal to span_widths rather
        # than using it as a non-inclusive upper bound.
        span_mask = (max_span_range_indices <= span_widths).float()
        raw_span_indices = span_ends - max_span_range_indices
        # We also don't want to include span indices which are less than zero,
        # which happens because some spans near the beginning of the sequence
        # have an end index < max_batch_span_width, so we add this to the mask here.
        span_mask = span_mask * (raw_span_indices >= 0).float()
        span_indices = torch.nn.functional.relu(raw_span_indices.float()).long()

        # Shape: (batch_size * num_spans * max_batch_span_width)
        flat_span_indices = util.flatten_and_batch_shift_indices(span_indices, sequence_tensor.size(1))

        # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim)
        span_embeddings = util.batched_index_select(sequence_tensor, span_indices, flat_span_indices)
        text_embeddings = span_embeddings * span_mask.unsqueeze(-1)
        sum_text_embeddings = text_embeddings.sum(dim=2)

        span_num = span_mask.unsqueeze(-1).sum(dim=2)
        mean_text_embeddings = sum_text_embeddings / span_num 
        return mean_text_embeddings


#  sequence_tensor = torch.randn(2, 5, 5)
#  span_indices = torch.LongTensor([[[0, 1]], [[1, 3]]])
#  extractor = MeanSpanExtractor(5)
#  print(extractor(sequence_tensor, span_indices))
#  print("====")
#  print((sequence_tensor[0][0] + sequence_tensor[0][1]) / 2)

#  print((sequence_tensor[1][1] + sequence_tensor[1][2] + sequence_tensor[1][3])/3 ) 
开发者ID:changzhisun,项目名称:AntNRE,代码行数:60,代码来源:mean_span_extractor.py


注:本文中的allennlp.nn.util.batched_index_select方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。