當前位置: 首頁>>技術問答>>正文


Torch張量的view方法有什麽作用?

我對以下代碼片段中的方法view()感到困惑,不知道這個view方法起什麽作用。

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool  = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1   = nn.Linear(16*5*5, 120)
        self.fc2   = nn.Linear(120, 84)
        self.fc3   = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16*5*5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()

有困惑問題在下麵這行。

x = x.view(-1, 16*5*5)

tensor.view()方法有什麽作用?我已經在許多地方看到它的用法,但我無法理解它如何解析它的’參數。

如果我將負值作為參數提供給view()函數會發生什麽?例如,如果我調用tensor_variable.view(1, 1, -1)會怎麽樣?

請舉例說明view()功能的主要原理。

最佳解釋

view函數旨在reshape張量形狀。

假設你有一個張量

import torch
a = torch.range(1, 16)

其中a是一個張量,有16個元素,從1到16(包括在內)。如果你想重塑這個張量使其成為4 x 4張量,那麽你可以使用

a = a.view(4, 4)

現在a將是一個4 x 4張量。(注意,重塑後元素的總數需要保持不變。將張量a重新整形為3 x 5張量是不合適的)

參數-1是什麽意思?

如果你不知道你想要多少行,但確定列數,那麽你可以將行數設置為-1(你可以將它擴展到具有更多維度的張量。隻有一個軸值可以是-1)。這是告訴係統Library:給我一個具有這麽多列的張量,並計算實現這一點所需的適當行數。

這可以在您上麵給出的神經網絡代碼中看到。在前向功能中的x = self.pool(F.relu(self.conv2(x)))行之後,您將擁有一個16深度的特征圖。您必須將其展平以將其提供給全連接的圖層。所以告訴pytorch重新塑造你獲得的張量,使其具有特定數量的列並讓它自己決定行數。

從numpy和pytorch之間的相似性來看,view類似於numpy的reshape函數。

補充解釋

讓我們舉一些例子,從簡到難。

  1. view方法返回張量與self張量相同的數據(這意味著返回的張量具有相同數量的元素),但具有不同的形狀。例如:

    a = torch.arange(1, 17)  # a's shape is (16,)
    
    a.view(4, 4) # output below
      1   2   3   4
      5   6   7   8
      9  10  11  12
     13  14  15  16
    [torch.FloatTensor of size 4x4]
    
    a.view(2, 2, 4) # output below
    (0 ,.,.) = 
    1   2   3   4
    5   6   7   8
    
    (1 ,.,.) = 
     9  10  11  12
    13  14  15  16
    [torch.FloatTensor of size 2x2x4]
    
  2. 假設-1不是其中一個參數,當將它們相乘時,結果必須等於張量中的元素數。如果您執行:a.view(3, 3),它將引發RuntimeError,因為對於具有16個元素的輸入,形狀(3 x 3)無效。換句話說:3 x 3不等於16但是9。

  3. 您可以使用-1作為傳遞給函數的參數之一,但隻能使用一次。所發生的事情是該方法將自動計算維度。例如,a.view(2, -1, 4)等同於a.view(2, 2, 4)。 [16 /(2 x 4)= 2]

  4. 特別請注意,返回的張量共享相同的數據。如果您相對”view”中進行更改,則需要更改原始張量數據:

    b = a.view(4, 4)
    b[0, 2] = 2
    a[2] == 3.0
    False
    
  5. 現在,對於更複雜的用例。文檔說每個新的視圖維度必須是原始維度的子空間,或者隻有跨度d,d + 1,…,d + k滿足以下contiguity-like條件,對於所有i = 0,… ,k – 1,stride [i] = stride [i + 1] x size [i + 1]。否則,需要在可以查看張量之前調用contiguous()。例如:

    a = torch.rand(5, 4, 3, 2) # size (5, 4, 3, 2)
    a_t = a.permute(0, 2, 3, 1) # size (5, 3, 2, 4)
    
    # The commented line below will raise a RuntimeError, because one dimension
    # spans across two contiguous subspaces
    # a_t.view(-1, 4)
    
    # instead do:
    a_t.contiguous().view(-1, 4)
    
    # To see why the first one does not work and the second does,
    # compare a.stride() and a_t.stride()
    a.stride() # (24, 6, 2, 1)
    a_t.stride() # (24, 2, 1, 6)
    

    請注意,對於a_t,stride [0]!= stride [1] x size [1],因為24!= 2 x 3

參考資料

本文由《純淨天空》出品。文章地址: https://vimsky.com/zh-tw/article/3888.html,未經允許,請勿轉載。