initialize_unmatched_weights 遇到內容相同或名稱修改

你目前的程式碼是使用 strict=False 來略過一些名稱或 shape 不一致的情況,並透過 _initialize_unmatched_weights() 去補上沒被載入的部分。不過,你的需求是這樣的:

如果權重的 shape 是一樣的,只是變數名稱不同,仍希望延用之前儲存的權重。

這種情況,可以透過「手動比對 shape 相符的權重」來實現 name mapping + shape matching,再把對應的權重載入。下面提供一個改進版的 _initialize_unmatched_weights() 方法,會嘗試找出 shape 一樣但名字不同的參數來對應載入。


✅ 改進的 _initialize_unmatched_weights 實作:

def _initialize_unmatched_weights(self, model, loaded_state_dict, model_name="Model"):
    model_state = model.state_dict()
    used_keys = set()

    # 建立 shape 對應的反查表
    shape_to_loaded_keys = {}
    for k, v in loaded_state_dict.items():
        shape_to_loaded_keys.setdefault(v.shape, []).append(k)

    for name, param in model.named_parameters():
        full_name = name
        print(f" Loading param (name - shape): {model_name}.{full_name} - {param.shape}")

        if full_name in loaded_state_dict and param.shape == loaded_state_dict[full_name].shape:
            # 直接 match
            param.data.copy_(loaded_state_dict[full_name])
            used_keys.add(full_name)
        else:
            matched = False
            candidate_keys = shape_to_loaded_keys.get(param.shape, [])
            for candidate in candidate_keys:
                if candidate not in used_keys:
                    print(f"  --> Shape matched. Copying from {candidate}")
                    param.data.copy_(loaded_state_dict[candidate])
                    used_keys.add(candidate)
                    matched = True
                    break
            if not matched:
                print(f"  --> No match found. Re-initializing param: {model_name}.{full_name}")
                if "weight" in full_name:
                    nn.init.kaiming_normal_(param.data, mode='fan_out', nonlinearity='leaky_relu')
                elif "bias" in full_name:
                    nn.init.constant_(param.data, 0)

    for name, buffer in model.named_buffers():
        if name not in loaded_state_dict or model_state[name].shape != loaded_state_dict[name].shape:
            print(f" Re-initializing buffer (shape mismatch or missing): {model_name}.{name}")
            buffer.data.zero_()

🔍 解釋一下這段邏輯:

  1. 建立 shape 到名稱的反查表:方便之後去找哪些舊模型的參數 shape 是一樣的。
  2. 如果名稱一致就直接用,這是正常情況。
  3. 名稱不一致時,比對 shape 相符的未使用參數,找到第一個 shape 一樣的來用。
  4. 如果還是找不到,就重新初始化。

✅ 你還需要確保:

  • 模型 architecture 沒有改變太多(例如層數一樣,不然會有不合理的 shape match)。
  • 舊模型和新模型之間的參數 shape 真的是可交換的(有語意一致性)。

這樣就可以在 variable 名稱變更的情況下繼續使用原本的權重了,儘可能延用 pretrained model 的知識。


需要我幫你整合進原本的 load_networks() 裡嗎?還是有特定模型要幫忙分析 name mapping?


好,這邊幫你做了完整整合,並 加入了層數檢查機制,避免 shape 雖然一樣但其實不屬於同一層的情況(像是 conv0 和 conv5 不應該交換權重)。

✅ 說明:

  • 比對變數 shape 相符。
  • 層數名稱需部分對得上(例如都是 down.0.weightdownconv.weight,這邊會允許匹配 downdownconv)。
  • 匹配成功才載入,否則會重新初始化。
  • 已整合進你的 load_networks()

🔧 完整整合版本:

import torch
import torch.nn as nn
import os

def extract_layer_name(name):
    """抽取層級資訊作為匹配依據,例如 down.0.weight -> down"""
    parts = name.split('.')
    if parts:
        return parts[0]
    return name

def _initialize_unmatched_weights(self, model, loaded_state_dict, model_name="Model"):
    model_state = model.state_dict()
    used_keys = set()

    # 建立 shape -> [names] 的映射表
    shape_to_loaded_keys = {}
    name_to_layer = {}

    for k, v in loaded_state_dict.items():
        shape_to_loaded_keys.setdefault(v.shape, []).append(k)
        name_to_layer[k] = extract_layer_name(k)

    for name, param in model.named_parameters():
        full_name = name
        current_layer = extract_layer_name(full_name)

        print(f" Loading param (name - shape): {model_name}.{full_name} - {param.shape}")

        if full_name in loaded_state_dict and param.shape == loaded_state_dict[full_name].shape:
            # 完全一致,直接載入
            param.data.copy_(loaded_state_dict[full_name])
            used_keys.add(full_name)
        else:
            matched = False
            candidate_keys = shape_to_loaded_keys.get(param.shape, [])
            for candidate in candidate_keys:
                if candidate in used_keys:
                    continue
                candidate_layer = name_to_layer.get(candidate)
                # 檢查是否為相同層級
                if candidate_layer == current_layer:
                    print(f"  --> Shape matched & same layer. Copying from {candidate}")
                    param.data.copy_(loaded_state_dict[candidate])
                    used_keys.add(candidate)
                    matched = True
                    break
            if not matched:
                print(f"  --> No match found. Re-initializing param: {model_name}.{full_name}")
                if "weight" in full_name:
                    nn.init.kaiming_normal_(param.data, mode='fan_out', nonlinearity='leaky_relu')
                elif "bias" in full_name:
                    nn.init.constant_(param.data, 0)

    for name, buffer in model.named_buffers():
        if name not in loaded_state_dict or model_state[name].shape != loaded_state_dict[name].shape:
            print(f" Re-initializing buffer (shape mismatch or missing): {model_name}.{name}")
            buffer.data.zero_()


def load_networks(self, step):
    loaded = False
    target_filepath_G = os.path.join(self.save_dir, f"{step}_net_G.pth")
    target_filepath_D = os.path.join(self.save_dir, f"{step}_net_D.pth")

    # --- Generator ---
    if os.path.exists(target_filepath_G):
        loaded = True
        try:
            state_dict_G = torch.load(target_filepath_G, map_location=self.device)
            self.netG.load_state_dict(state_dict_G, strict=False)
            self._initialize_unmatched_weights(self.netG, state_dict_G, model_name="netG")
        except Exception as e:
            print(f" Error loading Generator: {e}")
    else:
        print(f" Generator checkpoint not found: {target_filepath_G}")

    # --- Discriminator ---
    if os.path.exists(target_filepath_D):
        try:
            state_dict_D = torch.load(target_filepath_D, map_location=self.device)
            self.netD.load_state_dict(state_dict_D, strict=False)
            self._initialize_unmatched_weights(self.netD, state_dict_D, model_name="netD")
        except Exception as e:
            print(f" Error loading Discriminator: {e}")
    else:
        print(f" Discriminator checkpoint not found: {target_filepath_D}")

    if loaded:
        print(f" Model {step} loaded successfully")
    return loaded

✅ 使用方式:

只要你照常呼叫 load_networks(step),就會自動:

  • 載入同名參數
  • 嘗試載入 shape 相符 + 層級相同 的權重
  • 重新初始化找不到的參數

如果你有更複雜的命名轉換邏輯(例如 down.0.conv.weightencoder.block0.conv.weight),我們也可以加上 mapping dictionary 或 fuzzy matching,要的話我可以幫你擴充。

你要不要我幫你把這段封裝到 utility function 或 class method 方便未來重用?

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *