zi2zi-pytorch UnetSkipConnectionBlock 調整圖片解析度,256×256 改成 400×400, 在 forward() 的 torch.cat([x, dec], 1) 遇到問題, x 的 shape 是 torch.Size([8, 512, 3, 3]), dec 的 shape 是 torch.Size([8, 512, 2, 2]), 無法使用 torch.cat() 合併

錯誤訊息:

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 3 but got size 2 for tensor number 1 in the list.

這個錯誤訊息指出在 zi2zi-pytorch 專案中,當您將輸入圖片解析度從 256×256 修改為 400×400 後,UnetSkipConnectionBlockforward() 方法在最內層 (innermost) 遇到了 torch.cat() 無法合併張量的問題。 錯誤訊息 RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 3 but got size 2 for tensor number 1 in the list. 明確地告訴我們,torch.cat() 期望除了通道維度 (dimension 1) 以外,其他維度的尺寸必須完全一致。然而,在您的情況下,x 張量的空間維度是 3×3,而 dec 張量的空間維度是 2×2,尺寸不匹配,導致 torch.cat() 無法執行。

問題分析:解析度變更導致空間維度不匹配

這個問題的根本原因在於,當您將輸入圖片解析度從原本設計的 256×256 修改為 400×400 時,UNet 架構中各層的特徵圖 (feature map) 尺寸也會隨之改變。 UnetSkipConnectionBlock 的設計原先可能是針對特定輸入尺寸 (如 256×256) 進行優化,當輸入尺寸變更後,編碼器路徑 (encoder path) 和解碼器路徑 (decoder path) 在最內層 (innermost) 的輸出特徵圖空間尺寸產生了不一致,具體來說,一個是 3×3,另一個是 2×2。

UnetSkipConnectionBlockforward() 方法中,torch.cat([x, dec], 1) 這行程式碼的作用是將 跳躍連接 (skip connection) 的特徵圖 x (來自編碼器路徑) 和目前解碼器路徑區塊的特徵圖 dec 沿著通道維度 (dimension 1) 串聯在一起。跳躍連接的設計目的是為了融合不同層級的特徵資訊,提升生成品質。然而,進行 torch.cat() 之前,必須確保除了通道維度外,其他維度 (在這裡主要是指空間維度,即高度和寬度) 的尺寸是相同的。

解法:調整 dec 的空間尺寸以匹配 x

最直接的解決方案是在執行 torch.cat() 之前,調整 dec (或 x,但通常調整 dec 更為合理) 的空間尺寸,使其與 x 的空間尺寸一致。 由於 x 的 shape 是 torch.Size([8, 512, 3, 3])dec 的 shape 是 torch.Size([8, 512, 2, 2]),我們需要將 dec 的空間尺寸調整為 3×3,使其與 x 的空間尺寸匹配。

具體步驟:使用 nn.functional.interpolate 調整 dec 的尺寸

我們可以使用 torch.nn.functional.interpolate 函數來對 dec 進行尺寸調整 (上採樣)。 nn.functional.interpolate 函數提供了多種插值模式,例如 nearestlinearbilinearbicubictrilinear 等。 針對圖像特徵圖,bilinear 插值通常是一個不錯的選擇,它可以在調整尺寸的同時,保持圖像的平滑性。

修改 UnetSkipConnectionBlockforward() 方法 (innermost 情況)

您需要修改 zi2zi-pytorch 專案中 generators.py 檔案的 UnetSkipConnectionBlock 類別的 forward() 方法。 找到 innermost 條件下的程式碼區塊,並在 torch.cat([x, dec], 1) 之前,加入使用 nn.functional.interpolate 調整 dec 尺寸的程式碼。

以下是修改後的 forward() 方法 (僅展示 innermost 條件下的修改部分,完整程式碼請參考 zi2zi-pytorch 專案原始碼):

    def forward(self, x):
        if self.innermost: # 針對 innermost 的情況
            upconv = nn.ConvTranspose2d(self.inner_channels * 2, self.outer_channels,
                                        kernel_size=4, stride=2,
                                        padding=1)
            downconv = nn.Conv2d(self.inner_channels, self.outer_channels, kernel_size=4,
                                 stride=2, padding=1)
            norm = self.norm_layer(self.outer_channels)
            act = nn.ReLU(True)
            self.down = nn.Sequential(*[act, downconv, norm])
            self.up = nn.Sequential(*[act, upconv, norm])
            self.conv = nn.Sequential(*[act])
            down = self.down(x)
            up = self.up(down)

            # ---  修改開始  ---
            # 使用 nn.functional.interpolate 調整 up (即 dec) 的空間尺寸,使其匹配 x 的空間尺寸
            # 目標尺寸為 x 的空間尺寸,即 (3, 3)
            dec_resized = nn.functional.interpolate(up, size=[x.size(2), x.size(3)], mode='bilinear', align_corners=False)
            # 使用調整尺寸後的 dec_resized 進行 torch.cat() 操作
            return torch.cat([x, dec_resized], 1)
            # ---  修改結束  ---

        elif self.outermost: # ... (outermost 的程式碼,保持不變)
            # ...
        else: # ... (中間層的程式碼,保持不變)
            # ...

程式碼解釋:

  1. dec_resized = nn.functional.interpolate(up, size=[x.size(2), x.size(3)], mode='bilinear', align_corners=False): 這行程式碼是關鍵的修改。
    • nn.functional.interpolate(up, ...): 調用 nn.functional.interpolate 函數來進行插值上採樣。
    • up: up 變數在這裡對應到原始程式碼中的 up = self.up(down),也就是 UnetSkipConnectionBlock 的 “up” 路徑的輸出,我們需要調整的就是 up 的尺寸。
    • size=[x.size(2), x.size(3)]: size 參數指定了插值後的目標尺寸。 我們將目標尺寸設定為 x 的空間尺寸,即 [x.size(2), x.size(3)],其中 x.size(2) 代表高度維度的大小,x.size(3) 代表寬度維度的大小。 由於 x 的 shape 是 torch.Size([8, 512, 3, 3]),所以這裡的 size 會被設定為 [3, 3],也就是將 up (即 dec) 的高度和寬度都調整為 3 像素。
    • mode='bilinear': mode 參數指定了插值模式,我們選擇了 'bilinear' (雙線性插值)。 對於圖像特徵圖,bilinear 插值通常是一個較好的選擇,可以在調整尺寸的同時,保持圖像的平滑性,避免產生明顯的塊狀或鋸齒感。
    • align_corners=False: align_corners 參數設定為 False 是為了與 PyTorch 較新版本的 interpolate 行為保持一致,在大多數情況下,建議將 align_corners 設定為 False
  2. return torch.cat([x, dec_resized], 1): 在進行 torch.cat() 操作時,我們使用 調整尺寸後的 dec_resized,而不是原始的 up (即 dec),這樣就可以確保 xdec_resized 的空間尺寸一致,避免 RuntimeError 錯誤。

修改位置確認:

請務必確認您修改的是 UnetSkipConnectionBlock 類別的 forward() 方法中,if self.innermost: 條件判斷區塊內 的程式碼。 錯誤訊息指出問題發生在 innermost 的 torch.cat(),因此修改 innermost 的程式碼區塊是正確的。

其他可能的解決方案 (不推薦優先嘗試,僅供參考)

  • 調整 x 的尺寸以匹配 dec (不推薦): 雖然理論上也可以調整 x 的尺寸來匹配 dec,但通常不建議這樣做。 因為 x 是來自編碼器路徑的跳躍連接,通常認為編碼器路徑的特徵更重要,調整解碼器路徑的特徵 dec 來匹配 x 更為合理。
  • Padding 填充 dec (效果可能較差): 可以使用 Padding 操作 (例如 torch.nn.functional.pad) 將 dec 的空間尺寸填充到與 x 相同的大小。 但 Padding 通常會引入額外的零值,可能會影響特徵的有效性,效果可能不如插值上採樣好。
  • 修改 UNet 架構 (複雜且不必要): 更複雜的解決方案是修改 UnetSkipConnectionBlock 或整個 UNet 架構的設計,使其能夠更好地處理非標準的輸入尺寸。 例如,調整卷積層的 kernel size、stride、padding 參數,或者調整網路的層數,以確保在不同輸入尺寸下,跳躍連接的特徵圖尺寸仍然匹配。 但這種方法修改量較大,也需要更多實驗和驗證,對於解決您目前的問題來說,過於複雜且不必要。

建議的解法順序:

  1. 優先嘗試修改 UnetSkipConnectionBlockforward() 方法 (innermost 情況),加入 nn.functional.interpolate 調整 dec 尺寸的程式碼 (如上述程式碼範例所示)。 這是最直接、最簡單且最有可能解決問題的方法。
  2. 如果修改 forward() 方法後問題仍然存在,或者生成結果品質不理想, 可以仔細檢查您的資料預處理流程,確認輸入圖像的尺寸是否真的已經正確地調整為 400×400。 再次確認錯誤訊息是否仍然指向 innermost 的 torch.cat()
  3. 如果問題依然無法解決,並且您對 UNet 架構有更深入的理解,可以考慮更進階的調整方案 (例如,調整網路架構或卷積層參數),但通常情況下,修改 forward() 方法就足以解決空間尺寸不匹配的問題。

總結:

您遇到的 RuntimeError 是由於將輸入圖片解析度從 256×256 修改為 400×400 後,UNet 在最內層的跳躍連接處出現了空間維度不匹配的問題。 最推薦的解法是修改 UnetSkipConnectionBlockforward() 方法 (針對 innermost 的情況),在 torch.cat() 之前使用 nn.functional.interpolate 函數將 dec 的空間尺寸調整為與 x 相同的大小 (3×3),並使用 bilinear 插值模式。 這個方法可以有效地解決尺寸不匹配的問題,並且對模型結構的改動較小,易於實作。 修改完成後,請重新執行您的推論程式碼,觀察錯誤是否消失,並檢查生成結果是否符合預期。


同上, 新的錯誤訊息: torch.cat([x, dec], 1)

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 13 but got size 12 for tensor number 1 in the list.

torch.cat([x, dec], 1) RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 13 but got size 12 for tensor number 1 in the list.

錯誤訊息指出在使用 torch.cat([x, dec], 1) 時遇到了 RuntimeError,錯誤訊息為 “Sizes of tensors must match except in dimension 1. Expected size 13 but got size 12 for tensor number 1 in the list.”。 這個錯誤表示在執行 torch.cat() 操作時,張量 xdec 除了維度 1 (通常是通道維度) 之外,在其他維度上的尺寸必須完全一致,但實際上尺寸並不匹配。具體來說,錯誤訊息指出在某個維度上,預期尺寸為 13,但其中一個張量的尺寸卻為 12。

問題分析:張量尺寸不匹配

torch.cat(tensors, dim) 這個函數的作用是將給定的張量序列 tensors 沿著指定的維度 dim 連接起來。為了成功執行連接操作,除了指定的連接維度 dim 之外,所有輸入張量在其他維度上的尺寸都必須完全相同。

在您的情況中,torch.cat([x, dec], 1) 嘗試沿著維度 1 進行連接。錯誤訊息 “Expected size 13 but got size 12 for tensor number 1 in the list” 清楚地表明,張量 xdec 在**除了維度 1 之外的某個維度上**,尺寸不一致。其中一個張量在這個維度上的尺寸是 13,而另一個是 12。

要解決這個問題,您需要讓張量 xdec 在除了維度 1 以外的所有維度上尺寸一致。以下提供幾種可能的解決方案:

解法 1:檢查張量的形狀 (Shape) 並找出不匹配的維度

首先,您需要檢查張量 xdec 的形狀 (shape),以找出具體是哪個維度的尺寸不匹配。您可以使用 print(x.shape)print(dec.shape) 來輸出它們的形狀。 假設 xdec 都是 4 維張量 (例如,批次大小, 通道數, 高度, 寬度),形狀可能類似於:

  • x 的形狀可能是: torch.Size([Batch_Size, C_x, H_x, W_x])
  • dec 的形狀可能是: torch.Size([Batch_Size, C_dec, H_dec, W_dec])

torch.cat([x, dec], 1) 沿著維度 1 (通道維度) 連接,所以通道數 C_xC_dec 的總和將會是輸出張量的通道數。 但是,**批次大小 (Batch_Size)、高度維度 (H) 和寬度維度 (W)** 對於 xdec 必須完全相同。 錯誤訊息 “Expected size 13 but got size 12” 表示在高度維度 (H) 或寬度維度 (W) 上,xdec 的尺寸不一致,一個是 13,另一個是 12。 您需要比較 x.shapedec.shape,找出是高度維度還是寬度維度 (或是其他維度,如果張量是更高維度的) 的尺寸不匹配。

解法 2:調整張量的尺寸以匹配

一旦您確定了是哪個維度的尺寸不匹配 (例如,假設是高度維度,且 x 的高度是 13,dec 的高度是 12),您就需要調整其中一個張量的尺寸,使其與另一個張量匹配。 常用的調整尺寸的方法有:

  1. 裁剪 (Crop): 如果尺寸差異不大,且您希望保留張量中心區域的資訊,可以裁剪尺寸較大的張量,使其尺寸與較小的張量一致。 例如,如果 x 的高度是 13,dec 的高度是 12,您可以將 x 在高度維度上裁剪掉 1 個像素,使其高度也變成 12。 在 PyTorch 中,您可以使用切片 (slicing) 操作來實現裁剪。
  2. 填充 (Padding): 如果尺寸較小的張量需要擴展尺寸才能與較大的張量匹配,可以使用填充。例如,如果 dec 的高度是 12,x 的高度是 13,您可以對 dec 在高度維度上進行填充,使其高度也變成 13。 PyTorch 提供了 torch.nn.functional.pad 函數來進行填充。 您可以選擇不同的填充模式,例如零填充、常數值填充、反射填充、複製填充等。
  3. 調整大小 (Resize/Interpolate): 更常見且可能更適合圖像或特徵圖處理的方法是使用調整大小 (或插值)。 您可以將尺寸較小的張量放大,或者將尺寸較大的張量縮小,使其與另一個張量的尺寸匹配。 PyTorch 提供了 torch.nn.functional.interpolate 函數來進行插值調整大小。 interpolate 函數提供了多種插值模式,例如最近鄰插值 (nearest)、線性插值 (linear)、雙線性插值 (bilinear)、雙三次插值 (bicubic) 等。 對於圖像特徵圖,雙線性插值通常是一個較好的選擇,可以在調整尺寸的同時保持圖像的平滑性。

解法 3:使用 torch.nn.functional.interpolate 調整 dec 的尺寸 (範例程式碼,假設是高度維度不匹配)

假設您發現是高度維度不匹配,並且 x 的高度是 13,dec 的高度是 12。 以下範例程式碼展示如何使用 torch.nn.functional.interpolatedec 的高度調整為 13,使其與 x 的高度匹配,然後再進行 torch.cat() 操作。 我們假設使用雙線性插值 bilinear

import torch
import torch.nn.functional as F

# 假設 x 和 dec 是您的張量,並且已經有數值 (這裡為了示範,隨機生成兩個張量)
# 假設 x 的 shape 是 torch.Size([8, 512, 13, 13])
# 假設 dec 的 shape 是 torch.Size([8, 512, 12, 12])
x = torch.randn(8, 512, 13, 13)
dec = torch.randn(8, 512, 12, 12)

print("原始 x 的 shape:", x.shape)
print("原始 dec 的 shape:", dec.shape)

# 使用 nn.functional.interpolate 調整 dec 的空間尺寸,使其高度和寬度都匹配 x 的空間尺寸
# 這裡我們只調整高度,如果寬度也不同,也需要調整寬度
# 目標尺寸為 x 的空間尺寸 (假設我們只需要調整高度)
dec_resized = F.interpolate(dec, size=[x.size(2), x.size(3)], mode='bilinear', align_corners=False)

print("調整尺寸後的 dec 的 shape:", dec_resized.shape)

# 現在 dec_resized 的 shape 應該是 torch.Size([8, 512, 13, 12]) (如果只調整高度)
# 如果 x 的 shape 是 torch.Size([8, 512, 13, 13]), 錯誤訊息 "Expected size 13 but got size 12" 仍然存在
# 錯誤訊息 "Expected size 13 but got size 12"  暗示了是某個維度期望是 13,但實際是 12
# 根據錯誤訊息 "Sizes of tensors must match except in dimension 1", dimension 1 是通道維度 (512)
# 所以不匹配的維度可能是維度 2 或維度 3,也就是高度或寬度
# 我們假設是高度維度 (維度 2) 不匹配,預期 13,實際 12,所以需要將 dec 的高度調整為 13

dec_resized_height_matched = F.interpolate(dec, size=[13, dec.size(3)], mode='bilinear', align_corners=False)
# 或更精確地匹配 x 的高度和寬度
dec_resized_matched = F.interpolate(dec, size=[x.size(2), x.size(3)], mode='bilinear', align_corners=False)


print("高度調整後的 dec 的 shape:", dec_resized_height_matched.shape)
print("完全匹配 x 尺寸調整後的 dec 的 shape:", dec_resized_matched.shape)


# 現在可以進行 torch.cat 操作了 (使用調整尺寸後的 dec_resized_matched)
try:
    # 使用 dec_resized_matched 進行 torch.cat() 操作
    output = torch.cat([x, dec_resized_matched], dim=1)
    print("torch.cat 操作成功,輸出 shape:", output.shape)
except RuntimeError as e:
    print("torch.cat 操作仍然失敗:", e)


# 再次嘗試使用 dec_resized_height_matched (僅高度調整)
try:
    # 使用 dec_resized_height_matched 進行 torch.cat() 操作
    output_height_matched = torch.cat([x, dec_resized_height_matched], dim=1)
    print("torch.cat 操作 (僅高度調整) 成功,輸出 shape:", output_height_matched.shape)
except RuntimeError as e:
    print("torch.cat 操作 (僅高度調整) 仍然失敗:", e)

程式碼解釋:

  1. dec_resized_matched = F.interpolate(dec, size=[x.size(2), x.size(3)], mode='bilinear', align_corners=False): 這行程式碼使用 torch.nn.functional.interpolate 函數來調整 dec 的尺寸。
    • dec: 是要調整尺寸的輸入張量。
    • size=[x.size(2), x.size(3)]: 指定了目標尺寸。 我們將 dec 的高度和寬度都調整為與 x 的高度和寬度一致,即 x.size(2) 代表高度,x.size(3) 代表寬度。
    • mode='bilinear': 選擇雙線性插值模式。
    • align_corners=False: align_corners=False 通常是較好的選擇,與 PyTorch 較新版本的行為一致。
  2. output = torch.cat([x, dec_resized_matched], dim=1): 使用調整尺寸後的 dec_resized_matchedx 進行 torch.cat() 操作。

解法 4:檢查網路架構或資料處理流程

如果調整尺寸後問題仍然存在,或者您認為不應該出現尺寸不匹配的情況,那麼可能需要檢查您的網路架構或資料處理流程:

  • 檢查網路架構設計: 仔細檢查網路架構的設計,特別是 torch.cat([x, dec], 1) 這行程式碼所在的網路層或區塊。 確保網路架構的設計邏輯上應該在這個位置將形狀兼容的張量進行連接。 有時候,網路架構的某些部分可能在設計上就沒有考慮到輸入尺寸的變化,或者存在程式碼錯誤導致了尺寸不匹配。
  • 檢查資料預處理流程: 檢查資料預處理流程,確保輸入到網路的資料 (以及中間層輸出的特徵圖) 的尺寸是符合預期的。 有時候,資料預處理的錯誤 (例如,錯誤的裁剪、縮放、填充等) 可能會導致張量尺寸不一致。

總結

解決 “RuntimeError: Sizes of tensors must match except in dimension 1” 的關鍵步驟是:

  1. 分析錯誤訊息: 理解錯誤訊息的含義,明確指出是哪個維度尺寸不匹配。
  2. 檢查張量形狀: 輸出 xdec 的 shape,找出尺寸不匹配的維度。
  3. 選擇合適的解決方案: 根據具體情況選擇裁剪、填充或調整大小 (插值) 等方法來調整張量的尺寸,使其在除了連接維度之外的所有維度上都與另一個張量匹配。 通常,對於圖像或特徵圖,插值調整大小是一個較為通用的解決方案。
  4. 驗證解決方案: 修改程式碼後,重新執行程式,驗證 torch.cat() 操作是否成功,並檢查模型的輸出是否符合預期。

在您提供的錯誤訊息 “Expected size 13 but got size 12” 的情況下,最可能且最推薦的解決方案是使用 torch.nn.functional.interpolate 將尺寸為 12 的張量調整大小為 13,使其與尺寸為 13 的張量匹配,然後再進行 torch.cat() 操作。 請先嘗試使用範例程式碼中的插值方法,並根據您的具體情況調整程式碼,例如選擇合適的插值模式和目標尺寸。

About the author

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *