本文简要介绍python语言中 torch.nn.Unfold
的用法。
用法:
class torch.nn.Unfold(kernel_size, dilation=1, padding=0, stride=1)
从批处理输入张量中提取滑动局部块。
考虑一个形状为
input
张量,其中 是批处理维度, 是通道维度,而 表示任意空间维度。此操作将input
的空间维度内的每个滑动kernel_size
大小的块展平为形状为 的 3-Doutput
张量的列(即最后一维),其中 是总数每个块内的值的数量(一个块具有 空间位置,每个空间位置都包含一个 通道向量),而 是此类块的总数: 的批处理其中
input
(上面的 )的空间维度构成,而 覆盖所有空间维度。 由因此,在最后一个维度(列维度)索引
output
会给出某个块内的所有值。padding
、stride
和dilation
参数指定如何检索滑动块。stride
控制滑块的步幅。padding
控制重塑前每个维度的padding
点的两侧隐含zero-paddings 的数量。dilation
控制内核点之间的间距;也称为 à trous 算法。很难说明,但是这个link 很好地可视化了dilation
的作用。
如果
kernel_size
、dilation
、padding
或stride
是 int 或长度为 1 的元组,则它们的值将在所有空间维度上复制。对于两个输入空间维度的情况,此操作有时称为
im2col
。
注意
Fold
通过对所有包含块中的所有值求和来计算生成的大张量中的每个组合值。Unfold
通过从大张量复制来提取局部块中的值。因此,如果块重叠,它们就不是彼此相反的。一般来说,折叠和展开操作的关系如下。考虑使用相同参数创建的
Fold
和Unfold
实例:>>> fold_params = dict(kernel_size=..., dilation=..., padding=..., stride=...) >>> fold = nn.Fold(output_size=..., **fold_params) >>> unfold = nn.Unfold(**fold_params)
然后对于任何(支持的)
input
张量,以下等式成立:fold(unfold(input)) == divisor * input
其中
divisor
是仅取决于input
的形状和 dtype 的张量:>>> input_ones = torch.ones(input.shape, dtype=input.dtype) >>> divisor = fold(unfold(input_ones))
当
divisor
张量不包含零元素时,fold
和unfold
运算是彼此的逆运算(直到常数除数)。警告
目前,仅支持 4-D 输入张量(批量 image-like 张量)。
- 形状:
输入:
输出: 如上所述
例子:
>>> unfold = nn.Unfold(kernel_size=(2, 3)) >>> input = torch.randn(2, 5, 3, 4) >>> output = unfold(input) >>> # each patch contains 30 values (2x3=6 vectors, each of 5 channels) >>> # 4 blocks (2x3 kernels) in total in the 3x4 input >>> output.size() torch.Size([2, 30, 4]) >>> # Convolution is equivalent with Unfold + Matrix Multiplication + Fold (or view to output shape) >>> inp = torch.randn(1, 3, 10, 12) >>> w = torch.randn(2, 3, 4, 5) >>> inp_unf = torch.nn.functional.unfold(inp, (4, 5)) >>> out_unf = inp_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2) >>> out = torch.nn.functional.fold(out_unf, (7, 8), (1, 1)) >>> # or equivalently (and avoiding a copy), >>> # out = out_unf.view(1, 2, 7, 8) >>> (torch.nn.functional.conv2d(inp, w) - out).abs().max() tensor(1.9073e-06)
参数:
相关用法
- Python PyTorch Unflatten用法及代码示例
- Python PyTorch UnBatcher用法及代码示例
- Python PyTorch Uniform用法及代码示例
- Python PyTorch UnZipper用法及代码示例
- Python PyTorch Upsample用法及代码示例
- Python PyTorch UpsamplingBilinear2d用法及代码示例
- Python PyTorch UpsamplingNearest2d用法及代码示例
- Python PyTorch frexp用法及代码示例
- Python PyTorch jvp用法及代码示例
- Python PyTorch cholesky用法及代码示例
- Python PyTorch vdot用法及代码示例
- Python PyTorch ELU用法及代码示例
- Python PyTorch ScaledDotProduct.__init__用法及代码示例
- Python PyTorch gumbel_softmax用法及代码示例
- Python PyTorch get_tokenizer用法及代码示例
- Python PyTorch saved_tensors_hooks用法及代码示例
- Python PyTorch positive用法及代码示例
- Python PyTorch renorm用法及代码示例
- Python PyTorch AvgPool2d用法及代码示例
- Python PyTorch MaxUnpool3d用法及代码示例
- Python PyTorch Bernoulli用法及代码示例
- Python PyTorch Tensor.unflatten用法及代码示例
- Python PyTorch Sigmoid用法及代码示例
- Python PyTorch Tensor.register_hook用法及代码示例
- Python PyTorch ShardedEmbeddingBagCollection.named_parameters用法及代码示例
注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.nn.Unfold。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。