1:
class ResSkip(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
nn.init.kaiming_normal_(self.conv.weight)
self.relu = nn.ReLU(inplace=True) # 您可以根據需要調整激活函數
def forward(self, x):
residual = x
out = self.conv(x)
out = self.relu(out)
out += residual
return out
2:
class ResSkip(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.depthwise = nn.Conv2d(in_channels, in_channels, 3, padding=1, groups=in_channels)
self.pointwise = nn.Conv2d(in_channels, out_channels, 1)
self.norm = nn.InstanceNorm2d(out_channels)
self.act = nn.SiLU()
self.skip = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
def forward(self, x):
identity = self.skip(x)
out = self.depthwise(x)
out = self.pointwise(out)
out = self.act(self.norm(out))
return out + identity
這兩個 ResSkip
class 在設計理念和內部結構上存在顯著的差異,主要體現在它們如何處理輸入並引入殘差連接。以下是詳細的比較:
第一個 ResSkip
class:
- 結構:
- 包含一個標準的二維卷積層 (
nn.Conv2d
),kernel size 為 3,stride 為 1,padding 為 1。 - 使用 Kaiming Normal 初始化卷積層的權重。
- 在卷積操作後應用 ReLU 激活函數 (
nn.ReLU
). - 殘差連接: 直接將原始輸入
x
加到經過卷積和激活後的輸出上 (out += residual
)。這是一種最基本的殘差連接形式。
- 包含一個標準的二維卷積層 (
- 運作方式: 這個 block 的主要操作是對輸入進行一次 3×3 的卷積,然後將結果與原始輸入相加。它旨在學習輸入的殘差(即輸入與期望輸出的差異),並將其添加到原始輸入中,從而更容易訓練更深的網路。
- 參數效率: 相對簡單,參數數量取決於輸入和輸出通道數以及卷積核的大小。
- 設計目標: 專注於通過一個卷積層學習殘差特徵。
第二個 ResSkip
class:
- 結構:
- 使用深度可分離卷積 (Depthwise Separable Convolution):
depthwise
: 對每個輸入通道獨立進行 3×3 卷積 (groups=in_channels
)。pointwise
: 使用 1×1 卷積 (nn.Conv2d
) 混合不同通道的特徵。
- 包含一個
nn.InstanceNorm2d
進行實例歸一化。 - 使用 SiLU 激活函數 (
nn.SiLU
). - 跳躍連接 (
self.skip
): 使用一個 1×1 的卷積層 (nn.Conv2d
) 來調整輸入的通道數,使其與 block 的輸出通道數相匹配。如果輸入和輸出通道數相同,則使用nn.Identity()
,即直接跳過。 - 殘差連接: 將經過深度可分離卷積、歸一化和激活後的輸出與經過
self.skip
處理的輸入相加 (return out + identity
)。
- 使用深度可分離卷積 (Depthwise Separable Convolution):
- 運作方式: 這個 block 使用更複雜的深度可分離卷積來提取特徵,這通常比標準卷積更有效率。它還包含了歸一化層和不同的激活函數。跳躍連接
self.skip
的作用是確保殘差連接可以正確地進行元素級相加,即使輸入和輸出的通道數不同。 - 參數效率: 由於使用了深度可分離卷積,通常比具有相同輸入和輸出通道數的標準卷積層擁有更少的參數。
- 設計目標: 更注重參數效率和更精細的特徵提取,同時確保在通道數變化時也能進行有效的殘差連接。
總結差異:
特徵 | 第一個 ResSkip | 第二個 ResSkip |
卷積類型 | 標準 nn.Conv2d (3×3) | 深度可分離卷積 (depthwise 3×3 + pointwise 1×1) |
歸一化 | 無 | nn.InstanceNorm2d |
激活函數 | ReLU | SiLU |
跳躍連接 | 直接相加原始輸入 (通道數必須相同) | 使用 1×1 卷積調整通道數 (self.skip ),不同通道數也可相加 |
參數效率 | 相對較低 (與標準卷積相同) | 通常更高 (深度可分離卷積的優勢) |
複雜度 | 較簡單 | 較複雜 |
設計重點 | 學習殘差特徵 | 參數效率、精細特徵提取、處理通道數變化 |
選擇哪個 ResSkip
取決於您的具體需求:
- 如果您更注重模型的簡單性和快速實現,並且輸入和輸出通道數在殘差 block 中保持一致,那麼第一個
ResSkip
可能更適合。 - 如果您對模型的參數效率有更高的要求,或者需要在殘差 block 中改變通道數,並且希望使用更現代的卷積方式和歸一化技術,那麼第二個
ResSkip
可能會更好。
第二個 ResSkip
的設計更符合一些現代輕量級網路的風格,通過深度可分離卷積減少計算量和參數量,並通過 1×1 卷積和跳躍連接增強模型的靈活性。