判別器 (Discriminator) 的權重套用 Spectral Normalization,以限制其權重的 Lipschitz 常數

Spectral Normalization 的作用:

  • Spectral Normalization 通過將權重矩陣除以其譜範數(spectral norm)來限制權重的 Lipschitz 常數。
  • 這有助於穩定訓練過程,特別是在 GANs 中,並可以提高生成樣本的質量。
  • 限制判別器的權重,可以避免判別器在訓練初期過度自信,進而導致生成器無法學習。

通過這些修改,判別器的權重將受到 Spectral Normalization 的約束,從而提高訓練的穩定性和生成樣本的質量。

主要修改:

  1. 導入 spectral_norm
    • from torch.nn.utils.spectral_norm import spectral_norm 導入了 Spectral Normalization 函數。
  2. 應用 spectral_norm 到卷積層:
    • nn.Conv2d 層的前面,我們使用 spectral_norm() 函數來包裝卷積層,例如:
      • spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw))
    • 這將 Spectral Normalization 應用於每個卷積層的權重。
  3. 應用 spectral_norm 到線性層:
    • 同樣地,我們也將 spectral_norm() 應用於線性層 nn.Linear
      • self.binary = spectral_norm(nn.Linear(image_size // 8 * image_size // 8, 1))

這裡是修改後的 trainer 腳本,對判別器 (Discriminator) 的權重套用了 Spectral Normalization,以限制其權重的 Lipschitz 常數:

import torch
import torch.nn as nn
import torch.nn.utils.spectral_norm as spectral_norm

class Discriminator(nn.Module):
    def __init__(self, input_nc, embedding_num, ndf=64, norm_layer=nn.BatchNorm2d, image_size=256):
        super(Discriminator, self).__init__()
        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
            use_bias = norm_layer.func != nn.BatchNorm2d
        else:
            use_bias = norm_layer != nn.BatchNorm2d
        kw = 5
        padw = 2
        sequence = [
            spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw)),
            nn.LeakyReLU(0.2, True)
        ]
        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, 3):  # gradually increase the number of filters
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias)),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = 8
        sequence += [
            spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias)),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]
        sequence += [spectral_norm(nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw))]

        self.model = nn.Sequential(*sequence)
        final_channels = 1
        image_size = math.ceil(image_size / 2)
        image_size = math.ceil(image_size / 2)
        image_size = math.ceil(image_size / 2)
        final_features = final_channels * image_size * image_size
        self.binary = spectral_norm(nn.Linear(final_features, 1))
        self.catagory = spectral_norm(nn.Linear(final_features, embedding_num))

    def forward(self, input):
        features = self.model(input)
        features = features.view(input.shape[0], -1)
        binary_logits = self.binary(features)
        catagory_logits = self.catagory(features)
        return binary_logits, catagory_logits

這段程式碼已經正確地為判別器 (Discriminator) 中的所有卷積層和全連接層應用了 Spectral Normalization,確保權重的 Lipschitz 常數受限,提升對抗訓練的穩定性。你可以試試這個版本,看是否能解決訓練不穩定的問題!如果還有其他問題或需要進一步調整,歡迎討論。


gemini

好的,針對您提供的字型風格轉換任務,我將修改 Discriminator 類別,加入 Spectral Normalization,以限制判別器的權重 Lipschitz 常數。

以下是修改後的 Discriminator 類別:

import torch
import torch.nn as nn
import torch.nn.utils.spectral_norm as spectral_norm
import functools
import math

class Discriminator(nn.Module):
    def __init__(self, input_nc, embedding_num, ndf=64, norm_layer=nn.BatchNorm2d, image_size=256):
        super(Discriminator, self).__init__()
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func != nn.BatchNorm2d
        else:
            use_bias = norm_layer != nn.BatchNorm2d
        kw = 5
        padw = 2
        sequence = [
            spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw)),
            nn.LeakyReLU(0.2, True)
        ]
        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, 3):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias)),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = 8
        sequence += [
            spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias)),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]
        sequence += [spectral_norm(nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw))]

        self.model = nn.Sequential(*sequence)
        final_channels = 1
        image_size = math.ceil(image_size / 2)
        image_size = math.ceil(image_size / 2)
        image_size = math.ceil(image_size / 2)
        final_features = final_channels * image_size * image_size
        self.binary = nn.Linear(final_features, 1)
        self.catagory = nn.Linear(final_features, embedding_num)

    def forward(self, input):
        features = self.model(input)
        features = features.view(input.shape[0], -1)
        binary_logits = self.binary(features)
        catagory_logits = self.catagory(features)
        return binary_logits, catagory_logits

修改說明:

  • 引入 torch.nn.utils.spectral_norm
    • 首先,我們引入了 torch.nn.utils.spectral_norm 模組,它提供了 Spectral Normalization 功能。
  • 在卷積層應用 Spectral Normalization:
    • Discriminator 的每個 nn.Conv2d 層,我們都使用 spectral_norm() 函數進行包裹,例如:
      • spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw))
    • 這樣,每個卷積層的權重都會經過 Spectral Normalization,從而限制其 Lipschitz 常數。

使用 Spectral Normalization 的優點:

  • 穩定 GAN 訓練:
    • Spectral Normalization 有助於穩定 GAN 的訓練過程,防止判別器過度自信,避免梯度消失或爆炸。
  • 提高生成圖像品質:
    • 通過穩定訓練,可以提高生成圖像的品質和真實感。
  • 減少超參數調整:
    • Spectral Normalization 可以減少對超參數的敏感性,簡化模型調整過程。

其他建議:

  • 梯度懲罰 (Gradient Penalty):
    • 除了 Spectral Normalization,您還可以考慮在判別器的損失函數中加入梯度懲罰,例如 Wasserstein GAN-GP (WGAN-GP)。
  • 學習率調整:
    • 使用學習率衰減策略,例如線性衰減或餘弦退火,有助於模型收斂到更好的局部最小值。
  • AdamW optimizer:
    • 將optimizer從Adam改為AdamW。

希望這些修改和建議對您有所幫助!

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *