本文整理汇总了Python中torch.autograd.Variable.split方法的典型用法代码示例。如果您正苦于以下问题:Python Variable.split方法的具体用法?Python Variable.split怎么用?Python Variable.split使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.autograd.Variable
的用法示例。
在下文中一共展示了Variable.split方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: run_batch
# 需要导入模块: from torch.autograd import Variable [as 别名]
# 或者: from torch.autograd.Variable import split [as 别名]
def run_batch(self, batch_data, dataset='train', split=52, **kwargs):
valid, loss = dataset != 'train', self.init_loss()
pad, eos = self.model.src_dict.get_pad(), self.model.src_dict.get_eos()
source, targets = batch_data
# remove <eos> from decoder targets substituting them with <pad>
decode_targets = Variable(u.map_index(targets[:-1].data, eos, pad))
# remove <bos> from loss targets
loss_targets = targets[1:]
# compute model output
outs = self.model(source[1:], decode_targets)
# dettach outs from computational graph
det_outs = Variable(outs.data, requires_grad=not valid, volatile=valid)
for out, trg in zip(det_outs.split(split), loss_targets.split(split)):
# (seq_len x batch x hid_dim) -> (seq_len * batch x hid_dim)
out = out.view(-1, out.size(2))
pred = self.model.project(out)
loss = self.update_loss(loss, self.criterion(pred, trg.view(-1)))
if not valid:
batch = outs.size(1)
for l in loss:
l.div(batch).backward()
grad = None if det_outs.grad is None else det_outs.grad.data
outs.backward(grad)
self.optimizer_step()
return tuple(l.data[0] for l in loss)
示例2: test_correct_sequence_elements_are_embedded
# 需要导入模块: from torch.autograd import Variable [as 别名]
# 或者: from torch.autograd.Variable import split [as 别名]
def test_correct_sequence_elements_are_embedded(self):
sequence_tensor = Variable(torch.randn([2, 5, 8]))
# concatentate start and end points together to form our representation
# for both the forward and backward directions.
extractor = BidirectionalEndpointSpanExtractor(input_dim=8,
forward_combination="x,y",
backward_combination="x,y")
indices = Variable(torch.LongTensor([[[1, 3],
[2, 4]],
[[0, 2],
[3, 4]]]))
span_representations = extractor(sequence_tensor, indices)
assert list(span_representations.size()) == [2, 2, 16]
assert extractor.get_output_dim() == 16
assert extractor.get_input_dim() == 8
# We just concatenated the start and end embeddings together, so
# we can check they match the original indices if we split them apart.
(forward_start_embeddings, forward_end_embeddings,
backward_start_embeddings, backward_end_embeddings) = span_representations.split(4, -1)
forward_sequence_tensor, backward_sequence_tensor = sequence_tensor.split(4, -1)
# Forward direction => subtract 1 from start indices to make them exlusive.
correct_forward_start_indices = Variable(torch.LongTensor([[0, 1],
[-1, 2]]))
# This index should be -1, so it will be replaced with a sentinel. Here,
# we'll set it to a value other than -1 so we can index select the indices and
# replace it later.
correct_forward_start_indices[1, 0] = 1
# Forward direction => end indices are the same.
correct_forward_end_indices = Variable(torch.LongTensor([[3, 4], [2, 4]]))
# Backward direction => start indices are exclusive, so add 1 to the end indices.
correct_backward_start_indices = Variable(torch.LongTensor([[4, 5], [3, 5]]))
# These exclusive end indices are outside the tensor, so will be replaced with the end sentinel.
# Here we replace them with ones so we can index select using these indices without torch
# complaining.
correct_backward_start_indices[0, 1] = 1
correct_backward_start_indices[1, 1] = 1
# Backward direction => end indices are inclusive and equal to the forward start indices.
correct_backward_end_indices = Variable(torch.LongTensor([[1, 2], [0, 3]]))
correct_forward_start_embeddings = batched_index_select(forward_sequence_tensor.contiguous(),
correct_forward_start_indices)
# This element had sequence_tensor index of 0, so it's exclusive index is the start sentinel.
correct_forward_start_embeddings[1, 0] = extractor._start_sentinel.data
numpy.testing.assert_array_equal(forward_start_embeddings.data.numpy(),
correct_forward_start_embeddings.data.numpy())
correct_forward_end_embeddings = batched_index_select(forward_sequence_tensor.contiguous(),
correct_forward_end_indices)
numpy.testing.assert_array_equal(forward_end_embeddings.data.numpy(),
correct_forward_end_embeddings.data.numpy())
correct_backward_end_embeddings = batched_index_select(backward_sequence_tensor.contiguous(),
correct_backward_end_indices)
numpy.testing.assert_array_equal(backward_end_embeddings.data.numpy(),
correct_backward_end_embeddings.data.numpy())
correct_backward_start_embeddings = batched_index_select(backward_sequence_tensor.contiguous(),
correct_backward_start_indices)
# This element had sequence_tensor index == sequence_tensor.size(1),
# so it's exclusive index is the end sentinel.
correct_backward_start_embeddings[0, 1] = extractor._end_sentinel.data
correct_backward_start_embeddings[1, 1] = extractor._end_sentinel.data
numpy.testing.assert_array_equal(backward_start_embeddings.data.numpy(),
correct_backward_start_embeddings.data.numpy())