在 PixelShuffle 後新增 smoothing conv卷積層

以下是已完整整合 PixelShuffle 後加 smoothing conv 層UnetSkipConnectionBlock class 版本,保留你原本所有邏輯,僅針對 up_mode == 'pixelshuffle' 的地方進行擴充修改:


✅ 完整改寫後的 UnetSkipConnectionBlock

class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None,
                 norm_layer=nn.InstanceNorm2d, layer=0, embedding_dim=64,
                 use_dropout=False, self_attention=False, attention_type='linear',
                 blur=False, outermost=False, innermost=False, use_transformer=False,
                 attn_layers=None, up_mode='conv', freeze_downsample=False):
        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost
        self.innermost = innermost
        self.layer = layer
        self.attn_layers = attn_layers or []
        self.up_mode = up_mode
        self.freeze_downsample = freeze_downsample

        use_bias = norm_layer != nn.BatchNorm2d
        if input_nc is None:
            input_nc = outer_nc

        kernel_size = 3 if innermost else 4
        stride = 1 if innermost else 2
        padding = 1

        self.downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=kernel_size, stride=stride, padding=padding, bias=use_bias)
        nn.init.kaiming_normal_(self.downconv.weight, nonlinearity='leaky_relu')

        self.downrelu = nn.SiLU(inplace=True)
        self.downnorm = norm_layer(inner_nc)
        self.uprelu = nn.SiLU(inplace=True)
        self.upnorm = norm_layer(outer_nc)

        # UP path construction
        if outermost:
            in_channels = inner_nc * 2
        elif innermost:
            in_channels = inner_nc
        else:
            in_channels = inner_nc * 2

        if self.up_mode == 'conv':
            self.upconv = nn.ConvTranspose2d(in_channels, outer_nc, kernel_size=4, stride=2, padding=1, output_padding=1, bias=use_bias)
            nn.init.kaiming_normal_(self.upconv.weight)
        elif self.up_mode == 'upsample':
            self.upconv = nn.Sequential(
                nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
                nn.Conv2d(in_channels, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias),
                self.upnorm
            )
            nn.init.kaiming_normal_(self.upconv[1].weight)
        elif self.up_mode == 'pixelshuffle':
            self.upconv = nn.Sequential(
                nn.Conv2d(in_channels, outer_nc * 4, kernel_size=3, stride=1, padding=1, bias=use_bias),
                nn.PixelShuffle(2),
                nn.Conv2d(outer_nc, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias),  # << smoothing conv
                self.upnorm
            )
            nn.init.kaiming_normal_(self.upconv[0].weight)
            nn.init.kaiming_normal_(self.upconv[2].weight)
        else:
            raise ValueError(f"Unsupported up_mode: {self.up_mode}. Choose 'conv', 'upsample', or 'pixelshuffle'.")

        # Layer logic
        if outermost:
            self.down = nn.Sequential(self.downconv)
            self.up = nn.Sequential(self.uprelu, self.upconv, nn.Tanh())
        elif innermost:
            self.down = nn.Sequential(self.downrelu, self.downconv)
            self.up = nn.Sequential(self.uprelu, self.upconv, self.upnorm)
            if use_transformer:
                self.transformer_block = TransformerBlock(inner_nc)
            self.film = FiLMModulation(inner_nc, embedding_dim)
        else:
            self.down = nn.Sequential(self.downrelu, self.downconv, self.downnorm)
            self.up = nn.Sequential(self.uprelu, self.upconv, self.upnorm)
            if use_dropout:
                self.up.add_module("dropout", nn.Dropout(0.3))

        self.submodule = submodule

        if self_attention and self.layer in self.attn_layers:
            self.attn_block = LinearAttention(inner_nc) if attention_type == 'linear' else SelfAttention(inner_nc)
        else:
            self.attn_block = None

        if not outermost and not innermost and layer in [4, 5, 6, 7]:
            self.res_skip = ResSkip(outer_nc, outer_nc) if hasattr(self, 'res_skip') else None

        if self.freeze_downsample:
            for param in self.downconv.parameters():
                param.requires_grad = False
            for param in self.downnorm.parameters():
                param.requires_grad = False

    def forward(self, x, style=None):
        if hasattr(self, 'attn_block') and self.attn_block is not None:
            x = self.attn_block(x)
        encoded = self.down(x)
        if self.innermost:
            if hasattr(self, 'transformer_block'):
                encoded = self.transformer_block(encoded)
            if hasattr(self, 'film') and style is not None:
                encoded = self.film(encoded, style)
            decoded = self.up(encoded)
            if decoded.shape[2:] != x.shape[2:]:
                decoded = F.interpolate(decoded, size=x.shape[2:], mode='bilinear', align_corners=False)
            if hasattr(self, 'res_skip') and self.res_skip is not None:
                decoded = self.res_skip(decoded)
            return torch.cat([x, decoded], 1), encoded.contiguous().view(x.shape[0], -1)
        else:
            sub_output, encoded_real_A = self.submodule(encoded, style)
            decoded = self.up(sub_output)
            if decoded.shape[2:] != x.shape[2:]:
                decoded = F.interpolate(decoded, size=x.shape[2:], mode='bilinear', align_corners=False)
            if hasattr(self, 'res_skip') and self.res_skip is not None:
                decoded = self.res_skip(decoded)
            if self.outermost:
                return decoded, encoded_real_A
            else:
                return torch.cat([x, decoded], 1), encoded_real_A

這樣就完成囉,你可以很放心地使用 up_mode='pixelshuffle',同時享受 PixelShuffle 的效率與平滑後的輸出品質!

需要我也順便幫你更新整個 UNetGenerator 邏輯或是加入新的參數開關這個 smoothing 的話,也可以說~


PixelShuffle 後面加入一個卷積層 (nn.Conv2d(outer_nc, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias)) 會導致該層及其後續層的權重無法直接從使用 ConvTranspose2d 的舊模型中保留。

原因如下:

  1. 層結構不匹配: 你在 self.upconv 這個 Sequential 容器中插入了一個新的卷積層。這改變了模型的層結構。當你嘗試載入舊模型的權重時,新模型中這個額外的卷積層在舊模型中沒有對應的層。
  2. 權重形狀不匹配: 即使某些層在名稱上可能相似(例如,self.upnorm),但由於中間多了一個卷積層,它們在模型中的順序和連接方式已經不同。這會導致權重的形狀和意義不再對應。

哪些權重可能會被保留?

理論上,在修改之前的層的權重仍然可以被載入,只要它們的名稱和形狀在舊模型和新模型中保持一致。

例如,在 UnetSkipConnectionBlockoutermost 分支中,self.downconv 的權重如果名稱和形狀沒有改變,仍然可以被載入。

總結來說:

  • 直接載入所有權重將會失敗,因為模型結構已經改變。PyTorch 的權重載入機制通常依賴於模型 state_dict() 中鍵(層的名稱)的完全匹配。
  • 你需要編寫自定義的權重載入邏輯,該邏輯能夠:
    • 比對舊模型和新模型的層名稱。
    • 僅載入名稱和形狀都匹配的層的權重。
    • 跳過初始化新插入的卷積層 (smoothing conv) 及其後續層(例如,在這個 Sequential 中的 self.upnorm,因為它的輸入通道數現在來自於 smoothing conv 的輸出,而不是 PixelShuffle 的直接輸出)。

如何部分保留權重 (需要自定義程式碼):

你可以遍歷舊模型的 state_dict(),對於每一個權重,檢查新模型是否有名稱和形狀都匹配的層。如果匹配,則將舊權重載入到新模型對應的層中。對於新插入的層,你需要使用 PyTorch 的初始化方法(例如 nn.init.kaiming_normal_)來初始化它們的權重。

範例 (簡化概念):

Python

def load_partial_weights(model, pretrained_dict):
    model_dict = model.state_dict()
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and model_dict[k].shape == v.shape}
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)
    print(f"Loaded {len(pretrained_dict)}/{len(model_dict)} parameters.")

# 假設 old_model 是之前的模型,你已經載入了它的權重到 old_state_dict
# new_model 是你修改後的模型
# old_state_dict = torch.load('path/to/old_weights.pth')
# new_model = UNetGenerator(...) # 你的新模型

# load_partial_weights(new_model, old_state_dict)

這個 load_partial_weights 函數只會載入在新舊模型中名稱和形狀都匹配的權重。你需要確保在載入後,新模型中未被載入權重的層(包括你新增的 smoothing conv 和其後的層)會被正確地初始化。

因此,雖然在 PixelShuffle 後加入卷積層會破壞直接的權重保留,但你可以通過編寫自定義的載入邏輯來盡可能地保留模型中結構未改變部分的權重。


好,我幫你把 smoothing conv 做成特殊 case,初始化成「接近 identity」的卷積(中心值為 1,其餘為 0),並整合進你的 _initialize_unmatched_weights()。這樣一來,新模型第一次載入舊 checkpoint 時,smoothing_conv 就不會影響輸出行為太多,還能順利 fine-tune。


✅ 修改後的 _initialize_unmatched_weights()

def _initialize_unmatched_weights(self, model, loaded_state_dict, model_name="Model"):
    model_state = model.state_dict()

    def init_smoothing_conv_as_identity(conv_param, bias_param=None):
        """將 smoothing conv 初始化為接近 identity(中心為 1,其餘為 0)"""
        if conv_param.shape[2:] == (3, 3) and conv_param.shape[0] == conv_param.shape[1]:
            with torch.no_grad():
                conv_param.zero_()
                c = conv_param.shape[0]
                for i in range(c):
                    conv_param[i, i, 1, 1] = 1.0  # 對角線中心位置為 1
                if bias_param is not None:
                    bias_param.zero_()
            return True
        return False

    for name, param in model.named_parameters():
        need_init = (
            name not in loaded_state_dict or
            model_state[name].shape != loaded_state_dict[name].shape
        )

        if need_init:
            print(f" ⚠️  Re-initializing param: {model_name}.{name}, shape: {param.shape}")

            # 檢查是否是 smoothing conv(名稱可依你具體命名風格微調)
            if "smoothing" in name and "weight" in name:
                # 嘗試找對應的 bias
                bias_name = name.replace("weight", "bias")
                bias_param = model_state.get(bias_name, None)
                matched = init_smoothing_conv_as_identity(param, bias_param)
                if matched:
                    print(f" ✅  Initialized {model_name}.{name} as identity smoothing conv")
                    continue  # 跳過預設初始化

            # fallback 一般初始化
            if "weight" in name:
                nn.init.kaiming_normal_(param.data, mode='fan_out', nonlinearity='leaky_relu')
            elif "bias" in name:
                nn.init.constant_(param.data, 0)

    for name, buffer in model.named_buffers():
        if name not in loaded_state_dict or model_state[name].shape != loaded_state_dict[name].shape:
            print(f" ⚠️  Re-initializing buffer: {model_name}.{name}, shape: {buffer.shape}")
            buffer.data.zero_()

💡 備註

這段程式碼會自動偵測你用 smoothing 關鍵字命名的 conv 層,例如:

self.smoothing = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias)

如果你是把它包進 upconv 裡的 nn.Sequential 內,可以用 add_module("smoothing", conv) 這樣保持命名清楚(讓 state_dict key 包含 smoothing 以利辨認)。


如果你希望我順便幫你調整 UnetSkipConnectionBlocksmoothing conv 包裝好、命名明確,也可以跟我說~

如果模型的修改是只加入一層的情況下, 可以直接這樣子寫, 就可以完美的搬移舊的權重到新的模型:

    def extract_keywords(self, name):
        KEYWORD_MATCH_RULES = ["down", "up", "conv", "res", "encoder", "decoder", "self", "line"]
        return set([k for k in KEYWORD_MATCH_RULES if k in name])

    def extract_layer_name(self, name):
        parts = name.split('.')
        if parts:
            return parts[0]
        return name

    def _initialize_unmatched_weights(self, model, loaded_state_dict, model_name="Model"):
        model_state = model.state_dict()
        used_keys = set()

        shape_to_loaded_keys = {}
        name_to_layer = {}
        name_to_keywords = {}

        for k, v in loaded_state_dict.items():
            shape_to_loaded_keys.setdefault(v.shape, []).append(k)
            name_to_layer[k] = self.extract_layer_name(k)
            name_to_keywords[k] = self.extract_keywords(k)

        for name, param in model.named_parameters():
            full_name = name
            current_layer = self.extract_layer_name(full_name)
            current_keywords = self.extract_keywords(full_name)

            #print(f" Loading param (name - shape): {model_name}.{full_name} - {param.shape}")

            if full_name in loaded_state_dict and param.shape == loaded_state_dict[full_name].shape:
                param.data.copy_(loaded_state_dict[full_name])
                used_keys.add(full_name)
            else:
                matched = False
                candidate_keys = shape_to_loaded_keys.get(param.shape, [])
                for candidate in candidate_keys:
                    if candidate in used_keys:
                        continue

                    candidate_layer = name_to_layer.get(candidate)
                    candidate_keywords = name_to_keywords.get(candidate, set())

                    # 層級名稱與語意關鍵字需一致
                    if candidate_layer == current_layer and current_keywords & candidate_keywords:
                        print(f" Loading param (name - shape): {model_name}.{full_name} - {param.shape}")
                        print(f"  --> Shape & layer & keyword match. Copying from {candidate}")
                        param.data.copy_(loaded_state_dict[candidate])
                        used_keys.add(candidate)
                        matched = True
                        break

                if not matched:
                    print(f" Loading param (name - shape): {model_name}.{full_name} - {param.shape}")
                    print(f"  --> No suitable match found. Re-initializing param: {model_name}.{full_name}")
                    if "weight" in full_name:
                        # 暫時性的模型增加 conv.
                        init_smoothing_conv = True
                        if init_smoothing_conv:
                            # 嘗試找對應的 bias
                            bias_name = name.replace("weight", "bias")
                            bias_param = model_state.get(bias_name, None)
                            matched = self.init_smoothing_conv_as_identity(param, bias_param)
                            if matched:
                                print(f" ✅  Initialized {model_name}.{name} as identity smoothing conv")
                                continue  # 跳過預設初始化

                        nn.init.kaiming_normal_(param.data, mode='fan_out', nonlinearity='leaky_relu')
                    elif "bias" in full_name:
                        nn.init.constant_(param.data, 0)

        for name, buffer in model.named_buffers():
            if name not in loaded_state_dict or model_state[name].shape != loaded_state_dict[name].shape:
                print(f" Re-initializing buffer (shape mismatch or missing): {model_name}.{name}")
                buffer.data.zero_()

    def init_smoothing_conv_as_identity(self, conv_param, bias_param=None):
        """將 smoothing conv 初始化為接近 identity(中心為 1,其餘為 0)"""
        if conv_param.shape[2:] == (3, 3) and conv_param.shape[0] == conv_param.shape[1]:
            with torch.no_grad():
                conv_param.zero_()
                c = conv_param.shape[0]
                for i in range(c):
                    conv_param[i, i, 1, 1] = 1.0  # 對角線中心位置為 1
                if bias_param is not None:
                    bias_param.zero_()
            return True
        return False

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *