字型風格轉換的任務中,Generator 和 Discriminator 該不該使用不同的 norm_layer

在字型風格轉換的任務中,GeneratorDiscriminator 使用不同的 norm_layer 可能會有更好的效果。這是因為:

  1. Generator (使用 InstanceNorm2d)
    • InstanceNorm2d 對於風格遷移任務來說較好,因為它能夠減少 batch-wise 的統計信息影響,保留個別樣本的風格特徵,這有助於生成更加一致的風格化結果。
  2. Discriminator (BatchNorm2dInstanceNorm2d?)
    • 如果 Discriminator 也使用 InstanceNorm2d,它可能更關注局部的紋理特徵,而非 batch-level 的分佈,這可能會讓風格約束變得較弱。
    • 使用 BatchNorm2d,Discriminator 會學習到更全局的統計信息,可能會使它更容易分辨出真實與生成的樣本,從而給 Generator 施加更強的學習壓力,提升風格的一致性。

建議

  • 如果你希望 Generator 主導風格變換,建議 Discriminator 使用 BatchNorm2d 這樣 Generator 會更專注於學習風格轉換,而 Discriminator 主要負責區分真實與偽造字型樣本。
  • 如果發現 Discriminator 太強導致 Generator 學習困難,可以改為 InstanceNorm2d,讓 Discriminator 更聚焦於局部特徵而非 batch 級別的統計。

你可以先嘗試 Discriminator 使用 BatchNorm2d,看看是否提升風格適應效果,再根據結果調整!


好的,針對您提供的 UNetGenerator 程式碼,將其中的 BatchNorm2d 替換為 InstanceNorm2d,以更好地適應字型風格轉換的任務。以下是修改後的程式碼:

Python

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import math

class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None,
                 norm_layer=nn.InstanceNorm2d, layer=0, embedding_dim=128,
                 use_dropout=False, self_attention=False, blur=False, outermost=False, innermost=False):
        super(UnetSkipConnectionBlock, self).__init__()

        self.outermost = outermost
        self.innermost = innermost
        use_bias = norm_layer != nn.BatchNorm2d

        if input_nc is None:
            input_nc = outer_nc
        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.ReLU(inplace=False)
        upnorm = norm_layer(outer_nc)

        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            self.down = nn.Sequential(downconv)
            self.up = nn.Sequential(uprelu, upconv, nn.Tanh())
        elif innermost:
            upconv = nn.ConvTranspose2d(inner_nc + embedding_dim, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            self.down = nn.Sequential(downrelu, downconv)
            self.up = nn.Sequential(uprelu, upconv, upnorm)
        else:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            self.down = nn.Sequential(downrelu, downconv, downnorm)
            self.up = nn.Sequential(uprelu, upconv, upnorm)
            if use_dropout:
                self.up.add_module("dropout", nn.Dropout(0.3))

        self.submodule = submodule
        # ... (SelfAttention 和 ResSkip 的程式碼,如果有的話)

    def forward(self, x, style=None):
        encoded = self.down(x)
        # ... (SelfAttention 和 ResSkip 的程式碼,如果有的話)

        if self.innermost:
            if style is not None:
                encoded = torch.cat([style.view(style.shape[0], style.shape[1], 1, 1), encoded], dim=1)
            decoded = self.up(encoded)
            # ... (插值和 ResSkip 的程式碼,如果有的話)
            return torch.cat([x, decoded], 1), encoded.view(x.shape[0], -1)

        sub_output, encoded_real_A = self.submodule(encoded, style) if self.submodule else (encoded, None)
        decoded = self.up(sub_output)
        # ... (插值和 ResSkip 的程式碼,如果有的話)

        if self.outermost:
            return decoded, encoded_real_A
        else:
            return torch.cat([x, decoded], 1), encoded_real_A

class UNetGenerator(nn.Module):
    def __init__(self, input_nc=1, output_nc=1, num_downs=8, ngf=64, embedding_num=40, embedding_dim=128,
                 norm_layer=nn.InstanceNorm2d, use_dropout=False, self_attention=False, blur=False):
        super(UNetGenerator, self).__init__()
        
        # 最底層(innermost),負責風格嵌入處理
        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None,
                                             norm_layer=norm_layer, layer=1, embedding_dim=embedding_dim, 
                                             self_attention=self_attention, blur=blur, innermost=True)

        # 中間層
        for index in range(num_downs - 5):
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, 
                                                 norm_layer=norm_layer, layer=index+2, use_dropout=use_dropout, 
                                                 self_attention=self_attention, blur=blur)

        # 上層(恢復影像解析度)
        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, 
                                             norm_layer=norm_layer, layer=5, self_attention=self_attention, blur=blur)
        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, 
                                             norm_layer=norm_layer, layer=6, self_attention=self_attention, blur=blur)
        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, 
                                             norm_layer=norm_layer, layer=7, self_attention=self_attention, blur=blur)

        # 最外層(outermost)
        self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, 
                                             norm_layer=norm_layer, layer=8, self_attention=self_attention, blur=blur, outermost=True)

        self.embedder = nn.Embedding(embedding_num, embedding_dim)

    def _prepare_style(self, style_or_label):
        if style_or_label is not None and 'LongTensor' in style_or_label.type():
            return self.embedder(style_or_label)
        else:
            return style_or_label

    def forward(self, x, style_or_label=None):
        style = self._prepare_style(style_or_label)
        fake_B, encoded_real_A = self.model(x, style)
        return fake_B, encoded_real_A

    def encode(self, x, style_or_label=None):
        style = self._prepare_style(style_or_label)
        _, encoded_real_A = self.model(x, style)
        return encoded_real_A

主要修改:

  • UNetGenerator 初始化函數中的 norm_layer=nn.BatchNorm2d 變更為 norm_layer=nn.InstanceNorm2d
  • UnetSkipConnectionBlock 內的 norm_layer 也為nn.InstanceNorm2d

這樣一來,UNetGenerator 中的所有標準化層都會使用 InstanceNorm2d,從而增強模型對字型風格的適應能力。


好的,針對您提供的 UNetGenerator 程式碼,將其中的 BatchNorm2d 替換為 InstanceNorm2d,以更好地適應字型風格轉換的任務。以下是修改後的程式碼:

Python

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import math

class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None,
                 norm_layer=nn.InstanceNorm2d, layer=0, embedding_dim=128,
                 use_dropout=False, self_attention=False, blur=False, outermost=False, innermost=False):
        super(UnetSkipConnectionBlock, self).__init__()

        self.outermost = outermost
        self.innermost = innermost
        use_bias = norm_layer != nn.BatchNorm2d

        if input_nc is None:
            input_nc = outer_nc
        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.ReLU(inplace=False)
        upnorm = norm_layer(outer_nc)

        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            self.down = nn.Sequential(downconv)
            self.up = nn.Sequential(uprelu, upconv, nn.Tanh())
        elif innermost:
            upconv = nn.ConvTranspose2d(inner_nc + embedding_dim, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            self.down = nn.Sequential(downrelu, downconv)
            self.up = nn.Sequential(uprelu, upconv, upnorm)
        else:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            self.down = nn.Sequential(downrelu, downconv, downnorm)
            self.up = nn.Sequential(uprelu, upconv, upnorm)
            if use_dropout:
                self.up.add_module("dropout", nn.Dropout(0.3))

        self.submodule = submodule
        # ... (SelfAttention 和 ResSkip 的程式碼,如果有的話)

    def forward(self, x, style=None):
        encoded = self.down(x)
        # ... (SelfAttention 和 ResSkip 的程式碼,如果有的話)

        if self.innermost:
            if style is not None:
                encoded = torch.cat([style.view(style.shape[0], style.shape[1], 1, 1), encoded], dim=1)
            decoded = self.up(encoded)
            # ... (插值和 ResSkip 的程式碼,如果有的話)
            return torch.cat([x, decoded], 1), encoded.view(x.shape[0], -1)

        sub_output, encoded_real_A = self.submodule(encoded, style) if self.submodule else (encoded, None)
        decoded = self.up(sub_output)
        # ... (插值和 ResSkip 的程式碼,如果有的話)

        if self.outermost:
            return decoded, encoded_real_A
        else:
            return torch.cat([x, decoded], 1), encoded_real_A

class UNetGenerator(nn.Module):
    def __init__(self, input_nc=1, output_nc=1, num_downs=8, ngf=64, embedding_num=40, embedding_dim=128,
                 norm_layer=nn.InstanceNorm2d, use_dropout=False, self_attention=False, blur=False):
        super(UNetGenerator, self).__init__()
        
        # 最底層(innermost),負責風格嵌入處理
        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None,
                                             norm_layer=norm_layer, layer=1, embedding_dim=embedding_dim, 
                                             self_attention=self_attention, blur=blur, innermost=True)

        # 中間層
        for index in range(num_downs - 5):
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, 
                                                 norm_layer=norm_layer, layer=index+2, use_dropout=use_dropout, 
                                                 self_attention=self_attention, blur=blur)

        # 上層(恢復影像解析度)
        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, 
                                             norm_layer=norm_layer, layer=5, self_attention=self_attention, blur=blur)
        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, 
                                             norm_layer=norm_layer, layer=6, self_attention=self_attention, blur=blur)
        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, 
                                             norm_layer=norm_layer, layer=7, self_attention=self_attention, blur=blur)

        # 最外層(outermost)
        self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, 
                                             norm_layer=norm_layer, layer=8, self_attention=self_attention, blur=blur, outermost=True)

        self.embedder = nn.Embedding(embedding_num, embedding_dim)

    def _prepare_style(self, style_or_label):
        if style_or_label is not None and 'LongTensor' in style_or_label.type():
            return self.embedder(style_or_label)
        else:
            return style_or_label

    def forward(self, x, style_or_label=None):
        style = self._prepare_style(style_or_label)
        fake_B, encoded_real_A = self.model(x, style)
        return fake_B, encoded_real_A

    def encode(self, x, style_or_label=None):
        style = self._prepare_style(style_or_label)
        _, encoded_real_A = self.model(x, style)
        return encoded_real_A

主要修改:

  • UNetGenerator 初始化函數中的 norm_layer=nn.BatchNorm2d 變更為 norm_layer=nn.InstanceNorm2d
  • UnetSkipConnectionBlock 內的 norm_layer 也為nn.InstanceNorm2d

這樣一來,UNetGenerator 中的所有標準化層都會使用 InstanceNorm2d,從而增強模型對字型風格的適應能力。


好的,針對字型風格轉換(Noto Sans → Zen Maru Gothic)的任務,並將 Discriminator 中的 InstanceNorm2d 替換為 BatchNorm2d,以下是修改後的程式碼:

Python

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import math

class Discriminator(nn.Module):
    def __init__(self, input_nc, embedding_num, ndf=64, norm_layer=nn.BatchNorm2d,
                 image_size=256, final_channels=1, blur=False):
        super(Discriminator, self).__init__()
        
        use_bias = norm_layer != nn.BatchNorm2d
        kw = 5
        padw = 2

        sequence = [
            nn.utils.spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw)),
            nn.LeakyReLU(0.2, True)
        ]

        nf_mult = 1
        for n in range(1, 3):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                nn.utils.spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias)),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        sequence += [
            nn.utils.spectral_norm(nn.Conv2d(ndf * nf_mult, final_channels, kernel_size=kw, stride=1, padding=padw, bias=use_bias)),
            norm_layer(final_channels),
            nn.LeakyReLU(0.2, True)
        ]

        self.model = nn.Sequential(*sequence)
        image_size = math.ceil(image_size / 8)
        final_features = final_channels * image_size * image_size
        self.binary = nn.Linear(final_features, 1)
        self.category = nn.Linear(final_features, embedding_num)
        self.blur = blur
        self.gaussian_blur = T.GaussianBlur(kernel_size=1, sigma=1.0)  # 設定模糊程度

    def forward(self, input):
        features = self.model(input)
        if self.blur:
            features = self.gaussian_blur(features)
        features = features.view(input.shape[0], -1)
        binary_logits = self.binary(features)
        category_logits = self.category(features)
        return binary_logits, category_logits

class CategoryLoss(nn.Module):
    def __init__(self, category_num):
        super(CategoryLoss, self).__init__()
        emb = nn.Embedding(category_num, category_num)
        emb.weight.data = torch.eye(category_num)
        self.emb = emb
        self.loss = nn.BCEWithLogitsLoss()

    def forward(self, category_logits, labels):
        target = self.emb(labels)
        return self.loss(category_logits, target)

class Zi2ZiModel:
    def __init__(self, input_nc=1, embedding_num=40, embedding_dim=128, ngf=64, ndf=64,
                 Lconst_penalty=10, Lcategory_penalty=1, L1_penalty=100,
                 schedule=10, lr=0.001, gpu_ids=None, save_dir='.', is_training=True,
                 image_size=256, self_attention=False, residual_block=False,
                 weight_decay=1e-5, final_channels=1, beta1=0.5, g_blur=False, d_blur=False, epoch=40):

        if is_training:
            self.use_dropout = True
        else:
            self.use_dropout = False

        self.Lconst_penalty = Lconst_penalty
        self.Lcategory_penalty = Lcategory_penalty
        self.L1_penalty = L1_penalty

        self.schedule = schedule

        self.save_dir = save_dir
        self.gpu_ids = gpu_ids

        self.input_nc = input_nc
        self.embedding_dim = embedding_dim
        self.embedding_num = embedding_num
        self.ngf = ngf
        self.ndf = ndf
        self.lr = lr
        self.beta1 = beta1
        self.weight_decay = weight_decay
        self.is_training = is_training
        self.image_size = image_size
        self.self_attention = self_attention
        self.residual_block = residual_block
        self.final_channels = final_channels
        self.epoch = epoch
        self.g_blur = g_blur
        self.d_blur = d_blur

        self.scaler_G = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None
        self.scaler_D = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None
        device = torch.device("cuda" if self.gpu_ids and torch.cuda.is_available() else "cpu")
        self.device = device
        self.feature_matching_loss = nn.L1Loss()

    def setup(self):
        # ... (UNetGenerator 的設置,保持 InstanceNorm2d)
        self.netG = UNetGenerator(
            input_nc=self.input_nc,
            output_nc=self.input_nc,
            ngf=self.ngf,
            use_dropout=self.use_dropout,
            embedding_num=self.embedding_num,
            embedding_dim=self.embedding_dim,
            self_attention=self.self_attention,
            blur=self.g_blur,
            norm_layer=nn.InstanceNorm2d
        )
        # Discriminator 的設置,使用 BatchNorm2d
        self.netD = Discriminator(
            input_nc=2 * self.input_nc,
            embedding_num=self.embedding_num,
            ndf=self.ndf,
            final_channels=self.final_channels,
            image_size=self.image_size,
            blur=self.d_blur,
            norm_layer=nn.BatchNorm2d
        )

        init_net(self.netG, gpu_ids=self.gpu_ids)
        init_net(self.netD, gpu_ids=self.gpu_ids)

        self.optimizer_G = torch.optim.AdamW(self.netG.parameters(), lr=self.lr, betas=(self.beta1, 0.999), weight_decay=self.weight_decay)
        self.optimizer_D = torch.optim.AdamW(self.netD.parameters(), lr=self.lr, betas=(self.beta1, 0.999), weight_decay=self.weight_decay)
        
        eta_min = 1e-6
        self.scheduler_G = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer_G, T_max=self.epoch, eta_min=eta_min)
        self.scheduler_D = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer_D, T_max=self.epoch, eta_min=eta_min)

        self.category_loss = CategoryLoss(self.embedding_num)
        self.l1_loss = nn.L1Loss()
        self.mse = nn.MSELoss()

        if self.gpu_ids:
            self.category_loss.cuda()
            self.l1_loss.cuda()
            self.mse.cuda()

        if self.is_training:
            self.netD.train()
            self.netG.train()
        else:
            self.netD.eval()
            self.netG.eval()
            
    # ... (其他 Zi2ZiModel 中的方法,保持不變)

主要修改:

  • Discriminator__init__ 函數中,norm_layer 的預設值改為 nn.BatchNorm2d
  • Zi2ZiModelsetup 函數中,self.netD 的初始化部分,norm_layer 參數被明確設定為 nn.BatchNorm2d

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *