你目前的程式碼是使用 strict=False
來略過一些名稱或 shape 不一致的情況,並透過 _initialize_unmatched_weights()
去補上沒被載入的部分。不過,你的需求是這樣的:
如果權重的 shape 是一樣的,只是變數名稱不同,仍希望延用之前儲存的權重。
這種情況,可以透過「手動比對 shape 相符的權重」來實現 name mapping + shape matching,再把對應的權重載入。下面提供一個改進版的 _initialize_unmatched_weights()
方法,會嘗試找出 shape 一樣但名字不同的參數來對應載入。
✅ 改進的 _initialize_unmatched_weights
實作:
def _initialize_unmatched_weights(self, model, loaded_state_dict, model_name="Model"):
model_state = model.state_dict()
used_keys = set()
# 建立 shape 對應的反查表
shape_to_loaded_keys = {}
for k, v in loaded_state_dict.items():
shape_to_loaded_keys.setdefault(v.shape, []).append(k)
for name, param in model.named_parameters():
full_name = name
print(f" Loading param (name - shape): {model_name}.{full_name} - {param.shape}")
if full_name in loaded_state_dict and param.shape == loaded_state_dict[full_name].shape:
# 直接 match
param.data.copy_(loaded_state_dict[full_name])
used_keys.add(full_name)
else:
matched = False
candidate_keys = shape_to_loaded_keys.get(param.shape, [])
for candidate in candidate_keys:
if candidate not in used_keys:
print(f" --> Shape matched. Copying from {candidate}")
param.data.copy_(loaded_state_dict[candidate])
used_keys.add(candidate)
matched = True
break
if not matched:
print(f" --> No match found. Re-initializing param: {model_name}.{full_name}")
if "weight" in full_name:
nn.init.kaiming_normal_(param.data, mode='fan_out', nonlinearity='leaky_relu')
elif "bias" in full_name:
nn.init.constant_(param.data, 0)
for name, buffer in model.named_buffers():
if name not in loaded_state_dict or model_state[name].shape != loaded_state_dict[name].shape:
print(f" Re-initializing buffer (shape mismatch or missing): {model_name}.{name}")
buffer.data.zero_()
🔍 解釋一下這段邏輯:
- 建立 shape 到名稱的反查表:方便之後去找哪些舊模型的參數 shape 是一樣的。
- 如果名稱一致就直接用,這是正常情況。
- 名稱不一致時,比對 shape 相符的未使用參數,找到第一個 shape 一樣的來用。
- 如果還是找不到,就重新初始化。
✅ 你還需要確保:
- 模型 architecture 沒有改變太多(例如層數一樣,不然會有不合理的 shape match)。
- 舊模型和新模型之間的參數 shape 真的是可交換的(有語意一致性)。
這樣就可以在 variable 名稱變更的情況下繼續使用原本的權重了,儘可能延用 pretrained model 的知識。
需要我幫你整合進原本的 load_networks()
裡嗎?還是有特定模型要幫忙分析 name mapping?
好,這邊幫你做了完整整合,並 加入了層數檢查機制,避免 shape 雖然一樣但其實不屬於同一層的情況(像是 conv0 和 conv5 不應該交換權重)。
✅ 說明:
- 比對變數 shape 相符。
- 層數名稱需部分對得上(例如都是
down.0.weight
→downconv.weight
,這邊會允許匹配down
→downconv
)。 - 匹配成功才載入,否則會重新初始化。
- 已整合進你的
load_networks()
。
🔧 完整整合版本:
import torch
import torch.nn as nn
import os
def extract_layer_name(name):
"""抽取層級資訊作為匹配依據,例如 down.0.weight -> down"""
parts = name.split('.')
if parts:
return parts[0]
return name
def _initialize_unmatched_weights(self, model, loaded_state_dict, model_name="Model"):
model_state = model.state_dict()
used_keys = set()
# 建立 shape -> [names] 的映射表
shape_to_loaded_keys = {}
name_to_layer = {}
for k, v in loaded_state_dict.items():
shape_to_loaded_keys.setdefault(v.shape, []).append(k)
name_to_layer[k] = extract_layer_name(k)
for name, param in model.named_parameters():
full_name = name
current_layer = extract_layer_name(full_name)
print(f" Loading param (name - shape): {model_name}.{full_name} - {param.shape}")
if full_name in loaded_state_dict and param.shape == loaded_state_dict[full_name].shape:
# 完全一致,直接載入
param.data.copy_(loaded_state_dict[full_name])
used_keys.add(full_name)
else:
matched = False
candidate_keys = shape_to_loaded_keys.get(param.shape, [])
for candidate in candidate_keys:
if candidate in used_keys:
continue
candidate_layer = name_to_layer.get(candidate)
# 檢查是否為相同層級
if candidate_layer == current_layer:
print(f" --> Shape matched & same layer. Copying from {candidate}")
param.data.copy_(loaded_state_dict[candidate])
used_keys.add(candidate)
matched = True
break
if not matched:
print(f" --> No match found. Re-initializing param: {model_name}.{full_name}")
if "weight" in full_name:
nn.init.kaiming_normal_(param.data, mode='fan_out', nonlinearity='leaky_relu')
elif "bias" in full_name:
nn.init.constant_(param.data, 0)
for name, buffer in model.named_buffers():
if name not in loaded_state_dict or model_state[name].shape != loaded_state_dict[name].shape:
print(f" Re-initializing buffer (shape mismatch or missing): {model_name}.{name}")
buffer.data.zero_()
def load_networks(self, step):
loaded = False
target_filepath_G = os.path.join(self.save_dir, f"{step}_net_G.pth")
target_filepath_D = os.path.join(self.save_dir, f"{step}_net_D.pth")
# --- Generator ---
if os.path.exists(target_filepath_G):
loaded = True
try:
state_dict_G = torch.load(target_filepath_G, map_location=self.device)
self.netG.load_state_dict(state_dict_G, strict=False)
self._initialize_unmatched_weights(self.netG, state_dict_G, model_name="netG")
except Exception as e:
print(f" Error loading Generator: {e}")
else:
print(f" Generator checkpoint not found: {target_filepath_G}")
# --- Discriminator ---
if os.path.exists(target_filepath_D):
try:
state_dict_D = torch.load(target_filepath_D, map_location=self.device)
self.netD.load_state_dict(state_dict_D, strict=False)
self._initialize_unmatched_weights(self.netD, state_dict_D, model_name="netD")
except Exception as e:
print(f" Error loading Discriminator: {e}")
else:
print(f" Discriminator checkpoint not found: {target_filepath_D}")
if loaded:
print(f" Model {step} loaded successfully")
return loaded
✅ 使用方式:
只要你照常呼叫 load_networks(step)
,就會自動:
- 載入同名參數
- 嘗試載入 shape 相符 + 層級相同 的權重
- 重新初始化找不到的參數
如果你有更複雜的命名轉換邏輯(例如 down.0.conv.weight
→ encoder.block0.conv.weight
),我們也可以加上 mapping dictionary 或 fuzzy matching,要的話我可以幫你擴充。
你要不要我幫你把這段封裝到 utility function 或 class method 方便未來重用?