以下是已完整整合 PixelShuffle 後加 smoothing conv 層 的 UnetSkipConnectionBlock
class 版本,保留你原本所有邏輯,僅針對 up_mode == 'pixelshuffle'
的地方進行擴充修改:
✅ 完整改寫後的 UnetSkipConnectionBlock
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None,
norm_layer=nn.InstanceNorm2d, layer=0, embedding_dim=64,
use_dropout=False, self_attention=False, attention_type='linear',
blur=False, outermost=False, innermost=False, use_transformer=False,
attn_layers=None, up_mode='conv', freeze_downsample=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
self.innermost = innermost
self.layer = layer
self.attn_layers = attn_layers or []
self.up_mode = up_mode
self.freeze_downsample = freeze_downsample
use_bias = norm_layer != nn.BatchNorm2d
if input_nc is None:
input_nc = outer_nc
kernel_size = 3 if innermost else 4
stride = 1 if innermost else 2
padding = 1
self.downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=kernel_size, stride=stride, padding=padding, bias=use_bias)
nn.init.kaiming_normal_(self.downconv.weight, nonlinearity='leaky_relu')
self.downrelu = nn.SiLU(inplace=True)
self.downnorm = norm_layer(inner_nc)
self.uprelu = nn.SiLU(inplace=True)
self.upnorm = norm_layer(outer_nc)
# UP path construction
if outermost:
in_channels = inner_nc * 2
elif innermost:
in_channels = inner_nc
else:
in_channels = inner_nc * 2
if self.up_mode == 'conv':
self.upconv = nn.ConvTranspose2d(in_channels, outer_nc, kernel_size=4, stride=2, padding=1, output_padding=1, bias=use_bias)
nn.init.kaiming_normal_(self.upconv.weight)
elif self.up_mode == 'upsample':
self.upconv = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(in_channels, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias),
self.upnorm
)
nn.init.kaiming_normal_(self.upconv[1].weight)
elif self.up_mode == 'pixelshuffle':
self.upconv = nn.Sequential(
nn.Conv2d(in_channels, outer_nc * 4, kernel_size=3, stride=1, padding=1, bias=use_bias),
nn.PixelShuffle(2),
nn.Conv2d(outer_nc, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias), # << smoothing conv
self.upnorm
)
nn.init.kaiming_normal_(self.upconv[0].weight)
nn.init.kaiming_normal_(self.upconv[2].weight)
else:
raise ValueError(f"Unsupported up_mode: {self.up_mode}. Choose 'conv', 'upsample', or 'pixelshuffle'.")
# Layer logic
if outermost:
self.down = nn.Sequential(self.downconv)
self.up = nn.Sequential(self.uprelu, self.upconv, nn.Tanh())
elif innermost:
self.down = nn.Sequential(self.downrelu, self.downconv)
self.up = nn.Sequential(self.uprelu, self.upconv, self.upnorm)
if use_transformer:
self.transformer_block = TransformerBlock(inner_nc)
self.film = FiLMModulation(inner_nc, embedding_dim)
else:
self.down = nn.Sequential(self.downrelu, self.downconv, self.downnorm)
self.up = nn.Sequential(self.uprelu, self.upconv, self.upnorm)
if use_dropout:
self.up.add_module("dropout", nn.Dropout(0.3))
self.submodule = submodule
if self_attention and self.layer in self.attn_layers:
self.attn_block = LinearAttention(inner_nc) if attention_type == 'linear' else SelfAttention(inner_nc)
else:
self.attn_block = None
if not outermost and not innermost and layer in [4, 5, 6, 7]:
self.res_skip = ResSkip(outer_nc, outer_nc) if hasattr(self, 'res_skip') else None
if self.freeze_downsample:
for param in self.downconv.parameters():
param.requires_grad = False
for param in self.downnorm.parameters():
param.requires_grad = False
def forward(self, x, style=None):
if hasattr(self, 'attn_block') and self.attn_block is not None:
x = self.attn_block(x)
encoded = self.down(x)
if self.innermost:
if hasattr(self, 'transformer_block'):
encoded = self.transformer_block(encoded)
if hasattr(self, 'film') and style is not None:
encoded = self.film(encoded, style)
decoded = self.up(encoded)
if decoded.shape[2:] != x.shape[2:]:
decoded = F.interpolate(decoded, size=x.shape[2:], mode='bilinear', align_corners=False)
if hasattr(self, 'res_skip') and self.res_skip is not None:
decoded = self.res_skip(decoded)
return torch.cat([x, decoded], 1), encoded.contiguous().view(x.shape[0], -1)
else:
sub_output, encoded_real_A = self.submodule(encoded, style)
decoded = self.up(sub_output)
if decoded.shape[2:] != x.shape[2:]:
decoded = F.interpolate(decoded, size=x.shape[2:], mode='bilinear', align_corners=False)
if hasattr(self, 'res_skip') and self.res_skip is not None:
decoded = self.res_skip(decoded)
if self.outermost:
return decoded, encoded_real_A
else:
return torch.cat([x, decoded], 1), encoded_real_A
這樣就完成囉,你可以很放心地使用 up_mode='pixelshuffle'
,同時享受 PixelShuffle 的效率與平滑後的輸出品質!
需要我也順便幫你更新整個 UNetGenerator
邏輯或是加入新的參數開關這個 smoothing 的話,也可以說~
在 PixelShuffle
後面加入一個卷積層 (nn.Conv2d(outer_nc, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias)
) 會導致該層及其後續層的權重無法直接從使用 ConvTranspose2d
的舊模型中保留。
原因如下:
- 層結構不匹配: 你在
self.upconv
這個Sequential
容器中插入了一個新的卷積層。這改變了模型的層結構。當你嘗試載入舊模型的權重時,新模型中這個額外的卷積層在舊模型中沒有對應的層。 - 權重形狀不匹配: 即使某些層在名稱上可能相似(例如,
self.upnorm
),但由於中間多了一個卷積層,它們在模型中的順序和連接方式已經不同。這會導致權重的形狀和意義不再對應。
哪些權重可能會被保留?
理論上,在修改之前的層的權重仍然可以被載入,只要它們的名稱和形狀在舊模型和新模型中保持一致。
例如,在 UnetSkipConnectionBlock
的 outermost
分支中,self.downconv
的權重如果名稱和形狀沒有改變,仍然可以被載入。
總結來說:
- 直接載入所有權重將會失敗,因為模型結構已經改變。PyTorch 的權重載入機制通常依賴於模型
state_dict()
中鍵(層的名稱)的完全匹配。 - 你需要編寫自定義的權重載入邏輯,該邏輯能夠:
- 比對舊模型和新模型的層名稱。
- 僅載入名稱和形狀都匹配的層的權重。
- 跳過或初始化新插入的卷積層 (
smoothing conv
) 及其後續層(例如,在這個Sequential
中的self.upnorm
,因為它的輸入通道數現在來自於smoothing conv
的輸出,而不是PixelShuffle
的直接輸出)。
如何部分保留權重 (需要自定義程式碼):
你可以遍歷舊模型的 state_dict()
,對於每一個權重,檢查新模型是否有名稱和形狀都匹配的層。如果匹配,則將舊權重載入到新模型對應的層中。對於新插入的層,你需要使用 PyTorch 的初始化方法(例如 nn.init.kaiming_normal_
)來初始化它們的權重。
範例 (簡化概念):
Python
def load_partial_weights(model, pretrained_dict):
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and model_dict[k].shape == v.shape}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
print(f"Loaded {len(pretrained_dict)}/{len(model_dict)} parameters.")
# 假設 old_model 是之前的模型,你已經載入了它的權重到 old_state_dict
# new_model 是你修改後的模型
# old_state_dict = torch.load('path/to/old_weights.pth')
# new_model = UNetGenerator(...) # 你的新模型
# load_partial_weights(new_model, old_state_dict)
這個 load_partial_weights
函數只會載入在新舊模型中名稱和形狀都匹配的權重。你需要確保在載入後,新模型中未被載入權重的層(包括你新增的 smoothing conv 和其後的層)會被正確地初始化。
因此,雖然在 PixelShuffle
後加入卷積層會破壞直接的權重保留,但你可以通過編寫自定義的載入邏輯來盡可能地保留模型中結構未改變部分的權重。
好,我幫你把 smoothing conv
做成特殊 case,初始化成「接近 identity」的卷積(中心值為 1,其餘為 0),並整合進你的 _initialize_unmatched_weights()
。這樣一來,新模型第一次載入舊 checkpoint 時,smoothing_conv
就不會影響輸出行為太多,還能順利 fine-tune。
✅ 修改後的 _initialize_unmatched_weights()
def _initialize_unmatched_weights(self, model, loaded_state_dict, model_name="Model"):
model_state = model.state_dict()
def init_smoothing_conv_as_identity(conv_param, bias_param=None):
"""將 smoothing conv 初始化為接近 identity(中心為 1,其餘為 0)"""
if conv_param.shape[2:] == (3, 3) and conv_param.shape[0] == conv_param.shape[1]:
with torch.no_grad():
conv_param.zero_()
c = conv_param.shape[0]
for i in range(c):
conv_param[i, i, 1, 1] = 1.0 # 對角線中心位置為 1
if bias_param is not None:
bias_param.zero_()
return True
return False
for name, param in model.named_parameters():
need_init = (
name not in loaded_state_dict or
model_state[name].shape != loaded_state_dict[name].shape
)
if need_init:
print(f" ⚠️ Re-initializing param: {model_name}.{name}, shape: {param.shape}")
# 檢查是否是 smoothing conv(名稱可依你具體命名風格微調)
if "smoothing" in name and "weight" in name:
# 嘗試找對應的 bias
bias_name = name.replace("weight", "bias")
bias_param = model_state.get(bias_name, None)
matched = init_smoothing_conv_as_identity(param, bias_param)
if matched:
print(f" ✅ Initialized {model_name}.{name} as identity smoothing conv")
continue # 跳過預設初始化
# fallback 一般初始化
if "weight" in name:
nn.init.kaiming_normal_(param.data, mode='fan_out', nonlinearity='leaky_relu')
elif "bias" in name:
nn.init.constant_(param.data, 0)
for name, buffer in model.named_buffers():
if name not in loaded_state_dict or model_state[name].shape != loaded_state_dict[name].shape:
print(f" ⚠️ Re-initializing buffer: {model_name}.{name}, shape: {buffer.shape}")
buffer.data.zero_()
💡 備註
這段程式碼會自動偵測你用 smoothing
關鍵字命名的 conv 層,例如:
self.smoothing = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias)
如果你是把它包進 upconv
裡的 nn.Sequential
內,可以用 add_module("smoothing", conv)
這樣保持命名清楚(讓 state_dict
key 包含 smoothing
以利辨認)。
如果你希望我順便幫你調整 UnetSkipConnectionBlock
把 smoothing conv
包裝好、命名明確,也可以跟我說~
如果模型的修改是只加入一層的情況下, 可以直接這樣子寫, 就可以完美的搬移舊的權重到新的模型:
def extract_keywords(self, name):
KEYWORD_MATCH_RULES = ["down", "up", "conv", "res", "encoder", "decoder", "self", "line"]
return set([k for k in KEYWORD_MATCH_RULES if k in name])
def extract_layer_name(self, name):
parts = name.split('.')
if parts:
return parts[0]
return name
def _initialize_unmatched_weights(self, model, loaded_state_dict, model_name="Model"):
model_state = model.state_dict()
used_keys = set()
shape_to_loaded_keys = {}
name_to_layer = {}
name_to_keywords = {}
for k, v in loaded_state_dict.items():
shape_to_loaded_keys.setdefault(v.shape, []).append(k)
name_to_layer[k] = self.extract_layer_name(k)
name_to_keywords[k] = self.extract_keywords(k)
for name, param in model.named_parameters():
full_name = name
current_layer = self.extract_layer_name(full_name)
current_keywords = self.extract_keywords(full_name)
#print(f" Loading param (name - shape): {model_name}.{full_name} - {param.shape}")
if full_name in loaded_state_dict and param.shape == loaded_state_dict[full_name].shape:
param.data.copy_(loaded_state_dict[full_name])
used_keys.add(full_name)
else:
matched = False
candidate_keys = shape_to_loaded_keys.get(param.shape, [])
for candidate in candidate_keys:
if candidate in used_keys:
continue
candidate_layer = name_to_layer.get(candidate)
candidate_keywords = name_to_keywords.get(candidate, set())
# 層級名稱與語意關鍵字需一致
if candidate_layer == current_layer and current_keywords & candidate_keywords:
print(f" Loading param (name - shape): {model_name}.{full_name} - {param.shape}")
print(f" --> Shape & layer & keyword match. Copying from {candidate}")
param.data.copy_(loaded_state_dict[candidate])
used_keys.add(candidate)
matched = True
break
if not matched:
print(f" Loading param (name - shape): {model_name}.{full_name} - {param.shape}")
print(f" --> No suitable match found. Re-initializing param: {model_name}.{full_name}")
if "weight" in full_name:
# 暫時性的模型增加 conv.
init_smoothing_conv = True
if init_smoothing_conv:
# 嘗試找對應的 bias
bias_name = name.replace("weight", "bias")
bias_param = model_state.get(bias_name, None)
matched = self.init_smoothing_conv_as_identity(param, bias_param)
if matched:
print(f" ✅ Initialized {model_name}.{name} as identity smoothing conv")
continue # 跳過預設初始化
nn.init.kaiming_normal_(param.data, mode='fan_out', nonlinearity='leaky_relu')
elif "bias" in full_name:
nn.init.constant_(param.data, 0)
for name, buffer in model.named_buffers():
if name not in loaded_state_dict or model_state[name].shape != loaded_state_dict[name].shape:
print(f" Re-initializing buffer (shape mismatch or missing): {model_name}.{name}")
buffer.data.zero_()
def init_smoothing_conv_as_identity(self, conv_param, bias_param=None):
"""將 smoothing conv 初始化為接近 identity(中心為 1,其餘為 0)"""
if conv_param.shape[2:] == (3, 3) and conv_param.shape[0] == conv_param.shape[1]:
with torch.no_grad():
conv_param.zero_()
c = conv_param.shape[0]
for i in range(c):
conv_param[i, i, 1, 1] = 1.0 # 對角線中心位置為 1
if bias_param is not None:
bias_param.zero_()
return True
return False