這段程式碼已經相當完整,並且包含了一些先進的技術,如Self-Attention、ResSkip(殘差跳連接)、**梯度懲罰(Gradient Penalty)**等。為了進一步改善正則化方法,我建議以下增強策略:
改進方向
- 加入 Spectral Normalization(光譜正規化)
- 改善判別器(Discriminator)的穩定性,防止梯度爆炸。
- 可在
Discriminator
的卷積層上應用torch.nn.utils.spectral_norm
。
- 加入 Instance Normalization(IN)
- 目前使用的是
BatchNorm2d
,但由於字體風格轉換是圖像生成問題,InstanceNorm2d
通常能提供更好的風格適應能力。
- 目前使用的是
- 加入 Dropout 作為正則化
- 在
UnetSkipConnectionBlock
的up
變換中可增加nn.Dropout(p=0.3)
。
- 在
- 加入 Feature Matching Loss
- 透過比較真實與生成圖像在不同層的特徵差異,進一步提升字型風格一致性。
改進後的 Discriminator
(加入 Spectral Normalization)
class Discriminator(nn.Module):
def __init__(self, input_nc, embedding_num, ndf=64, norm_layer=nn.InstanceNorm2d,
image_size=256, final_channels=1, blur=False):
super(Discriminator, self).__init__()
use_bias = norm_layer != nn.BatchNorm2d
kw = 5
padw = 2
sequence = [
nn.utils.spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw)),
nn.LeakyReLU(0.2, True)
]
nf_mult = 1
for n in range(1, 3):
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
nn.utils.spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias)),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [
nn.utils.spectral_norm(nn.Conv2d(ndf * nf_mult, final_channels, kernel_size=kw, stride=1, padding=padw, bias=use_bias)),
norm_layer(final_channels),
nn.LeakyReLU(0.2, True)
]
sequence += [nn.Conv2d(final_channels, 1, kernel_size=kw, stride=1, padding=padw)]
self.model = nn.Sequential(*sequence)
image_size = math.ceil(image_size / 8)
final_features = final_channels * image_size * image_size
self.binary = nn.Linear(final_features, 1)
self.category = nn.Linear(final_features, embedding_num)
self.blur = blur
self.gaussian_blur = T.GaussianBlur(kernel_size=1, sigma=1.0) # 設定模糊程度
def forward(self, input):
features = self.model(input)
if self.blur:
features = self.gaussian_blur(features)
features = features.view(input.shape[0], -1)
binary_logits = self.binary(features)
category_logits = self.category(features)
return binary_logits, category_logits
改進點 ✅ Spectral Normalization 避免判別器過度收斂,提高梯度穩定性。
✅ Instance Normalization 幫助不同字型之間的風格對齊。
改進後的 UnetSkipConnectionBlock
(加入 Dropout & InstanceNorm)
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None,
norm_layer=nn.InstanceNorm2d, layer=0, embedding_dim=128,
use_dropout=False, self_attention=False, blur=False, outermost=False, innermost=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
self.innermost = innermost
use_bias = norm_layer != nn.BatchNorm2d
if input_nc is None:
input_nc = outer_nc
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(inplace=False)
upnorm = norm_layer(outer_nc)
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
self.down = nn.Sequential(downconv)
self.up = nn.Sequential(uprelu, upconv, nn.Tanh())
elif innermost:
upconv = nn.ConvTranspose2d(inner_nc + embedding_dim, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
self.down = nn.Sequential(downrelu, downconv)
self.up = nn.Sequential(uprelu, upconv, upnorm)
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
self.down = nn.Sequential(downrelu, downconv, downnorm)
self.up = nn.Sequential(uprelu, upconv, upnorm)
if use_dropout:
self.up.add_module("dropout", nn.Dropout(0.3)) # 加入 dropout 來增加正則化效果
self.submodule = submodule
self.self_attn = SelfAttention(inner_nc) if self_attention and layer in [4, 6] else None
self.res_skip = ResSkip(outer_nc) if not outermost and not innermost else None
改進點 ✅ Instance Normalization 提高風格轉換穩定性。
✅ Dropout 減少模型過擬合的風險。
加入 Feature Matching Loss
在 Zi2ZiModel
裡新增 Feature Matching Loss:
class Zi2ZiModel:
def __init__(self, ...):
...
self.feature_matching_loss = nn.L1Loss()
def forward(self):
self.fake_B, self.encoded_real_A = self.netG(self.real_A, self.labels)
self.encoded_fake_B = self.netG.encode(self.fake_B, self.labels)
def compute_feature_matching_loss(self, real_AB, fake_AB):
real_features, _ = self.netD(real_AB)
fake_features, _ = self.netD(fake_AB)
return self.feature_matching_loss(real_features, fake_features)
def backward_G(self):
fake_AB = torch.cat([self.real_A, self.fake_B], 1)
real_AB = torch.cat([self.real_A, self.real_B], 1)
# Feature Matching Loss
fm_loss = self.compute_feature_matching_loss(real_AB, fake_AB)
# L1 Loss
l1_loss = self.l1_loss(self.fake_B, self.real_B) * self.L1_penalty
self.g_loss = l1_loss + fm_loss * 10.0 # 設定 Feature Matching Loss 的權重
self.g_loss.backward()
改進點 ✅ Feature Matching Loss 讓 Generator
產生的字型風格更接近 real_B
。
✅ 加入權重控制 (10.0
),使得 G
的學習更穩定。
總結
- 判別器(Discriminator)
- ✅ 加入
Spectral Normalization
- ✅ 改用
InstanceNorm2d
- ✅ 加入
- 生成器(Generator)
- ✅ 加入
Dropout
- ✅ 改用
InstanceNorm2d
- ✅ 加入
- 訓練 Loss
- ✅ 增加 Feature Matching Loss,讓風格更貼近
real_B
。
- ✅ 增加 Feature Matching Loss,讓風格更貼近
這些改動應該能夠進一步提升字型風格轉換的準確性,並減少過擬合風險。你可以嘗試這些方法,看看是否有明顯的提升!