当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch index_select用法及代码示例


本文简要介绍python语言中 torch.index_select 的用法。

用法:

torch.index_select(input, dim, index, *, out=None) → Tensor

参数

  • input(Tensor) -输入张量。

  • dim(int) -我们索引的维度

  • index(IntTensor或者LongTensor) -包含要索引的索引的一维张量

关键字参数

out(Tensor,可选的) -输出张量。

返回一个新张量,该张量使用 index 中的条目沿维度 dim 索引 input 张量,该条目是 LongTensor

返回的张量与原始张量 (input) 具有相同的维数。第 dim 维度的大小与 index 的长度相同;其他维度的大小与原始张量中的大小相同。

注意

返回的张量不是使用与原始张量相同的存储。如果out如果形状与预期不同,我们会默默地将其更改为正确的形状,并在必要时重新分配底层存储。

例子:

>>> x = torch.randn(3, 4)
>>> x
tensor([[ 0.1427,  0.0231, -0.5414, -1.0009],
        [-0.4664,  0.2647, -0.1228, -1.1068],
        [-1.1734, -0.6571,  0.7230, -0.6004]])
>>> indices = torch.tensor([0, 2])
>>> torch.index_select(x, 0, indices)
tensor([[ 0.1427,  0.0231, -0.5414, -1.0009],
        [-1.1734, -0.6571,  0.7230, -0.6004]])
>>> torch.index_select(x, 1, indices)
tensor([[ 0.1427, -0.5414],
        [-0.4664, -0.1228],
        [-1.1734,  0.7230]])

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.index_select。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。