當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch squeeze用法及代碼示例


本文簡要介紹python語言中 torch.squeeze 的用法。

用法:

torch.squeeze(input, dim=None, *, out=None) → Tensor

參數

  • input(Tensor) -輸入張量。

  • dim(int,可選的) -如果給定,輸入將隻在這個維度上被壓縮

關鍵字參數

out(Tensor,可選的) -輸出張量。

返回一個張量,其中刪除了大小為 1input 的所有維度。

例如,如果 input 的形狀為: ,那麽 out 張量的形狀為:

當給定dim 時,僅在給定維度上進行擠壓操作。如果 input 的形狀為: squeeze(input, 0) 保持張量不變,但 squeeze(input, 1) 會將張量壓縮到形狀

注意

返回的張量與輸入張量共享存儲,因此更改其中一個的內容將更改另一個的內容。

警告

如果張量具有大小為 1 的批量維度,那麽 squeeze(input) 也會刪除批量維度,這可能會導致意外錯誤。

例子:

>>> x = torch.zeros(2, 1, 2, 1, 2)
>>> x.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x)
>>> y.size()
torch.Size([2, 2, 2])
>>> y = torch.squeeze(x, 0)
>>> y.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x, 1)
>>> y.size()
torch.Size([2, 2, 1, 2])

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.squeeze。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。