Discriminator 加入 Spectral Normalization(光譜正規化)

這段程式碼已經相當完整,並且包含了一些先進的技術,如Self-AttentionResSkip(殘差跳連接)、**梯度懲罰(Gradient Penalty)**等。為了進一步改善正則化方法,我建議以下增強策略:

改進方向

  1. 加入 Spectral Normalization(光譜正規化)
    • 改善判別器(Discriminator)的穩定性,防止梯度爆炸。
    • 可在 Discriminator 的卷積層上應用 torch.nn.utils.spectral_norm
  2. 加入 Instance Normalization(IN)
    • 目前使用的是 BatchNorm2d,但由於字體風格轉換是圖像生成問題,InstanceNorm2d 通常能提供更好的風格適應能力。
  3. 加入 Dropout 作為正則化
    • UnetSkipConnectionBlockup 變換中可增加 nn.Dropout(p=0.3)
  4. 加入 Feature Matching Loss
    • 透過比較真實與生成圖像在不同層的特徵差異,進一步提升字型風格一致性。

改進後的 Discriminator(加入 Spectral Normalization)

class Discriminator(nn.Module):
    def __init__(self, input_nc, embedding_num, ndf=64, norm_layer=nn.InstanceNorm2d, 
                 image_size=256, final_channels=1, blur=False):
        super(Discriminator, self).__init__()
        
        use_bias = norm_layer != nn.BatchNorm2d
        kw = 5
        padw = 2

        sequence = [
            nn.utils.spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw)),
            nn.LeakyReLU(0.2, True)
        ]

        nf_mult = 1
        for n in range(1, 3):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                nn.utils.spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias)),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        sequence += [
            nn.utils.spectral_norm(nn.Conv2d(ndf * nf_mult, final_channels, kernel_size=kw, stride=1, padding=padw, bias=use_bias)),
            norm_layer(final_channels),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv2d(final_channels, 1, kernel_size=kw, stride=1, padding=padw)]
        
        self.model = nn.Sequential(*sequence)
        image_size = math.ceil(image_size / 8)
        final_features = final_channels * image_size * image_size
        self.binary = nn.Linear(final_features, 1)
        self.category = nn.Linear(final_features, embedding_num)
        self.blur = blur
        self.gaussian_blur = T.GaussianBlur(kernel_size=1, sigma=1.0)  # 設定模糊程度

    def forward(self, input):
        features = self.model(input)
        if self.blur:
            features = self.gaussian_blur(features)
        features = features.view(input.shape[0], -1)
        binary_logits = self.binary(features)
        category_logits = self.category(features)
        return binary_logits, category_logits

改進點Spectral Normalization 避免判別器過度收斂,提高梯度穩定性。
Instance Normalization 幫助不同字型之間的風格對齊。


改進後的 UnetSkipConnectionBlock(加入 Dropout & InstanceNorm)

class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None,
                 norm_layer=nn.InstanceNorm2d, layer=0, embedding_dim=128,
                 use_dropout=False, self_attention=False, blur=False, outermost=False, innermost=False):
        super(UnetSkipConnectionBlock, self).__init__()

        self.outermost = outermost
        self.innermost = innermost
        use_bias = norm_layer != nn.BatchNorm2d

        if input_nc is None:
            input_nc = outer_nc
        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.ReLU(inplace=False)
        upnorm = norm_layer(outer_nc)

        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            self.down = nn.Sequential(downconv)
            self.up = nn.Sequential(uprelu, upconv, nn.Tanh())
        elif innermost:
            upconv = nn.ConvTranspose2d(inner_nc + embedding_dim, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            self.down = nn.Sequential(downrelu, downconv)
            self.up = nn.Sequential(uprelu, upconv, upnorm)
        else:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            self.down = nn.Sequential(downrelu, downconv, downnorm)
            self.up = nn.Sequential(uprelu, upconv, upnorm)
            if use_dropout:
                self.up.add_module("dropout", nn.Dropout(0.3))  # 加入 dropout 來增加正則化效果

        self.submodule = submodule
        self.self_attn = SelfAttention(inner_nc) if self_attention and layer in [4, 6] else None
        self.res_skip = ResSkip(outer_nc) if not outermost and not innermost else None

改進點Instance Normalization 提高風格轉換穩定性。
Dropout 減少模型過擬合的風險。


加入 Feature Matching Loss

Zi2ZiModel 裡新增 Feature Matching Loss:

class Zi2ZiModel:
    def __init__(self, ...):
        ...
        self.feature_matching_loss = nn.L1Loss()

    def forward(self):
        self.fake_B, self.encoded_real_A = self.netG(self.real_A, self.labels)
        self.encoded_fake_B = self.netG.encode(self.fake_B, self.labels)

    def compute_feature_matching_loss(self, real_AB, fake_AB):
        real_features, _ = self.netD(real_AB)
        fake_features, _ = self.netD(fake_AB)
        return self.feature_matching_loss(real_features, fake_features)

    def backward_G(self):
        fake_AB = torch.cat([self.real_A, self.fake_B], 1)
        real_AB = torch.cat([self.real_A, self.real_B], 1)

        # Feature Matching Loss
        fm_loss = self.compute_feature_matching_loss(real_AB, fake_AB)

        # L1 Loss
        l1_loss = self.l1_loss(self.fake_B, self.real_B) * self.L1_penalty

        self.g_loss = l1_loss + fm_loss * 10.0  # 設定 Feature Matching Loss 的權重
        self.g_loss.backward()

改進點Feature Matching LossGenerator 產生的字型風格更接近 real_B
加入權重控制 (10.0),使得 G 的學習更穩定。


總結

  1. 判別器(Discriminator)
    • ✅ 加入 Spectral Normalization
    • ✅ 改用 InstanceNorm2d
  2. 生成器(Generator)
    • ✅ 加入 Dropout
    • ✅ 改用 InstanceNorm2d
  3. 訓練 Loss
    • ✅ 增加 Feature Matching Loss,讓風格更貼近 real_B

這些改動應該能夠進一步提升字型風格轉換的準確性,並減少過擬合風險。你可以嘗試這些方法,看看是否有明顯的提升!

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *