本文整理匯總了Python中torch.LongTensor.split方法的典型用法代碼示例。如果您正苦於以下問題:Python LongTensor.split方法的具體用法?Python LongTensor.split怎麽用?Python LongTensor.split使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類torch.LongTensor
的用法示例。
在下文中一共展示了LongTensor.split方法的3個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: forward
# 需要導入模塊: from torch import LongTensor [as 別名]
# 或者: from torch.LongTensor import split [as 別名]
def forward(self,
sequence_tensor: torch.FloatTensor,
span_indices: torch.LongTensor,
sequence_mask: torch.LongTensor = None,
span_indices_mask: torch.LongTensor = None) -> None:
# shape (batch_size, num_spans)
span_starts, span_ends = [index.squeeze(-1) for index in span_indices.split(1, dim=-1)]
if span_indices_mask is not None:
# It's not strictly necessary to multiply the span indices by the mask here,
# but it's possible that the span representation was padded with something other
# than 0 (such as -1, which would be an invalid index), so we do so anyway to
# be safe.
span_starts = span_starts * span_indices_mask
span_ends = span_ends * span_indices_mask
if not self._use_exclusive_start_indices:
start_embeddings = util.batched_index_select(sequence_tensor, span_starts)
end_embeddings = util.batched_index_select(sequence_tensor, span_ends)
else:
# We want `exclusive` span starts, so we remove 1 from the forward span starts
# as the AllenNLP ``SpanField`` is inclusive.
# shape (batch_size, num_spans)
exclusive_span_starts = span_starts - 1
# shape (batch_size, num_spans, 1)
start_sentinel_mask = (exclusive_span_starts == -1).long().unsqueeze(-1)
exclusive_span_starts = exclusive_span_starts * (1 - start_sentinel_mask.squeeze(-1))
# We'll check the indices here at runtime, because it's difficult to debug
# if this goes wrong and it's tricky to get right.
if (exclusive_span_starts < 0).any():
raise ValueError(f"Adjusted span indices must lie inside the the sequence tensor, "
f"but found: exclusive_span_starts: {exclusive_span_starts}.")
start_embeddings = util.batched_index_select(sequence_tensor, exclusive_span_starts)
end_embeddings = util.batched_index_select(sequence_tensor, span_ends)
# We're using sentinels, so we need to replace all the elements which were
# outside the dimensions of the sequence_tensor with the start sentinel.
float_start_sentinel_mask = start_sentinel_mask.float()
start_embeddings = start_embeddings * (1 - float_start_sentinel_mask) \
+ float_start_sentinel_mask * self._start_sentinel
combined_tensors = util.combine_tensors(self._combination, [start_embeddings, end_embeddings])
if self._span_width_embedding is not None:
# Embed the span widths and concatenate to the rest of the representations.
if self._bucket_widths:
span_widths = util.bucket_values(span_ends - span_starts,
num_total_buckets=self._num_width_embeddings)
else:
span_widths = span_ends - span_starts
span_width_embeddings = self._span_width_embedding(span_widths)
return torch.cat([combined_tensors, span_width_embeddings], -1)
if span_indices_mask is not None:
return combined_tensors * span_indices_mask.unsqueeze(-1).float()
return combined_tensors
示例2: forward
# 需要導入模塊: from torch import LongTensor [as 別名]
# 或者: from torch.LongTensor import split [as 別名]
def forward(self,
sequence_tensor: torch.FloatTensor,
span_indices: torch.LongTensor,
sequence_mask: torch.LongTensor = None,
span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor:
# Both of shape (batch_size, sequence_length, embedding_size / 2)
forward_sequence, backward_sequence = sequence_tensor.split(int(self._input_dim / 2), dim=-1)
forward_sequence = forward_sequence.contiguous()
backward_sequence = backward_sequence.contiguous()
# shape (batch_size, num_spans)
span_starts, span_ends = [index.squeeze(-1) for index in span_indices.split(1, dim=-1)]
if span_indices_mask is not None:
span_starts = span_starts * span_indices_mask
span_ends = span_ends * span_indices_mask
# We want `exclusive` span starts, so we remove 1 from the forward span starts
# as the AllenNLP ``SpanField`` is inclusive.
# shape (batch_size, num_spans)
exclusive_span_starts = span_starts - 1
# shape (batch_size, num_spans, 1)
start_sentinel_mask = (exclusive_span_starts == -1).long().unsqueeze(-1)
# We want `exclusive` span ends for the backward direction
# (so that the `start` of the span in that direction is exlusive), so
# we add 1 to the span ends as the AllenNLP ``SpanField`` is inclusive.
exclusive_span_ends = span_ends + 1
if sequence_mask is not None:
# shape (batch_size)
sequence_lengths = util.get_lengths_from_binary_sequence_mask(sequence_mask)
else:
# shape (batch_size), filled with the sequence length size of the sequence_tensor.
sequence_lengths = util.ones_like(sequence_tensor[:, 0, 0]).long() * sequence_tensor.size(1)
# shape (batch_size, num_spans, 1)
end_sentinel_mask = (exclusive_span_ends == sequence_lengths.unsqueeze(-1)).long().unsqueeze(-1)
# As we added 1 to the span_ends to make them exclusive, which might have caused indices
# equal to the sequence_length to become out of bounds, we multiply by the inverse of the
# end_sentinel mask to erase these indices (as we will replace them anyway in the block below).
# The same argument follows for the exclusive span start indices.
exclusive_span_ends = exclusive_span_ends * (1 - end_sentinel_mask.squeeze(-1))
exclusive_span_starts = exclusive_span_starts * (1 - start_sentinel_mask.squeeze(-1))
# We'll check the indices here at runtime, because it's difficult to debug
# if this goes wrong and it's tricky to get right.
if (exclusive_span_starts < 0).any() or (exclusive_span_ends > sequence_lengths.unsqueeze(-1)).any():
raise ValueError(f"Adjusted span indices must lie inside the length of the sequence tensor, "
f"but found: exclusive_span_starts: {exclusive_span_starts}, "
f"exclusive_span_ends: {exclusive_span_ends} for a sequence tensor with lengths "
f"{sequence_lengths}.")
# Forward Direction: start indices are exclusive. Shape (batch_size, num_spans, input_size / 2)
forward_start_embeddings = util.batched_index_select(forward_sequence, exclusive_span_starts)
# Forward Direction: end indices are inclusive, so we can just use span_ends.
# Shape (batch_size, num_spans, input_size / 2)
forward_end_embeddings = util.batched_index_select(forward_sequence, span_ends)
# Backward Direction: The backward start embeddings use the `forward` end
# indices, because we are going backwards.
# Shape (batch_size, num_spans, input_size / 2)
backward_start_embeddings = util.batched_index_select(backward_sequence, exclusive_span_ends)
# Backward Direction: The backward end embeddings use the `forward` start
# indices, because we are going backwards.
# Shape (batch_size, num_spans, input_size / 2)
backward_end_embeddings = util.batched_index_select(backward_sequence, span_starts)
if self._use_sentinels:
# If we're using sentinels, we need to replace all the elements which were
# outside the dimensions of the sequence_tensor with either the start sentinel,
# or the end sentinel.
float_end_sentinel_mask = end_sentinel_mask.float()
float_start_sentinel_mask = start_sentinel_mask.float()
forward_start_embeddings = forward_start_embeddings * (1 - float_start_sentinel_mask) \
+ float_start_sentinel_mask * self._start_sentinel
backward_start_embeddings = backward_start_embeddings * (1 - float_end_sentinel_mask) \
+ float_end_sentinel_mask * self._end_sentinel
# Now we combine the forward and backward spans in the manner specified by the
# respective combinations and concatenate these representations.
# Shape (batch_size, num_spans, forward_combination_dim)
forward_spans = util.combine_tensors(self._forward_combination,
[forward_start_embeddings, forward_end_embeddings])
# Shape (batch_size, num_spans, backward_combination_dim)
backward_spans = util.combine_tensors(self._backward_combination,
[backward_start_embeddings, backward_end_embeddings])
# Shape (batch_size, num_spans, forward_combination_dim + backward_combination_dim)
span_embeddings = torch.cat([forward_spans, backward_spans], -1)
if self._span_width_embedding is not None:
# Embed the span widths and concatenate to the rest of the representations.
if self._bucket_widths:
span_widths = util.bucket_values(span_ends - span_starts,
num_total_buckets=self._num_width_embeddings)
else:
span_widths = span_ends - span_starts
span_width_embeddings = self._span_width_embedding(span_widths)
#.........這裏部分代碼省略.........
示例3: forward
# 需要導入模塊: from torch import LongTensor [as 別名]
# 或者: from torch.LongTensor import split [as 別名]
def forward(self,
sequence_tensor: torch.FloatTensor,
span_indices: torch.LongTensor,
sequence_mask: torch.LongTensor = None,
span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor:
# both of shape (batch_size, num_spans, 1)
span_starts, span_ends = span_indices.split(1, dim=-1)
# shape (batch_size, num_spans, 1)
# These span widths are off by 1, because the span ends are `inclusive`.
span_widths = span_ends - span_starts
# We need to know the maximum span width so we can
# generate indices to extract the spans from the sequence tensor.
# These indices will then get masked below, such that if the length
# of a given span is smaller than the max, the rest of the values
# are masked.
max_batch_span_width = span_widths.max().item() + 1
# shape (batch_size, sequence_length, 1)
global_attention_logits = self._global_attention(sequence_tensor)
# Shape: (1, 1, max_batch_span_width)
max_span_range_indices = util.get_range_vector(max_batch_span_width,
util.get_device_of(sequence_tensor)).view(1, 1, -1)
# Shape: (batch_size, num_spans, max_batch_span_width)
# This is a broadcasted comparison - for each span we are considering,
# we are creating a range vector of size max_span_width, but masking values
# which are greater than the actual length of the span.
#
# We're using <= here (and for the mask below) because the span ends are
# inclusive, so we want to include indices which are equal to span_widths rather
# than using it as a non-inclusive upper bound.
span_mask = (max_span_range_indices <= span_widths).float()
raw_span_indices = span_ends - max_span_range_indices
# We also don't want to include span indices which are less than zero,
# which happens because some spans near the beginning of the sequence
# have an end index < max_batch_span_width, so we add this to the mask here.
span_mask = span_mask * (raw_span_indices >= 0).float()
span_indices = torch.nn.functional.relu(raw_span_indices.float()).long()
# Shape: (batch_size * num_spans * max_batch_span_width)
flat_span_indices = util.flatten_and_batch_shift_indices(span_indices, sequence_tensor.size(1))
# Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim)
span_embeddings = util.batched_index_select(sequence_tensor, span_indices, flat_span_indices)
# Shape: (batch_size, num_spans, max_batch_span_width)
span_attention_logits = util.batched_index_select(global_attention_logits,
span_indices,
flat_span_indices).squeeze(-1)
# Shape: (batch_size, num_spans, max_batch_span_width)
span_attention_weights = util.masked_softmax(span_attention_logits, span_mask)
# Do a weighted sum of the embedded spans with
# respect to the normalised attention distributions.
# Shape: (batch_size, num_spans, embedding_dim)
attended_text_embeddings = util.weighted_sum(span_embeddings, span_attention_weights)
if span_indices_mask is not None:
# Above we were masking the widths of spans with respect to the max
# span width in the batch. Here we are masking the spans which were
# originally passed in as padding.
return attended_text_embeddings * span_indices_mask.unsqueeze(-1).float()
return attended_text_embeddings