當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch GroupNorm用法及代碼示例


本文簡要介紹python語言中 torch.nn.GroupNorm 的用法。

用法:

class torch.nn.GroupNorm(num_groups, num_channels, eps=1e-05, affine=True, device=None, dtype=None)

參數

  • num_groups(int) -將通道分成的組數

  • num_channels(int) -輸入中預期的通道數

  • eps-加到分母上的值,以保證數值穩定性。默認值:1e-5

  • affine-一個布爾值,當設置為 True 時,此模塊具有可學習的每通道仿射參數,初始化為 1(用於權重)和 0(用於偏差)。默認值:True

如論文 Group Normalization 中所述,對小批量輸入應用組規範化

輸入通道分為num_groups 組,每個組包含num_channels / num_groups 通道。分別計算每組的平均值和標準差。 是大小為 num_channels 的可學習的每通道仿射變換參數向量,如果 affineTrue 。標準差是通過偏置估計器計算的,相當於 torch.var(input, unbiased=False)

該層使用從訓練和評估模式中的輸入數據計算的統計數據。

形狀:
  • 輸入: 其中

  • 輸出: (與輸入的形狀相同)

例子:

>>> input = torch.randn(20, 6, 10, 10)
>>> # Separate 6 channels into 3 groups
>>> m = nn.GroupNorm(3, 6)
>>> # Separate 6 channels into 6 groups (equivalent with InstanceNorm)
>>> m = nn.GroupNorm(6, 6)
>>> # Put all 6 channels into a single group (equivalent with LayerNorm)
>>> m = nn.GroupNorm(1, 6)
>>> # Activating the module
>>> output = m(input)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.nn.GroupNorm。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。