解碼器路徑的 “第一個”(laryer=1),但實際上它接收的是 UNet 架構的 Bottleneck 部分 (瓶頸層) 的輸出。 Bottleneck 通常是編碼器路徑的最深層,最內層 (Innermost),也是連接編碼器和解碼器的橋樑。
雖然是layer=1, 實際上被呼叫的順序是在最後,
以 512×512 圖片做為輸入, x 的 shape[1] 以序如下:
- self.layer 8
x.shape[1] 1
x.shape[2] 512
x.shape[3] 512
tensor([[[[ 1.0000, 1.0000, 1.0000, …, 1.0000, 1.0000, 1.0000], - layer 7
x.shape[0] 8
x.shape[1] 64
x.shape[2] 256
x.shape[3] 256
tensor([[[[-0.0235, -0.0102, -0.0102, …, -0.0102, -0.0102, 0.0197], - layer 6
x.shape[0] 8
x.shape[1] 128
x.shape[2] 128
x.shape[3] 128
tensor([[[[-6.3282e-02, 3.2165e-01, 3.2165e-01, …, 3.2165e-01, 3.2165e-01, -9.6392e-02], - layer 5
x.shape[0] 8
x.shape[1] 256
x.shape[2] 64
x.shape[3] 64
tensor([[[[ 1.7172e+00, 1.0012e+00, 1.0012e+00, …, 1.0012e+00, 1.0012e+00, 1.3477e+00], - layer 4
x.shape[0] 8
x.shape[1] 512
x.shape[2] 32
x.shape[3] 32
tensor([[[[ 3.5998e-01, 3.3396e-01, 5.3206e-02, …, 2.8828e-02, 2.8828e-02, 2.1134e-01], - layer 3
x.shape[0] 8
x.shape[1] 512
x.shape[2] 16
x.shape[3] 16
tensor([[[[ 1.0598e+00, -9.8572e-01, -1.5966e-01, …, 1.0778e+00, 2.8010e-01, 3.9897e-01], - layer 2
x.shape[0] 8
x.shape[1] 512
x.shape[2] 8
x.shape[3] 8
x tensor([[[[ 1.7219e+00, 7.3240e-01, 4.2313e-01, …, 9.2909e-01, 7.7978e-01, 1.8762e-01], - .layer 1
x.shape[0] 8
x.shape[1] 512
x.shape[2] 4
x.shape[3] 4
x tensor([[[[-5.6431e-01, 5.9121e-01, 8.9240e-01, 2.8132e-01],
以上, x.shape[0] 固定時 batch_size.
以 256×256 圖片做為輸入, x 的 shape[1] 以序如下:
- layer 8
x.shape[0] 8
x.shape[1] 1
x.shape[2] 256
x.shape[3] 256
tensor([[[[1., 1., 1., …, 1., 1., 1.], - layer 7
x.shape[0] 8
x.shape[1] 64
x.shape[2] 128
x.shape[3] 128
tensor([[[[-0.0285, -0.0182, -0.0182, …, -0.0182, -0.0182, 0.0147], - layer 6
x.shape[0] 8
x.shape[1] 128
x.shape[2] 64
x.shape[3] 64
tensor([[[[ 0.0416, 0.3736, 0.3736, …, 0.3736, 0.3736, 0.0185], - layer 5
x.shape[0] 8
x.shape[1] 256
x.shape[2] 32
x.shape[3] 32
tensor([[[[ 3.7615e-01, 3.2315e-02, -3.0772e-01, …, -3.9309e-01, - layer 4
x.shape[0] 8
x.shape[1] 512
x.shape[2] 16
x.shape[3] 16
x tensor([[[[ 4.4770e-01, -1.2710e+00, -6.8913e-01, …, 2.0007e-01, - layer 3
x.shape[0] 8
x.shape[1] 512
x.shape[2] 8
x.shape[3] 8
tensor([[[[ 3.6267e-01, -1.2073e+00, -8.4982e-02, …, 2.3464e+00, 1.1156e+00, 8.6918e-02], - layer 2
x.shape[0] 8
x.shape[1] 512
x.shape[2] 4
x.shape[3] 4
tensor([[[[-1.1430e+00, 2.6779e-01, 6.2793e-01, 3.9685e-01], - layer 1
x.shape[0] 21
x.shape[1] 512
x.shape[2] 2
x.shape[3] 2
tensor([[[[ 7.4590e-01, 6.0514e-01],
UnetSkipConnectionBlock
的 def forward(self, x)
函數中的輸入資料 x
,其來源取決於 UnetSkipConnectionBlock
在 UNet 架構中所處的位置,以及它是屬於編碼器路徑 (Encoder Path) 還是解碼器路徑 (Decoder Path) 的一部分。
為了更清楚地說明 x
的來源,我們需要先回顧一下 UNet 架構的基本概念,以及 UnetSkipConnectionBlock
在其中的作用:
UNet 架構簡要回顧:
UNet 架構主要由兩個路徑組成,形狀類似英文字母 “U”:
- 編碼器路徑 (Encoder Path) / 下採樣路徑 (Downsampling Path): 網路的左側部分,負責逐步提取輸入圖像的特徵,並降低解析度。可以想像成一個特徵提取器,將圖像壓縮成更小的、但包含豐富語義資訊的特徵表示。
- 解碼器路徑 (Decoder Path) / 上採樣路徑 (Upsampling Path): 網路的右側部分,負責逐步將低解析度的特徵圖還原到原始解析度,並生成最終的輸出圖像。可以想像成一個特徵還原器,將壓縮的特徵資訊擴展並轉換回圖像。
- 跳躍連接 (Skip Connection): 連接編碼器路徑和解碼器路徑的橋樑。從編碼器路徑的某些層提取的特徵,會直接跳躍連接到解碼器路徑中對應的層,幫助解碼器在還原圖像時,能利用到編碼器路徑中更細節的特徵資訊。
UnetSkipConnectionBlock
的作用:
在 zi2zi-pytorch
中,UnetSkipConnectionBlock
是 UNet 架構的基本建構區塊 (building block)。 它被重複使用在編碼器路徑和解碼器路徑中,用於建構整個 UNet 網路。 每個 UnetSkipConnectionBlock
內部通常包含卷積層、正規化層、激活函數等,負責進行特徵提取和處理。
UnetSkipConnectionBlock
的 def forward(self, x)
輸入 x
的來源 (根據不同位置):
現在,我們來具體說明 UnetSkipConnectionBlock
的 forward
函數的輸入 x
是從哪裡來的,根據它在 UNet 架構中的位置,可以分為幾種情況:
1. 編碼器路徑 (Encoder Path) 中的 UnetSkipConnectionBlock
:
在編碼器路徑中,UnetSkipConnectionBlock
通常是 串聯 起來的,形成一個逐步下採樣的過程。
- 第一個
UnetSkipConnectionBlock
(編碼器路徑的起始): 對於編碼器路徑中的 第一個UnetSkipConnectionBlock
,它的forward
函數的輸入x
通常是 整個 UNet 網路的原始輸入圖像 (Input Image)。 - 後續的
UnetSkipConnectionBlock
(編碼器路徑的中間層): 對於編碼器路徑中 後續的UnetSkipConnectionBlock
(非第一個),它的forward
函數的輸入x
來自於 前一個UnetSkipConnectionBlock
的輸出。 也就是說,編碼器路徑中的資料是逐層向下傳遞的。
簡而言之,在編碼器路徑中,每個 UnetSkipConnectionBlock
的輸入 x
,除了第一個 block 是原始輸入圖像外,其餘都是來自於前一個 UnetSkipConnectionBlock
的輸出。
資料流動示意 (編碼器路徑):
Input Image --> UnetSkipConnectionBlock_1 (input: Input Image, output: Feature Map 1)
--> UnetSkipConnectionBlock_2 (input: Feature Map 1, output: Feature Map 2)
--> UnetSkipConnectionBlock_3 (input: Feature Map 2, output: Feature Map 3)
--> ... (更多 UnetSkipConnectionBlock ...)
--> ... (編碼器路徑的最後一個 UnetSkipConnectionBlock)
2. 解碼器路徑 (Decoder Path) 中的 UnetSkipConnectionBlock
:
解碼器路徑中的 UnetSkipConnectionBlock
除了接收來自解碼器路徑前一層的輸入外,還會接收來自編碼器路徑的 跳躍連接 (Skip Connection) 的輸入。
- 解碼器路徑的起始
UnetSkipConnectionBlock
(通常在 UNet Bottleneck 之後): 對於解碼器路徑的 第一個UnetSkipConnectionBlock
,它的forward
函數的輸入x
來自於:- 解碼器路徑前一層的輸出: 雖然是解碼器路徑的 “第一個”,但實際上它接收的是 UNet 架構的 Bottleneck 部分 (瓶頸層) 的輸出。 Bottleneck 通常是編碼器路徑的最深層,也是連接編碼器和解碼器的橋樑。
- 來自編碼器路徑的跳躍連接: 同時,它也會接收來自編碼器路徑中 對應層 的
UnetSkipConnectionBlock
的 跳躍連接 輸出。 在 UNet 中,跳躍連接通常是將編碼器路徑的特徵圖與解碼器路徑的特徵圖 串聯 (Concatenate) 在一起。
- 後續的
UnetSkipConnectionBlock
(解碼器路徑的中間層和輸出層): 對於解碼器路徑中 後續的UnetSkipConnectionBlock
(非第一個),它的forward
函數的輸入x
ebenfalls aus zwei Quellen:- 解碼器路徑前一層的輸出: 來自解碼器路徑中 前一個
UnetSkipConnectionBlock
的輸出。 - 來自編碼器路徑的跳躍連接: 以及來自編碼器路徑中 對應層 的
UnetSkipConnectionBlock
的 跳躍連接 輸出 (串聯在一起)。
- 解碼器路徑前一層的輸出: 來自解碼器路徑中 前一個
簡而言之,在解碼器路徑中,每個 UnetSkipConnectionBlock
的輸入 x
都來自兩個部分: 前一個解碼器 UnetSkipConnectionBlock
的輸出,以及編碼器路徑中對應層的 UnetSkipConnectionBlock
的跳躍連接輸出 (串聯在一起)。
資料流動示意 (解碼器路徑與跳躍連接):
Encoder Path:
UnetSkipConnectionBlock_Encoder_1 --> Feature Map Encoder 1 ------- (Skip Connection) ------> UnetSkipConnectionBlock_Decoder_1 Input (part 2)
UnetSkipConnectionBlock_Encoder_2 --> Feature Map Encoder 2 ------- (Skip Connection) ------> UnetSkipConnectionBlock_Decoder_2 Input (part 2)
UnetSkipConnectionBlock_Encoder_3 --> Feature Map Encoder 3 ------- (Skip Connection) ------> UnetSkipConnectionBlock_Decoder_3 Input (part 2)
...
Decoder Path:
(Bottleneck Output) --> UnetSkipConnectionBlock_Decoder_1 Input (part 1)
UnetSkipConnectionBlock_Decoder_1 (input: [Input Part 1, Input Part 2 (Skip Connection from Encoder 1)], output: Feature Map Decoder 1)
--> UnetSkipConnectionBlock_Decoder_2 Input (part 1)
UnetSkipConnectionBlock_Decoder_2 (input: [Input Part 1, Input Part 2 (Skip Connection from Encoder 2)], output: Feature Map Decoder 2)
--> UnetSkipConnectionBlock_Decoder_3 Input (part 1)
UnetSkipConnectionBlock_Decoder_3 (input: [Input Part 1, Input Part 2 (Skip Connection from Encoder 3)], output: Feature Map Decoder 3)
--> ... (解碼器路徑的後續 UnetSkipConnectionBlock)
--> Final Output Image (透過解碼器路徑的最後一層)
3. 最內層 (Bottleneck) 的 UnetSkipConnectionBlock
(如果存在):
有些 UNet 架構會在編碼器路徑和解碼器路徑之間有一個 最內層 (Innermost) 的 UnetSkipConnectionBlock
,也稱為 Bottleneck (瓶頸層)。
- Bottleneck
UnetSkipConnectionBlock
: Bottleneck 層的UnetSkipConnectionBlock
的forward
函數的輸入x
來自於 編碼器路徑中最後一個UnetSkipConnectionBlock
的輸出。 它的輸出則會直接作為 解碼器路徑中第一個UnetSkipConnectionBlock
的輸入之一 (解碼器路徑前一層的輸出部分)。 Bottleneck 層本身 不接收 來自跳躍連接的輸入,但它 會提供輸出 作為跳躍連接的一部分,傳遞給解碼器路徑的第一個UnetSkipConnectionBlock
。
資料流動示意 (Bottleneck):
Encoder Path (Last Block) --> Bottleneck UnetSkipConnectionBlock Input
Bottleneck UnetSkipConnectionBlock (input: Encoder Path Last Block Output, output: Bottleneck Output)
--> Decoder Path (First Block) Input (Part 1)
總結:UnetSkipConnectionBlock
輸入 x
的來源
UnetSkipConnectionBlock 位置 | forward(self, x) 的輸入 x 的來源 |
---|---|
編碼器路徑 (Encoder Path) | |
第一個 Block | UNet 網路的 原始輸入圖像 (Input Image) |
後續 Block (非第一個) | 前一個 UnetSkipConnectionBlock 的輸出 |
解碼器路徑 (Decoder Path) | |
第一個 Block (Bottleneck 後) | 1. Bottleneck 層的輸出 (解碼器路徑前一層的輸出) <br/> 2. 來自 編碼器路徑中對應層 UnetSkipConnectionBlock 的跳躍連接輸出 |
後續 Block (非第一個, 包括輸出層前的) | 1. 前一個解碼器路徑 UnetSkipConnectionBlock 的輸出 (解碼器路徑前一層的輸出) <br/> 2. 來自 編碼器路徑中對應層 UnetSkipConnectionBlock 的跳躍連接輸出 |
Bottleneck 層 (如果存在) | 編碼器路徑中最後一個 UnetSkipConnectionBlock 的輸出 |
最外層 (Outermost) Block | 1. 前一個解碼器路徑 UnetSkipConnectionBlock 的輸出 (解碼器路徑前一層的輸出) <br/> 2. 來自 編碼器路徑中對應層 UnetSkipConnectionBlock 的跳躍連接輸出 (與後續 Block 相同) |
更具體的程式碼範例 (概念性):
為了更清楚地理解,以下提供一個非常簡化的 UNet 模型程式碼範例,展示 UnetSkipConnectionBlock
如何連接以及資料如何流動:
Python
import torch
import torch.nn as nn
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_channels, inner_channels, submodule=None): # Simplified for demonstration
super(UnetSkipConnectionBlock, self).__init__()
self.conv = nn.Conv2d(inner_channels, outer_channels, kernel_size=3, padding=1)
self.submodule = submodule # Submodule for deeper UNet
def forward(self, x):
if self.submodule is not None:
x_submodule = self.submodule(x) # 遞迴呼叫 submodule
x_output = self.conv(x_submodule)
return torch.cat([x, x_output], dim=1) # Skip connection (concatenate input x and output)
else: # Innermost block
return self.conv(x)
class SimpleUNet(nn.Module):
def __init__(self, input_channels, output_channels):
super(SimpleUNet, self).__init__()
# Encoder Path
self.encoder_block1 = UnetSkipConnectionBlock(64, input_channels) # First encoder block, input is original image
self.encoder_block2 = UnetSkipConnectionBlock(128, 64, submodule=self.encoder_block1) # Input from previous encoder block
self.encoder_block3 = UnetSkipConnectionBlock(256, 128, submodule=self.encoder_block2)
# Bottleneck (Innermost - in this simplified example, encoder_block3 acts as bottleneck too)
self.bottleneck = self.encoder_block3 # In this example, last encoder block is also the bottleneck
# Decoder Path
self.decoder_block1 = UnetSkipConnectionBlock(128, 256, submodule=self.bottleneck) # Input: Bottleneck output, Skip from encoder_block3
self.decoder_block2 = UnetSkipConnectionBlock(64, 128, submodule=self.decoder_block1) # Input: Previous decoder block, Skip from encoder_block2
self.decoder_block3 = UnetSkipConnectionBlock(output_channels, 64, submodule=self.decoder_block2) # Input: Previous decoder block, Skip from encoder_block1
self.output_conv = nn.Conv2d(output_channels, output_channels, kernel_size=1) # Output layer
def forward(self, x):
# Encoder Path
encoder_output1 = self.encoder_block1(x)
encoder_output2 = self.encoder_block2(x)
encoder_output3 = self.encoder_block3(x) # encoder_output3 is also the bottleneck output
# Decoder Path
decoder_input1 = encoder_output3 # Bottleneck output as first decoder input
decoder_output1 = self.decoder_block1(decoder_input1) # decoder_block1 input is decoder_input1 + skip from encoder_block3 (in block's forward)
decoder_input2 = decoder_output1 # Output of decoder_block1 as input to next decoder block
decoder_output2 = self.decoder_block2(decoder_input2) # decoder_block2 input is decoder_input2 + skip from encoder_block2
decoder_input3 = decoder_output2 # Output of decoder_block2 as input to next decoder block
decoder_output3 = self.decoder_block3(decoder_input3) # decoder_block3 input is decoder_input3 + skip from encoder_block1
final_output = self.output_conv(decoder_output3)
return final_output
# Example usage
input_channels = 3
output_channels = 3
image_size = 256
batch_size = 1
unet_model = SimpleUNet(input_channels=input_channels, output_channels=output_channels)
dummy_input = torch.randn(batch_size, input_channels, image_size, image_size)
output = unet_model(dummy_input)
print("UNet Output Shape:", output.shape) # Expected output shape: torch.Size([1, 3, 256, 256])
程式碼範例解釋:
SimpleUNet
類別: 定義了一個簡化的 UNet 模型。- 編碼器路徑:
encoder_block1
,encoder_block2
,encoder_block3
組成編碼器路徑,資料逐層向下傳遞。 - Bottleneck:
encoder_block3
同時作為 Bottleneck 層。 - 解碼器路徑:
decoder_block1
,decoder_block2
,decoder_block3
組成解碼器路徑,資料逐層向上傳遞,並接收來自編碼器路徑的跳躍連接 (雖然在SimpleUNet
的forward
函數中沒有顯式地看到跳躍連接的串聯操作,但UnetSkipConnectionBlock
的forward
函數中已經包含了torch.cat
串聯)。 forward
函數: 清晰地展示了資料在編碼器路徑、Bottleneck 層和解碼器路徑之間的流動,以及每個UnetSkipConnectionBlock
的輸入x
的來源。
總結:
理解 UnetSkipConnectionBlock
的 forward(self, x)
函數的輸入 x
的來源,需要結合 UNet 架構的整體資料流動來看。 在編碼器路徑中,x
主要來自前一層的輸出;在解碼器路徑中,x
則來自前一層解碼器輸出 以及 編碼器路徑對應層的跳躍連接。 透過這種跳躍連接的結構,UNet 能夠有效地融合多層次的特徵信息,生成更精細、更準確的圖像。 查看 UNet 模型的完整程式碼 (例如 zi2zi-pytorch
專案中的 generators.py
或相關檔案),結合 UNet 架構示意圖,可以更直觀地理解 UnetSkipConnectionBlock
在整個網路中的資料輸入和輸出關係。