zi2zi-pytorch UnetSkipConnectionBlock 的 def forward(self, x) 的 x 的 input 資料是從那裡來的?

解碼器路徑的 “第一個”(laryer=1),但實際上它接收的是 UNet 架構的 Bottleneck 部分 (瓶頸層) 的輸出。 Bottleneck 通常是編碼器路徑的最深層,最內層 (Innermost),也是連接編碼器和解碼器的橋樑。

雖然是layer=1, 實際上被呼叫的順序是在最後,

以 512×512 圖片做為輸入, x 的 shape[1] 以序如下:

  • self.layer 8
    x.shape[1] 1
    x.shape[2] 512
    x.shape[3] 512
    tensor([[[[ 1.0000, 1.0000, 1.0000, …, 1.0000, 1.0000, 1.0000],
  • layer 7
    x.shape[0] 8
    x.shape[1] 64
    x.shape[2] 256
    x.shape[3] 256
    tensor([[[[-0.0235, -0.0102, -0.0102, …, -0.0102, -0.0102, 0.0197],
  • layer 6
    x.shape[0] 8
    x.shape[1] 128
    x.shape[2] 128
    x.shape[3] 128
    tensor([[[[-6.3282e-02, 3.2165e-01, 3.2165e-01, …, 3.2165e-01, 3.2165e-01, -9.6392e-02],
  • layer 5
    x.shape[0] 8
    x.shape[1] 256
    x.shape[2] 64
    x.shape[3] 64
    tensor([[[[ 1.7172e+00, 1.0012e+00, 1.0012e+00, …, 1.0012e+00, 1.0012e+00, 1.3477e+00],
  • layer 4
    x.shape[0] 8
    x.shape[1] 512
    x.shape[2] 32
    x.shape[3] 32
    tensor([[[[ 3.5998e-01, 3.3396e-01, 5.3206e-02, …, 2.8828e-02, 2.8828e-02, 2.1134e-01],
  • layer 3
    x.shape[0] 8
    x.shape[1] 512
    x.shape[2] 16
    x.shape[3] 16
    tensor([[[[ 1.0598e+00, -9.8572e-01, -1.5966e-01, …, 1.0778e+00, 2.8010e-01, 3.9897e-01],
  • layer 2
    x.shape[0] 8
    x.shape[1] 512
    x.shape[2] 8
    x.shape[3] 8
    x tensor([[[[ 1.7219e+00, 7.3240e-01, 4.2313e-01, …, 9.2909e-01, 7.7978e-01, 1.8762e-01],
  • .layer 1
    x.shape[0] 8
    x.shape[1] 512
    x.shape[2] 4
    x.shape[3] 4
    x tensor([[[[-5.6431e-01, 5.9121e-01, 8.9240e-01, 2.8132e-01],

以上, x.shape[0] 固定時 batch_size.

以 256×256 圖片做為輸入, x 的 shape[1] 以序如下:

  • layer 8
    x.shape[0] 8
    x.shape[1] 1
    x.shape[2] 256
    x.shape[3] 256
    tensor([[[[1., 1., 1., …, 1., 1., 1.],
  • layer 7
    x.shape[0] 8
    x.shape[1] 64
    x.shape[2] 128
    x.shape[3] 128
    tensor([[[[-0.0285, -0.0182, -0.0182, …, -0.0182, -0.0182, 0.0147],
  • layer 6
    x.shape[0] 8
    x.shape[1] 128
    x.shape[2] 64
    x.shape[3] 64
    tensor([[[[ 0.0416, 0.3736, 0.3736, …, 0.3736, 0.3736, 0.0185],
  • layer 5
    x.shape[0] 8
    x.shape[1] 256
    x.shape[2] 32
    x.shape[3] 32
    tensor([[[[ 3.7615e-01, 3.2315e-02, -3.0772e-01, …, -3.9309e-01,
  • layer 4
    x.shape[0] 8
    x.shape[1] 512
    x.shape[2] 16
    x.shape[3] 16
    x tensor([[[[ 4.4770e-01, -1.2710e+00, -6.8913e-01, …, 2.0007e-01,
  • layer 3
    x.shape[0] 8
    x.shape[1] 512
    x.shape[2] 8
    x.shape[3] 8
    tensor([[[[ 3.6267e-01, -1.2073e+00, -8.4982e-02, …, 2.3464e+00, 1.1156e+00, 8.6918e-02],
  • layer 2
    x.shape[0] 8
    x.shape[1] 512
    x.shape[2] 4
    x.shape[3] 4
    tensor([[[[-1.1430e+00, 2.6779e-01, 6.2793e-01, 3.9685e-01],
  • layer 1
    x.shape[0] 21
    x.shape[1] 512
    x.shape[2] 2
    x.shape[3] 2
    tensor([[[[ 7.4590e-01, 6.0514e-01],


UnetSkipConnectionBlockdef forward(self, x) 函數中的輸入資料 x,其來源取決於 UnetSkipConnectionBlock 在 UNet 架構中所處的位置,以及它是屬於編碼器路徑 (Encoder Path) 還是解碼器路徑 (Decoder Path) 的一部分。

為了更清楚地說明 x 的來源,我們需要先回顧一下 UNet 架構的基本概念,以及 UnetSkipConnectionBlock 在其中的作用:

UNet 架構簡要回顧:

UNet 架構主要由兩個路徑組成,形狀類似英文字母 “U”:

  • 編碼器路徑 (Encoder Path) / 下採樣路徑 (Downsampling Path): 網路的左側部分,負責逐步提取輸入圖像的特徵,並降低解析度。可以想像成一個特徵提取器,將圖像壓縮成更小的、但包含豐富語義資訊的特徵表示。
  • 解碼器路徑 (Decoder Path) / 上採樣路徑 (Upsampling Path): 網路的右側部分,負責逐步將低解析度的特徵圖還原到原始解析度,並生成最終的輸出圖像。可以想像成一個特徵還原器,將壓縮的特徵資訊擴展並轉換回圖像。
  • 跳躍連接 (Skip Connection): 連接編碼器路徑和解碼器路徑的橋樑。從編碼器路徑的某些層提取的特徵,會直接跳躍連接到解碼器路徑中對應的層,幫助解碼器在還原圖像時,能利用到編碼器路徑中更細節的特徵資訊。

UnetSkipConnectionBlock 的作用:

zi2zi-pytorch 中,UnetSkipConnectionBlock 是 UNet 架構的基本建構區塊 (building block)。 它被重複使用在編碼器路徑和解碼器路徑中,用於建構整個 UNet 網路。 每個 UnetSkipConnectionBlock 內部通常包含卷積層、正規化層、激活函數等,負責進行特徵提取和處理。

UnetSkipConnectionBlockdef forward(self, x) 輸入 x 的來源 (根據不同位置):

現在,我們來具體說明 UnetSkipConnectionBlockforward 函數的輸入 x 是從哪裡來的,根據它在 UNet 架構中的位置,可以分為幾種情況:

1. 編碼器路徑 (Encoder Path) 中的 UnetSkipConnectionBlock

在編碼器路徑中,UnetSkipConnectionBlock 通常是 串聯 起來的,形成一個逐步下採樣的過程。

  • 第一個 UnetSkipConnectionBlock (編碼器路徑的起始): 對於編碼器路徑中的 第一個 UnetSkipConnectionBlock,它的 forward 函數的輸入 x 通常是 整個 UNet 網路的原始輸入圖像 (Input Image)
  • 後續的 UnetSkipConnectionBlock (編碼器路徑的中間層): 對於編碼器路徑中 後續的 UnetSkipConnectionBlock (非第一個),它的 forward 函數的輸入 x 來自於 前一個 UnetSkipConnectionBlock 的輸出。 也就是說,編碼器路徑中的資料是逐層向下傳遞的。

簡而言之,在編碼器路徑中,每個 UnetSkipConnectionBlock 的輸入 x,除了第一個 block 是原始輸入圖像外,其餘都是來自於前一個 UnetSkipConnectionBlock 的輸出。

資料流動示意 (編碼器路徑):

Input Image --> UnetSkipConnectionBlock_1 (input: Input Image, output: Feature Map 1)
              --> UnetSkipConnectionBlock_2 (input: Feature Map 1, output: Feature Map 2)
              --> UnetSkipConnectionBlock_3 (input: Feature Map 2, output: Feature Map 3)
              --> ... (更多 UnetSkipConnectionBlock ...)
              --> ... (編碼器路徑的最後一個 UnetSkipConnectionBlock)

2. 解碼器路徑 (Decoder Path) 中的 UnetSkipConnectionBlock

解碼器路徑中的 UnetSkipConnectionBlock 除了接收來自解碼器路徑前一層的輸入外,還會接收來自編碼器路徑的 跳躍連接 (Skip Connection) 的輸入。

  • 解碼器路徑的起始 UnetSkipConnectionBlock (通常在 UNet Bottleneck 之後): 對於解碼器路徑的 第一個UnetSkipConnectionBlock,它的 forward 函數的輸入 x 來自於:
    • 解碼器路徑前一層的輸出: 雖然是解碼器路徑的 “第一個”,但實際上它接收的是 UNet 架構的 Bottleneck 部分 (瓶頸層) 的輸出。 Bottleneck 通常是編碼器路徑的最深層,也是連接編碼器和解碼器的橋樑。
    • 來自編碼器路徑的跳躍連接: 同時,它也會接收來自編碼器路徑中 對應層UnetSkipConnectionBlock跳躍連接 輸出。 在 UNet 中,跳躍連接通常是將編碼器路徑的特徵圖與解碼器路徑的特徵圖 串聯 (Concatenate) 在一起。
  • 後續的 UnetSkipConnectionBlock (解碼器路徑的中間層和輸出層): 對於解碼器路徑中 後續的UnetSkipConnectionBlock (非第一個),它的 forward 函數的輸入 x ebenfalls aus zwei Quellen:
    • 解碼器路徑前一層的輸出: 來自解碼器路徑中 前一個 UnetSkipConnectionBlock 的輸出。
    • 來自編碼器路徑的跳躍連接: 以及來自編碼器路徑中 對應層UnetSkipConnectionBlock跳躍連接 輸出 (串聯在一起)。

簡而言之,在解碼器路徑中,每個 UnetSkipConnectionBlock 的輸入 x 都來自兩個部分: 前一個解碼器 UnetSkipConnectionBlock 的輸出,以及編碼器路徑中對應層的 UnetSkipConnectionBlock 的跳躍連接輸出 (串聯在一起)。

資料流動示意 (解碼器路徑與跳躍連接):

Encoder Path:
UnetSkipConnectionBlock_Encoder_1 --> Feature Map Encoder 1  ------- (Skip Connection) ------>  UnetSkipConnectionBlock_Decoder_1 Input (part 2)
UnetSkipConnectionBlock_Encoder_2 --> Feature Map Encoder 2  ------- (Skip Connection) ------>  UnetSkipConnectionBlock_Decoder_2 Input (part 2)
UnetSkipConnectionBlock_Encoder_3 --> Feature Map Encoder 3  ------- (Skip Connection) ------>  UnetSkipConnectionBlock_Decoder_3 Input (part 2)
...

Decoder Path:
(Bottleneck Output) --> UnetSkipConnectionBlock_Decoder_1 Input (part 1)
                                  UnetSkipConnectionBlock_Decoder_1 (input: [Input Part 1, Input Part 2 (Skip Connection from Encoder 1)], output: Feature Map Decoder 1)
                                  --> UnetSkipConnectionBlock_Decoder_2 Input (part 1)
                                                                   UnetSkipConnectionBlock_Decoder_2 (input: [Input Part 1, Input Part 2 (Skip Connection from Encoder 2)], output: Feature Map Decoder 2)
                                                                   --> UnetSkipConnectionBlock_Decoder_3 Input (part 1)
                                                                                                    UnetSkipConnectionBlock_Decoder_3 (input: [Input Part 1, Input Part 2 (Skip Connection from Encoder 3)], output: Feature Map Decoder 3)
                                                                                                    --> ... (解碼器路徑的後續 UnetSkipConnectionBlock)
                                                                                                    --> Final Output Image (透過解碼器路徑的最後一層)

3. 最內層 (Bottleneck) 的 UnetSkipConnectionBlock (如果存在):

有些 UNet 架構會在編碼器路徑和解碼器路徑之間有一個 最內層 (Innermost) 的 UnetSkipConnectionBlock,也稱為 Bottleneck (瓶頸層)

  • Bottleneck UnetSkipConnectionBlock Bottleneck 層的 UnetSkipConnectionBlockforward 函數的輸入 x 來自於 編碼器路徑中最後一個 UnetSkipConnectionBlock 的輸出。 它的輸出則會直接作為 解碼器路徑中第一個 UnetSkipConnectionBlock 的輸入之一 (解碼器路徑前一層的輸出部分)。 Bottleneck 層本身 不接收 來自跳躍連接的輸入,但它 會提供輸出 作為跳躍連接的一部分,傳遞給解碼器路徑的第一個 UnetSkipConnectionBlock

資料流動示意 (Bottleneck):

Encoder Path (Last Block) --> Bottleneck UnetSkipConnectionBlock Input
                                Bottleneck UnetSkipConnectionBlock (input: Encoder Path Last Block Output, output: Bottleneck Output)
                                --> Decoder Path (First Block) Input (Part 1)

總結:UnetSkipConnectionBlock 輸入 x 的來源

UnetSkipConnectionBlock 位置forward(self, x) 的輸入 x 的來源
編碼器路徑 (Encoder Path)
第一個 BlockUNet 網路的 原始輸入圖像 (Input Image)
後續 Block (非第一個)前一個 UnetSkipConnectionBlock 的輸出
解碼器路徑 (Decoder Path)
第一個 Block (Bottleneck 後)1. Bottleneck 層的輸出 (解碼器路徑前一層的輸出) <br/> 2. 來自 編碼器路徑中對應層 UnetSkipConnectionBlock 的跳躍連接輸出
後續 Block (非第一個, 包括輸出層前的)1. 前一個解碼器路徑 UnetSkipConnectionBlock 的輸出 (解碼器路徑前一層的輸出) <br/> 2. 來自 編碼器路徑中對應層 UnetSkipConnectionBlock 的跳躍連接輸出
Bottleneck 層 (如果存在)編碼器路徑中最後一個 UnetSkipConnectionBlock 的輸出
最外層 (Outermost) Block1. 前一個解碼器路徑 UnetSkipConnectionBlock 的輸出 (解碼器路徑前一層的輸出) <br/> 2. 來自 編碼器路徑中對應層 UnetSkipConnectionBlock 的跳躍連接輸出 (與後續 Block 相同)

更具體的程式碼範例 (概念性):

為了更清楚地理解,以下提供一個非常簡化的 UNet 模型程式碼範例,展示 UnetSkipConnectionBlock 如何連接以及資料如何流動:

Python

import torch
import torch.nn as nn

class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_channels, inner_channels, submodule=None): # Simplified for demonstration
        super(UnetSkipConnectionBlock, self).__init__()
        self.conv = nn.Conv2d(inner_channels, outer_channels, kernel_size=3, padding=1)
        self.submodule = submodule # Submodule for deeper UNet

    def forward(self, x):
        if self.submodule is not None:
            x_submodule = self.submodule(x) # 遞迴呼叫 submodule
            x_output = self.conv(x_submodule)
            return torch.cat([x, x_output], dim=1) # Skip connection (concatenate input x and output)
        else: # Innermost block
            return self.conv(x)

class SimpleUNet(nn.Module):
    def __init__(self, input_channels, output_channels):
        super(SimpleUNet, self).__init__()
        # Encoder Path
        self.encoder_block1 = UnetSkipConnectionBlock(64, input_channels) # First encoder block, input is original image
        self.encoder_block2 = UnetSkipConnectionBlock(128, 64, submodule=self.encoder_block1) # Input from previous encoder block
        self.encoder_block3 = UnetSkipConnectionBlock(256, 128, submodule=self.encoder_block2)

        # Bottleneck (Innermost - in this simplified example, encoder_block3 acts as bottleneck too)
        self.bottleneck = self.encoder_block3 # In this example, last encoder block is also the bottleneck

        # Decoder Path
        self.decoder_block1 = UnetSkipConnectionBlock(128, 256, submodule=self.bottleneck) # Input: Bottleneck output, Skip from encoder_block3
        self.decoder_block2 = UnetSkipConnectionBlock(64, 128, submodule=self.decoder_block1) # Input: Previous decoder block, Skip from encoder_block2
        self.decoder_block3 = UnetSkipConnectionBlock(output_channels, 64, submodule=self.decoder_block2) # Input: Previous decoder block, Skip from encoder_block1

        self.output_conv = nn.Conv2d(output_channels, output_channels, kernel_size=1) # Output layer

    def forward(self, x):
        # Encoder Path
        encoder_output1 = self.encoder_block1(x)
        encoder_output2 = self.encoder_block2(x)
        encoder_output3 = self.encoder_block3(x) # encoder_output3 is also the bottleneck output

        # Decoder Path
        decoder_input1 = encoder_output3 # Bottleneck output as first decoder input
        decoder_output1 = self.decoder_block1(decoder_input1) # decoder_block1 input is decoder_input1 + skip from encoder_block3 (in block's forward)

        decoder_input2 = decoder_output1 # Output of decoder_block1 as input to next decoder block
        decoder_output2 = self.decoder_block2(decoder_input2) # decoder_block2 input is decoder_input2 + skip from encoder_block2

        decoder_input3 = decoder_output2 # Output of decoder_block2 as input to next decoder block
        decoder_output3 = self.decoder_block3(decoder_input3) # decoder_block3 input is decoder_input3 + skip from encoder_block1

        final_output = self.output_conv(decoder_output3)
        return final_output


# Example usage
input_channels = 3
output_channels = 3
image_size = 256
batch_size = 1

unet_model = SimpleUNet(input_channels=input_channels, output_channels=output_channels)
dummy_input = torch.randn(batch_size, input_channels, image_size, image_size)
output = unet_model(dummy_input)

print("UNet Output Shape:", output.shape) # Expected output shape: torch.Size([1, 3, 256, 256])

程式碼範例解釋:

  • SimpleUNet 類別: 定義了一個簡化的 UNet 模型。
  • 編碼器路徑: encoder_block1, encoder_block2, encoder_block3 組成編碼器路徑,資料逐層向下傳遞。
  • Bottleneck: encoder_block3 同時作為 Bottleneck 層。
  • 解碼器路徑: decoder_block1, decoder_block2, decoder_block3 組成解碼器路徑,資料逐層向上傳遞,並接收來自編碼器路徑的跳躍連接 (雖然在 SimpleUNetforward 函數中沒有顯式地看到跳躍連接的串聯操作,但 UnetSkipConnectionBlockforward 函數中已經包含了 torch.cat 串聯)。
  • forward 函數: 清晰地展示了資料在編碼器路徑、Bottleneck 層和解碼器路徑之間的流動,以及每個 UnetSkipConnectionBlock 的輸入 x 的來源。

總結:

理解 UnetSkipConnectionBlockforward(self, x) 函數的輸入 x 的來源,需要結合 UNet 架構的整體資料流動來看。 在編碼器路徑中,x 主要來自前一層的輸出;在解碼器路徑中,x 則來自前一層解碼器輸出 以及 編碼器路徑對應層的跳躍連接。 透過這種跳躍連接的結構,UNet 能夠有效地融合多層次的特徵信息,生成更精細、更準確的圖像。 查看 UNet 模型的完整程式碼 (例如 zi2zi-pytorch 專案中的 generators.py 或相關檔案),結合 UNet 架構示意圖,可以更直觀地理解 UnetSkipConnectionBlock 在整個網路中的資料輸入和輸出關係。

About the author

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *