当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch take_along_dim用法及代码示例


本文简要介绍python语言中 torch.take_along_dim 的用法。

用法:

torch.take_along_dim(input, indices, dim, *, out=None) → Tensor

参数

  • input(Tensor) -输入张量。

  • indices(张量) -input 中的索引。必须有长 dtype。

  • dim(int) -要选择的尺寸。

关键字参数

out(Tensor,可选的) -输出张量。

沿着给定的 dimindices 的一维索引处从 input 中选择值。

沿维度返回索引的函数,如 torch.argmax() torch.argsort() ,旨在与此函数一起使用。请参阅下面的示例。

注意

这个函数类似于 NumPy 的 take_along_axis 。另见 torch.gather()

例子:

>>> t = torch.tensor([[10, 30, 20], [60, 40, 50]])
>>> max_idx = torch.argmax(t)
>>> torch.take_along_dim(t, max_idx)
tensor([60])
>>> sorted_idx = torch.argsort(t, dim=1)
>>> torch.take_along_dim(t, sorted_idx, dim=1)
tensor([[10, 20, 30],
        [40, 50, 60]])

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.take_along_dim。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。