当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch broadcast_shapes用法及代码示例


本文简要介绍python语言中 torch.broadcast_shapes 的用法。

用法:

torch.broadcast_shapes(*shapes) → Size

参数

*shapes(torch.Size) -张量的形状。

返回

与所有输入形状兼容的形状。

返回类型

形状(火炬.尺寸)

抛出

RuntimeError - 如果形状不兼容。

类似于 broadcast_tensors() ,但用于形状。

这相当于torch.broadcast_tensors(*map(torch.empty, shapes))[0].shape但避免了创建中间张量的需要。这对于广播常见批处理形状但最右边形状不同的张量很有用,例如用协方差矩阵广播平均向量。

例子:

>>> torch.broadcast_shapes((2,), (3, 1), (1, 1, 1))
torch.Size([1, 3, 2])

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.broadcast_shapes。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。