將 PixelShuffleUpBlock 中的 conv 替換為 ResSkip

希望將 PixelShuffleUpBlock 中的 self.conv_out 替換為 self.residual = ResSkip(out_channels, out_channels),並且嘗試保留 self.conv_out 中已訓練的權重。

由於 ResSkip 通常的結構是將輸入與經過一個或多個卷積層處理後的輸出相加,它的權重結構與一個單純的 nn.Conv2d (您的 self.conv_out) 是不同的,因此無法直接將 self.conv_out 的權重完全轉移到 ResSkip 的卷積層中並保持其原有的功能。

然而,您可以嘗試將 self.conv_out 的權重作為 ResSkip 中第一個卷積層的初始權重,這或許能在一定程度上保留一些已學習到的特徵。

以下是修改後的 PixelShuffleUpBlock class:

Python

import torch
import torch.nn as nn

class ResSkip(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        nn.init.kaiming_normal_(self.conv.weight)
        self.relu = nn.ReLU(inplace=True) # 您可以根據需要調整激活函數

    def forward(self, x):
        residual = x
        out = self.conv(x)
        out = self.relu(out)
        out += residual
        return out

class PixelShuffleUpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, norm_layer=nn.InstanceNorm2d, use_bias=False):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels * 4, kernel_size=3, stride=1, padding=1, bias=use_bias)
        nn.init.kaiming_normal_(self.conv.weight)
        self.pixel_shuffle = nn.PixelShuffle(2)
        self.residual = ResSkip(out_channels, out_channels)
        self.norm = norm_layer(out_channels)
        self.conv_out_weight = None # 用於暫存 conv_out 的權重

    def forward(self, x):
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        if self.conv_out_weight is not None:
            # 手動將之前 conv_out 的權重載入到 residual 的第一個卷積層
            with torch.no_grad():
                if self.residual.conv.weight.shape == self.conv_out_weight.shape:
                    self.residual.conv.weight.data.copy_(self.conv_out_weight.data)
                    self.conv_out_weight = None # 權重已載入,清空暫存
                else:
                    print("警告:residual 的卷積層權重形狀與原始 conv_out 不匹配,無法載入權重。")
        x = self.residual(x)
        x = self.norm(x)
        return x

    def load_conv_out_weight(self, weight):
        """用於載入原始 conv_out 的權重"""
        self.conv_out_weight = weight

UnetSkipConnectionBlock 中的修改:

您需要修改 UnetSkipConnectionBlock 中創建 PixelShuffleUpBlock 的部分,使其不再包含 conv_out,並在模型加載權重後,將原始 conv_out 的權重載入到 PixelShuffleUpBlockresidual 層中。

Python

class UnetSkipConnectionBlock(nn.Module):
    # ... (__init__ 方法保持不變,但創建 pixelshuffle 部分需要修改)

    # === outermost 層 ===
    if outermost:
        # ...
        elif self.up_mode == 'pixelshuffle':
            upconv = PixelShuffleUpBlock(inner_nc * 2, outer_nc, norm_layer, use_bias)
            self.up = nn.Sequential(uprelu, upconv, nn.Tanh())
        # ...

    # === innermost 層 ===
    elif innermost:
        # ...
        elif self.up_mode == 'pixelshuffle':
            upconv = PixelShuffleUpBlock(inner_nc, outer_nc, norm_layer, use_bias)
            self.up = nn.Sequential(uprelu, upconv, upnorm)
        # ...

    # === 中間層 ===
    else:
        # ...
        elif self.up_mode == 'pixelshuffle':
            upconv = PixelShuffleUpBlock(inner_nc * 2, outer_nc, norm_layer, use_bias)
            self.up = nn.Sequential(uprelu, upconv, upnorm)
        # ...

載入權重的步驟:

  1. 加載預訓練模型的權重: 使用 torch.load() 加載您之前訓練好的模型權重。
  2. 找到原始 conv_out 的權重: 遍歷模型的 state_dict(),找到對應於 PixelShuffleUpBlockconv_out 層的權重。這通常會是一個類似於 module.up.1.conv_out.weight 的鍵(具體路徑取決於您的模型結構)。
  3. 將權重載入到新的 PixelShuffleUpBlock 中: 找到新的 PixelShuffleUpBlock 實例,並調用其 load_conv_out_weight() 方法,將之前找到的權重傳遞給它。您需要在模型加載權重後進行這個操作。

Python

# 假設您已經加載了預訓練模型的權重到 `pretrained_dict`
pretrained_dict = torch.load('your_pretrained_model.pth')
model = YourUnetModel(...) # 創建您的 U-Net 模型

# 獲取模型當前的 state_dict
model_dict = model.state_dict()

# 更新模型權重
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

# 找到 PixelShuffleUpBlock 的實例並載入權重
for name, module in model.named_modules():
    if isinstance(module, PixelShuffleUpBlock):
        # 假設原始 conv_out 的權重路徑是 'module.up.1.conv_out.weight'
        # 您需要根據您的實際模型結構調整這個路徑
        conv_out_weight_key = name + '.conv_out.weight'
        if conv_out_weight_key in pretrained_dict:
            module.load_conv_out_weight(pretrained_dict[conv_out_weight_key])
            print(f"成功載入 {name} 的 conv_out 權重到 residual 層。")
        else:
            print(f"警告:在預訓練權重中找不到 {conv_out_weight_key}。")

# 現在您的模型中,PixelShuffleUpBlock 的 residual 層已經嘗試載入了原始 conv_out 的權重

重要注意事項:

  • 權重形狀匹配: 上述程式碼中包含了對權重形狀的檢查。如果原始 conv_out 的權重形狀與 ResSkip 中第一個卷積層的權重形狀不匹配(這在通道數不同的情況下很可能發生),則無法直接載入。
  • 功能差異: 即使權重形狀匹配並成功載入,ResSkip 的前向傳播與單純的 nn.Conv2d 是不同的(它會加上殘差連接和激活函數)。因此,替換後的行為可能與原始的 conv_out 不完全一致,您可能需要進行微調以獲得最佳結果。
  • 激活函數: ResSkip 中我添加了 ReLU 激活函數。如果原始 conv_out 後面沒有激活函數,或者使用了不同的激活函數,您可能需要調整 ResSkip 中的激活函數以更接近原始行為。

總之,雖然無法完美地將 conv_out 的權重轉移到 ResSkip 並保持完全相同的推論結果,但將其權重作為 ResSkip 中卷積層的初始權重是一個值得嘗試的方法。之後,您可能需要根據實際的推論結果進行微調。

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *