当前位置: 首页>>代码示例>>Python>>正文


Python Variable.split方法代码示例

本文整理汇总了Python中torch.autograd.Variable.split方法的典型用法代码示例。如果您正苦于以下问题:Python Variable.split方法的具体用法?Python Variable.split怎么用?Python Variable.split使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.autograd.Variable的用法示例。


在下文中一共展示了Variable.split方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: run_batch

# 需要导入模块: from torch.autograd import Variable [as 别名]
# 或者: from torch.autograd.Variable import split [as 别名]
 def run_batch(self, batch_data, dataset='train', split=52, **kwargs):
     valid, loss = dataset != 'train', self.init_loss()
     pad, eos = self.model.src_dict.get_pad(), self.model.src_dict.get_eos()
     source, targets = batch_data
     # remove <eos> from decoder targets substituting them with <pad>
     decode_targets = Variable(u.map_index(targets[:-1].data, eos, pad))
     # remove <bos> from loss targets
     loss_targets = targets[1:]
     # compute model output
     outs = self.model(source[1:], decode_targets)
     # dettach outs from computational graph
     det_outs = Variable(outs.data, requires_grad=not valid, volatile=valid)
     for out, trg in zip(det_outs.split(split), loss_targets.split(split)):
         # (seq_len x batch x hid_dim) -> (seq_len * batch x hid_dim)
         out = out.view(-1, out.size(2))
         pred = self.model.project(out)
         loss = self.update_loss(loss, self.criterion(pred, trg.view(-1)))
     if not valid:
         batch = outs.size(1)
         for l in loss:
             l.div(batch).backward()
         grad = None if det_outs.grad is None else det_outs.grad.data
         outs.backward(grad)
         self.optimizer_step()
     return tuple(l.data[0] for l in loss)
开发者ID:mikekestemont,项目名称:seqmod,代码行数:27,代码来源:trainer.py

示例2: test_correct_sequence_elements_are_embedded

# 需要导入模块: from torch.autograd import Variable [as 别名]
# 或者: from torch.autograd.Variable import split [as 别名]
    def test_correct_sequence_elements_are_embedded(self):
        sequence_tensor = Variable(torch.randn([2, 5, 8]))
        # concatentate start and end points together to form our representation
        # for both the forward and backward directions.
        extractor = BidirectionalEndpointSpanExtractor(input_dim=8,
                                                       forward_combination="x,y",
                                                       backward_combination="x,y")
        indices = Variable(torch.LongTensor([[[1, 3],
                                              [2, 4]],
                                             [[0, 2],
                                              [3, 4]]]))

        span_representations = extractor(sequence_tensor, indices)

        assert list(span_representations.size()) == [2, 2, 16]
        assert extractor.get_output_dim() == 16
        assert extractor.get_input_dim() == 8

        # We just concatenated the start and end embeddings together, so
        # we can check they match the original indices if we split them apart.
        (forward_start_embeddings, forward_end_embeddings,
         backward_start_embeddings, backward_end_embeddings) = span_representations.split(4, -1)

        forward_sequence_tensor, backward_sequence_tensor = sequence_tensor.split(4, -1)

        # Forward direction => subtract 1 from start indices to make them exlusive.
        correct_forward_start_indices = Variable(torch.LongTensor([[0, 1],
                                                                   [-1, 2]]))
        # This index should be -1, so it will be replaced with a sentinel. Here,
        # we'll set it to a value other than -1 so we can index select the indices and
        # replace it later.
        correct_forward_start_indices[1, 0] = 1

        # Forward direction => end indices are the same.
        correct_forward_end_indices = Variable(torch.LongTensor([[3, 4], [2, 4]]))

        # Backward direction => start indices are exclusive, so add 1 to the end indices.
        correct_backward_start_indices = Variable(torch.LongTensor([[4, 5], [3, 5]]))
        # These exclusive end indices are outside the tensor, so will be replaced with the end sentinel.
        # Here we replace them with ones so we can index select using these indices without torch
        # complaining.
        correct_backward_start_indices[0, 1] = 1
        correct_backward_start_indices[1, 1] = 1
        # Backward direction => end indices are inclusive and equal to the forward start indices.
        correct_backward_end_indices = Variable(torch.LongTensor([[1, 2], [0, 3]]))

        correct_forward_start_embeddings = batched_index_select(forward_sequence_tensor.contiguous(),
                                                                correct_forward_start_indices)
        # This element had sequence_tensor index of 0, so it's exclusive index is the start sentinel.
        correct_forward_start_embeddings[1, 0] = extractor._start_sentinel.data
        numpy.testing.assert_array_equal(forward_start_embeddings.data.numpy(),
                                         correct_forward_start_embeddings.data.numpy())

        correct_forward_end_embeddings = batched_index_select(forward_sequence_tensor.contiguous(),
                                                              correct_forward_end_indices)
        numpy.testing.assert_array_equal(forward_end_embeddings.data.numpy(),
                                         correct_forward_end_embeddings.data.numpy())

        correct_backward_end_embeddings = batched_index_select(backward_sequence_tensor.contiguous(),
                                                               correct_backward_end_indices)
        numpy.testing.assert_array_equal(backward_end_embeddings.data.numpy(),
                                         correct_backward_end_embeddings.data.numpy())

        correct_backward_start_embeddings = batched_index_select(backward_sequence_tensor.contiguous(),
                                                                 correct_backward_start_indices)
        # This element had sequence_tensor index == sequence_tensor.size(1),
        # so it's exclusive index is the end sentinel.
        correct_backward_start_embeddings[0, 1] = extractor._end_sentinel.data
        correct_backward_start_embeddings[1, 1] = extractor._end_sentinel.data
        numpy.testing.assert_array_equal(backward_start_embeddings.data.numpy(),
                                         correct_backward_start_embeddings.data.numpy())
开发者ID:Jordan-Sauchuk,项目名称:allennlp,代码行数:73,代码来源:bidirectional_endpoint_span_extractor_test.py


注:本文中的torch.autograd.Variable.split方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。