當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch Conv1d用法及代碼示例


本文簡要介紹python語言中 torch.nn.Conv1d 的用法。

用法:

class torch.nn.Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None)

參數

  • in_channels(int) -輸入圖像中的通道數

  • out_channels(int) -卷積產生的通道數

  • kernel_size(int或者tuple) -卷積核的大小

  • stride(int或者tuple,可選的) -卷積的步幅。默認值:1

  • padding(int,tuple或者str,可選的) -填充添加到輸入的兩側。默認值:0

  • padding_mode(string,可選的) -'zeros''reflect''replicate''circular' 。默認值:'zeros'

  • dilation(int或者tuple,可選的) -內核元素之間的間距。默認值:1

  • groups(int,可選的) -從輸入通道到輸出通道的阻塞連接數。默認值:1

  • bias(bool,可選的) -如果 True ,則向輸出添加可學習的偏差。默認值:True

變量

  • ~Conv1d.weight(Tensor) -形狀模塊的可學習權重(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}}, \text{kernel\_size}) .這些權重的值是從\mathcal{U}(-\sqrt{k}, \sqrt{k}) 其中k = \frac{groups}{C_\text{in} * \text{kernel\_size}}

  • ~Conv1d.bias(Tensor) -形狀模塊的可學習偏差(out_channels)。如果 biasTrue ,則這些權重的值是從 中采樣的,其中

對由多個輸入平麵組成的輸入信號應用一維卷積。

在最簡單的情況下,輸入大小為 和輸出 的層的輸出值可以精確地說明為:

其中 是有效的cross-correlation運算符, 是批量大小, 表示通道數, 是信號序列的長度。

該模塊支持 TensorFloat32。

  • stride 控製互相關、單個數字或單元素元組的步長。

  • padding 控製應用於輸入的填充量。它可以是字符串 {‘valid’, ‘same’} 或整數元組,給出兩側應用的隱式填充量。

  • dilation控製內核點之間的間距;也稱為 à trous 算法。很難說明,但是這個link 很好地可視化了dilation 的作用。

  • groups 控製輸入和輸出之間的連接。 in_channelsout_channels 都必須能被 groups 整除。例如,

    • 在 groups=1 時,所有輸入都卷積到所有輸出。

    • 在 groups=2 時,該操作等效於並排有兩個卷積層,每個卷積層看到一半的輸入通道並產生一半的輸出通道,並且隨後將兩者連接起來。

    • 在 groups= in_channels ,每個輸入通道都與自己的一組過濾器(大小為 )進行卷積。

注意

groups == in_channelsout_channels == K * in_channels 時,其中 K 是正整數,此操作也稱為 “depthwise convolution”。

換句話說,對於大小為 的輸入,可以使用參數 執行具有深度乘數 K 的深度卷積。

注意

在某些情況下,當在 CUDA 設備上給定張量並使用 CuDNN 時,此運算符可能會選擇非確定性算法來提高性能。如果這是不可取的,您可以嘗試通過設置 torch.backends.cudnn.deterministic = True 來使操作具有確定性(可能以性能為代價)。有關詳細信息,請參閱重現性。

注意

padding='valid' 與無填充相同。 padding='same' 填充輸入,使輸出具有作為輸入的形狀。但是,此模式不支持 1 以外的任何步幅值。

形狀:
  • 輸入:

  • 輸出:

例子:

>>> m = nn.Conv1d(16, 33, 3, stride=2)
>>> input = torch.randn(20, 16, 50)
>>> output = m(input)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.nn.Conv1d。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。