這個版本的 Discriminator
支援特徵輸出,讓你可以在訓練時使用 Feature Matching Loss 或進行可視化分析。
主要修改點:
- 回傳中間特徵層的輸出:
- 在
forward
方法中,將各個卷積層的輸出保留,最後回傳(pred, features)
。
- 在
- 最後一層 (
final_layer
) 分離:- 這樣可以分開提取特徵與分類結果。
- 允許靈活選擇特徵層輸出:
features
變數收集了各個層的中間輸出,供Feature Matching Loss
使用。
修改後的 Discriminator
import torch
import torch.nn as nn
import functools
class Discriminator(nn.Module):
def __init__(self, input_nc=1, ndf=64, norm_layer=nn.BatchNorm2d, final_channels=1):
"""PatchGAN Discriminator with Feature Extraction
Parameters:
input_nc (int) -- Number of input image channels
ndf (int) -- Number of filters in the first conv layer
norm_layer -- Normalization layer
"""
super(Discriminator, self).__init__()
if isinstance(norm_layer, functools.partial):
use_bias = norm_layer.func != nn.BatchNorm2d
else:
use_bias = norm_layer != nn.BatchNorm2d
kw = 4 # Kernel size
padw = 1 # Padding
# Feature extraction layers
self.conv1 = nn.Sequential(
nn.Conv2d(input_nc, ndf * 1, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True)
)
self.conv2 = nn.Sequential(
nn.Conv2d(ndf * 1, ndf * 2, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * 2),
nn.LeakyReLU(0.2, True)
)
self.conv3 = nn.Sequential(
nn.Conv2d(ndf * 2, ndf * 4, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, True)
)
self.conv4 = nn.Sequential(
nn.Conv2d(ndf * 4, ndf * 8, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, True)
)
# Final classification layer
self.final_layer = nn.Sequential(
nn.Conv2d(ndf * 8, final_channels, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
nn.Sigmoid()
)
def forward(self, x):
features = []
x = self.conv1(x)
features.append(x)
x = self.conv2(x)
features.append(x)
x = self.conv3(x)
features.append(x)
x = self.conv4(x)
features.append(x)
pred = self.final_layer(x)
return pred, features # Return both the prediction and intermediate features
如何使用這個 Discriminator
?
計算 Feature Matching Loss:
real_pred, real_features = discriminator(real_images)
fake_pred, fake_features = discriminator(fake_images)
feature_matching_loss = sum(nn.L1Loss()(rf, ff) for rf, ff in zip(real_features, fake_features))
這樣就可以讓 Generator 更關注局部紋理細節,提高 Zen Maru Gothic 的字形品質。
這樣的 Discriminator
設計已經支援輸出預測結果與中間特徵,但有幾個可以進一步優化的地方:
- 統一 Normalization 層:
conv3
和conv4
目前是nn.BatchNorm2d
,但conv2
使用的是norm_layer
,最好保持一致。- 建議全部都用
norm_layer(ndf * X)
來統一,這樣可以更容易切換不同的正規化策略,例如InstanceNorm2d
。
- 最後一層的 Sigmoid:
- PatchGAN 的
Discriminator
一般不加Sigmoid()
,而是讓 Loss function (BCEWithLogitsLoss
) 自己處理。 - 這樣可以穩定訓練,因為
BCEWithLogitsLoss
內部已經包含 Sigmoid 運算。
- PatchGAN 的
- 回傳特徵的方式:
- 目前是
features.append(x)
,但x
是直接經過 LeakyReLU 後的結果,或許可以直接取x.clone()
來確保計算圖的穩定性。
- 目前是
改進後的 Discriminator
class Discriminator(nn.Module):
def __init__(self, input_nc=1, ndf=64, norm_layer=nn.BatchNorm2d, final_channels=1):
super(Discriminator, self).__init__()
if isinstance(norm_layer, functools.partial):
use_bias = norm_layer.func != nn.BatchNorm2d
else:
use_bias = norm_layer != nn.BatchNorm2d
kw = 4 # Kernel size
padw = 1 # Padding
# Feature extraction layers
self.conv1 = nn.Sequential(
nn.Conv2d(input_nc, ndf * 1, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True)
)
self.conv2 = nn.Sequential(
nn.Conv2d(ndf * 1, ndf * 2, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * 2),
nn.LeakyReLU(0.2, True)
)
self.conv3 = nn.Sequential(
nn.Conv2d(ndf * 2, ndf * 4, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * 4), # 改成統一使用 norm_layer
nn.LeakyReLU(0.2, True)
)
self.conv4 = nn.Sequential(
nn.Conv2d(ndf * 4, ndf * 8, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * 8), # 改成統一使用 norm_layer
nn.LeakyReLU(0.2, True)
)
# Final classification layer
self.final_layer = nn.Conv2d(ndf * 8, final_channels, kernel_size=kw, stride=1, padding=padw, bias=use_bias)
def forward(self, x):
features = []
x = self.conv1(x)
features.append(x.clone()) # 避免計算圖影響
x = self.conv2(x)
features.append(x.clone())
x = self.conv3(x)
features.append(x.clone())
x = self.conv4(x)
features.append(x.clone())
pred = self.final_layer(x) # 不使用 Sigmoid()
return pred, features # Return both the prediction and intermediate features
優化點說明
- 不使用
Sigmoid()
- 讓
BCEWithLogitsLoss
來處理數值範圍,這樣可以避免梯度消失問題。
- 讓
- 統一
norm_layer
- 讓
conv3
和conv4
也使用norm_layer(ndf * X)
,以便未來切換InstanceNorm2d
更靈活。
- 讓
features.append(x.clone())
- 避免 PyTorch 計算圖過度連結,確保梯度不會不小心傳遞到不應該的地方。
這樣的 Discriminator
應該會更穩定,並且符合你的訓練需求!🚀