本文簡要介紹python語言中 torch.nn.Unfold
的用法。
用法:
class torch.nn.Unfold(kernel_size, dilation=1, padding=0, stride=1)
從批處理輸入張量中提取滑動局部塊。
考慮一個形狀為
input
張量,其中 是批處理維度, 是通道維度,而 表示任意空間維度。此操作將input
的空間維度內的每個滑動kernel_size
大小的塊展平為形狀為 的 3-Doutput
張量的列(即最後一維),其中 是總數每個塊內的值的數量(一個塊具有 空間位置,每個空間位置都包含一個 通道向量),而 是此類塊的總數: 的批處理其中
input
(上麵的 )的空間維度構成,而 覆蓋所有空間維度。 由因此,在最後一個維度(列維度)索引
output
會給出某個塊內的所有值。padding
、stride
和dilation
參數指定如何檢索滑動塊。stride
控製滑塊的步幅。padding
控製重塑前每個維度的padding
點的兩側隱含zero-paddings 的數量。dilation
控製內核點之間的間距;也稱為 à trous 算法。很難說明,但是這個link 很好地可視化了dilation
的作用。
如果
kernel_size
、dilation
、padding
或stride
是 int 或長度為 1 的元組,則它們的值將在所有空間維度上複製。對於兩個輸入空間維度的情況,此操作有時稱為
im2col
。
注意
Fold
通過對所有包含塊中的所有值求和來計算生成的大張量中的每個組合值。Unfold
通過從大張量複製來提取局部塊中的值。因此,如果塊重疊,它們就不是彼此相反的。一般來說,折疊和展開操作的關係如下。考慮使用相同參數創建的
Fold
和Unfold
實例:>>> fold_params = dict(kernel_size=..., dilation=..., padding=..., stride=...) >>> fold = nn.Fold(output_size=..., **fold_params) >>> unfold = nn.Unfold(**fold_params)
然後對於任何(支持的)
input
張量,以下等式成立:fold(unfold(input)) == divisor * input
其中
divisor
是僅取決於input
的形狀和 dtype 的張量:>>> input_ones = torch.ones(input.shape, dtype=input.dtype) >>> divisor = fold(unfold(input_ones))
當
divisor
張量不包含零元素時,fold
和unfold
運算是彼此的逆運算(直到常數除數)。警告
目前,僅支持 4-D 輸入張量(批量 image-like 張量)。
- 形狀:
輸入:
輸出: 如上所述
例子:
>>> unfold = nn.Unfold(kernel_size=(2, 3)) >>> input = torch.randn(2, 5, 3, 4) >>> output = unfold(input) >>> # each patch contains 30 values (2x3=6 vectors, each of 5 channels) >>> # 4 blocks (2x3 kernels) in total in the 3x4 input >>> output.size() torch.Size([2, 30, 4]) >>> # Convolution is equivalent with Unfold + Matrix Multiplication + Fold (or view to output shape) >>> inp = torch.randn(1, 3, 10, 12) >>> w = torch.randn(2, 3, 4, 5) >>> inp_unf = torch.nn.functional.unfold(inp, (4, 5)) >>> out_unf = inp_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2) >>> out = torch.nn.functional.fold(out_unf, (7, 8), (1, 1)) >>> # or equivalently (and avoiding a copy), >>> # out = out_unf.view(1, 2, 7, 8) >>> (torch.nn.functional.conv2d(inp, w) - out).abs().max() tensor(1.9073e-06)
參數:
相關用法
- Python PyTorch Unflatten用法及代碼示例
- Python PyTorch UnBatcher用法及代碼示例
- Python PyTorch Uniform用法及代碼示例
- Python PyTorch UnZipper用法及代碼示例
- Python PyTorch Upsample用法及代碼示例
- Python PyTorch UpsamplingBilinear2d用法及代碼示例
- Python PyTorch UpsamplingNearest2d用法及代碼示例
- Python PyTorch frexp用法及代碼示例
- Python PyTorch jvp用法及代碼示例
- Python PyTorch cholesky用法及代碼示例
- Python PyTorch vdot用法及代碼示例
- Python PyTorch ELU用法及代碼示例
- Python PyTorch ScaledDotProduct.__init__用法及代碼示例
- Python PyTorch gumbel_softmax用法及代碼示例
- Python PyTorch get_tokenizer用法及代碼示例
- Python PyTorch saved_tensors_hooks用法及代碼示例
- Python PyTorch positive用法及代碼示例
- Python PyTorch renorm用法及代碼示例
- Python PyTorch AvgPool2d用法及代碼示例
- Python PyTorch MaxUnpool3d用法及代碼示例
- Python PyTorch Bernoulli用法及代碼示例
- Python PyTorch Tensor.unflatten用法及代碼示例
- Python PyTorch Sigmoid用法及代碼示例
- Python PyTorch Tensor.register_hook用法及代碼示例
- Python PyTorch ShardedEmbeddingBagCollection.named_parameters用法及代碼示例
注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.nn.Unfold。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。