本文整理汇总了Python中torch.Tensor.index_select方法的典型用法代码示例。如果您正苦于以下问题:Python Tensor.index_select方法的具体用法?Python Tensor.index_select怎么用?Python Tensor.index_select使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.Tensor
的用法示例。
在下文中一共展示了Tensor.index_select方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: flattened_index_select
# 需要导入模块: from torch import Tensor [as 别名]
# 或者: from torch.Tensor import index_select [as 别名]
def flattened_index_select(target: torch.Tensor,
indices: torch.LongTensor) -> torch.Tensor:
"""
The given ``indices`` of size ``(set_size, subset_size)`` specifies subsets of the ``target``
that each of the set_size rows should select. The `target` has size
``(batch_size, sequence_length, embedding_size)``, and the resulting selected tensor has size
``(batch_size, set_size, subset_size, embedding_size)``.
Parameters
----------
target : ``torch.Tensor``, required.
A Tensor of shape (batch_size, sequence_length, embedding_size).
indices : ``torch.LongTensor``, required.
A LongTensor of shape (set_size, subset_size). All indices must be < sequence_length
as this tensor is an index into the sequence_length dimension of the target.
Returns
-------
selected : ``torch.Tensor``, required.
A Tensor of shape (batch_size, set_size, subset_size, embedding_size).
"""
if indices.dim() != 2:
raise ConfigurationError("Indices passed to flattened_index_select had shape {} but "
"only 2 dimensional inputs are supported.".format(indices.size()))
# Shape: (batch_size, set_size * subset_size, embedding_size)
flattened_selected = target.index_select(1, indices.view(-1))
# Shape: (batch_size, set_size, subset_size, embedding_size)
selected = flattened_selected.view(target.size(0), indices.size(0), indices.size(1), -1)
return selected
示例2: sort_batch_by_length
# 需要导入模块: from torch import Tensor [as 别名]
# 或者: from torch.Tensor import index_select [as 别名]
def sort_batch_by_length(tensor: torch.Tensor, sequence_lengths: torch.Tensor):
"""
Sort a batch first tensor by some specified lengths.
Parameters
----------
tensor : torch.FloatTensor, required.
A batch first Pytorch tensor.
sequence_lengths : torch.LongTensor, required.
A tensor representing the lengths of some dimension of the tensor which
we want to sort by.
Returns
-------
sorted_tensor : torch.FloatTensor
The original tensor sorted along the batch dimension with respect to sequence_lengths.
sorted_sequence_lengths : torch.LongTensor
The original sequence_lengths sorted by decreasing size.
restoration_indices : torch.LongTensor
Indices into the sorted_tensor such that
``sorted_tensor.index_select(0, restoration_indices) == original_tensor``
permuation_index : torch.LongTensor
The indices used to sort the tensor. This is useful if you want to sort many
tensors using the same ordering.
"""
if not isinstance(tensor, torch.Tensor) or not isinstance(sequence_lengths, torch.Tensor):
raise ConfigurationError("Both the tensor and sequence lengths must be torch.Tensors.")
sorted_sequence_lengths, permutation_index = sequence_lengths.sort(0, descending=True)
sorted_tensor = tensor.index_select(0, permutation_index)
index_range = sequence_lengths.new_tensor(torch.arange(0, len(sequence_lengths)))
# This is the equivalent of zipping with index, sorting by the original
# sequence lengths and returning the now sorted indices.
_, reverse_mapping = permutation_index.sort(0, descending=False)
restoration_indices = index_range.index_select(0, reverse_mapping)
return sorted_tensor, sorted_sequence_lengths, restoration_indices, permutation_index