当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch repeat_interleave用法及代码示例


本文简要介绍python语言中 torch.repeat_interleave 的用法。

用法:

torch.repeat_interleave(input, repeats, dim=None, *, output_size=None) → Tensor

参数

  • input(Tensor) -输入张量。

  • repeats(Tensor或者int) -每个元素的重复次数。广播重复以适应给定轴的形状。

  • dim(int,可选的) -沿其重复值的维度。默认情况下,使用扁平化的输入数组,并返回一个扁平的输出数组。

关键字参数

output_size(int,可选的) -给定轴的总输出大小(例如重复总和)。如果给定,它将避免计算张量的输出形状所需的流同步。

返回

除了沿给定轴外,与输入具有相同形状的重复张量。

返回类型

Tensor

重复张量的元素。

警告

这与 torch.Tensor.repeat() 不同,但类似于 numpy.repeat

例子:

>>> x = torch.tensor([1, 2, 3])
>>> x.repeat_interleave(2)
tensor([1, 1, 2, 2, 3, 3])
>>> y = torch.tensor([[1, 2], [3, 4]])
>>> torch.repeat_interleave(y, 2)
tensor([1, 1, 2, 2, 3, 3, 4, 4])
>>> torch.repeat_interleave(y, 3, dim=1)
tensor([[1, 1, 1, 2, 2, 2],
        [3, 3, 3, 4, 4, 4]])
>>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0)
tensor([[1, 2],
        [3, 4],
        [3, 4]])
>>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0, output_size=3)
tensor([[1, 2],
        [3, 4],
        [3, 4]])
torch.repeat_interleave(repeats, *, output_size=None) → Tensor

如果 repeatstensor([n1, n2, n3, …]) ,那么输出将是 tensor([0, 0, …, 1, 1, …, 2, 2, …, …]) 其中 0 出现 n1 次,1 出现 n2 次,2 出现 n3 次等。

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.repeat_interleave。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。