在字型風格轉換的任務中,Generator
和 Discriminator
使用不同的 norm_layer
可能會有更好的效果。這是因為:
- Generator (使用
InstanceNorm2d
)InstanceNorm2d
對於風格遷移任務來說較好,因為它能夠減少 batch-wise 的統計信息影響,保留個別樣本的風格特徵,這有助於生成更加一致的風格化結果。
- Discriminator (
BatchNorm2d
或InstanceNorm2d
?)- 如果 Discriminator 也使用
InstanceNorm2d
,它可能更關注局部的紋理特徵,而非 batch-level 的分佈,這可能會讓風格約束變得較弱。 - 使用
BatchNorm2d
,Discriminator 會學習到更全局的統計信息,可能會使它更容易分辨出真實與生成的樣本,從而給 Generator 施加更強的學習壓力,提升風格的一致性。
- 如果 Discriminator 也使用
建議
- 如果你希望 Generator 主導風格變換,建議 Discriminator 使用
BatchNorm2d
。 這樣 Generator 會更專注於學習風格轉換,而 Discriminator 主要負責區分真實與偽造字型樣本。 - 如果發現 Discriminator 太強導致 Generator 學習困難,可以改為
InstanceNorm2d
,讓 Discriminator 更聚焦於局部特徵而非 batch 級別的統計。
你可以先嘗試 Discriminator
使用 BatchNorm2d
,看看是否提升風格適應效果,再根據結果調整!
好的,針對您提供的 UNetGenerator
程式碼,將其中的 BatchNorm2d
替換為 InstanceNorm2d
,以更好地適應字型風格轉換的任務。以下是修改後的程式碼:
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import math
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None,
norm_layer=nn.InstanceNorm2d, layer=0, embedding_dim=128,
use_dropout=False, self_attention=False, blur=False, outermost=False, innermost=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
self.innermost = innermost
use_bias = norm_layer != nn.BatchNorm2d
if input_nc is None:
input_nc = outer_nc
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(inplace=False)
upnorm = norm_layer(outer_nc)
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
self.down = nn.Sequential(downconv)
self.up = nn.Sequential(uprelu, upconv, nn.Tanh())
elif innermost:
upconv = nn.ConvTranspose2d(inner_nc + embedding_dim, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
self.down = nn.Sequential(downrelu, downconv)
self.up = nn.Sequential(uprelu, upconv, upnorm)
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
self.down = nn.Sequential(downrelu, downconv, downnorm)
self.up = nn.Sequential(uprelu, upconv, upnorm)
if use_dropout:
self.up.add_module("dropout", nn.Dropout(0.3))
self.submodule = submodule
# ... (SelfAttention 和 ResSkip 的程式碼,如果有的話)
def forward(self, x, style=None):
encoded = self.down(x)
# ... (SelfAttention 和 ResSkip 的程式碼,如果有的話)
if self.innermost:
if style is not None:
encoded = torch.cat([style.view(style.shape[0], style.shape[1], 1, 1), encoded], dim=1)
decoded = self.up(encoded)
# ... (插值和 ResSkip 的程式碼,如果有的話)
return torch.cat([x, decoded], 1), encoded.view(x.shape[0], -1)
sub_output, encoded_real_A = self.submodule(encoded, style) if self.submodule else (encoded, None)
decoded = self.up(sub_output)
# ... (插值和 ResSkip 的程式碼,如果有的話)
if self.outermost:
return decoded, encoded_real_A
else:
return torch.cat([x, decoded], 1), encoded_real_A
class UNetGenerator(nn.Module):
def __init__(self, input_nc=1, output_nc=1, num_downs=8, ngf=64, embedding_num=40, embedding_dim=128,
norm_layer=nn.InstanceNorm2d, use_dropout=False, self_attention=False, blur=False):
super(UNetGenerator, self).__init__()
# 最底層(innermost),負責風格嵌入處理
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None,
norm_layer=norm_layer, layer=1, embedding_dim=embedding_dim,
self_attention=self_attention, blur=blur, innermost=True)
# 中間層
for index in range(num_downs - 5):
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block,
norm_layer=norm_layer, layer=index+2, use_dropout=use_dropout,
self_attention=self_attention, blur=blur)
# 上層(恢復影像解析度)
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block,
norm_layer=norm_layer, layer=5, self_attention=self_attention, blur=blur)
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block,
norm_layer=norm_layer, layer=6, self_attention=self_attention, blur=blur)
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block,
norm_layer=norm_layer, layer=7, self_attention=self_attention, blur=blur)
# 最外層(outermost)
self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block,
norm_layer=norm_layer, layer=8, self_attention=self_attention, blur=blur, outermost=True)
self.embedder = nn.Embedding(embedding_num, embedding_dim)
def _prepare_style(self, style_or_label):
if style_or_label is not None and 'LongTensor' in style_or_label.type():
return self.embedder(style_or_label)
else:
return style_or_label
def forward(self, x, style_or_label=None):
style = self._prepare_style(style_or_label)
fake_B, encoded_real_A = self.model(x, style)
return fake_B, encoded_real_A
def encode(self, x, style_or_label=None):
style = self._prepare_style(style_or_label)
_, encoded_real_A = self.model(x, style)
return encoded_real_A
主要修改:
- 將
UNetGenerator
初始化函數中的norm_layer=nn.BatchNorm2d
變更為norm_layer=nn.InstanceNorm2d
。 UnetSkipConnectionBlock
內的norm_layer
也為nn.InstanceNorm2d
。
這樣一來,UNetGenerator
中的所有標準化層都會使用 InstanceNorm2d
,從而增強模型對字型風格的適應能力。
好的,針對您提供的 UNetGenerator
程式碼,將其中的 BatchNorm2d
替換為 InstanceNorm2d
,以更好地適應字型風格轉換的任務。以下是修改後的程式碼:
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import math
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None,
norm_layer=nn.InstanceNorm2d, layer=0, embedding_dim=128,
use_dropout=False, self_attention=False, blur=False, outermost=False, innermost=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
self.innermost = innermost
use_bias = norm_layer != nn.BatchNorm2d
if input_nc is None:
input_nc = outer_nc
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(inplace=False)
upnorm = norm_layer(outer_nc)
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
self.down = nn.Sequential(downconv)
self.up = nn.Sequential(uprelu, upconv, nn.Tanh())
elif innermost:
upconv = nn.ConvTranspose2d(inner_nc + embedding_dim, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
self.down = nn.Sequential(downrelu, downconv)
self.up = nn.Sequential(uprelu, upconv, upnorm)
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
self.down = nn.Sequential(downrelu, downconv, downnorm)
self.up = nn.Sequential(uprelu, upconv, upnorm)
if use_dropout:
self.up.add_module("dropout", nn.Dropout(0.3))
self.submodule = submodule
# ... (SelfAttention 和 ResSkip 的程式碼,如果有的話)
def forward(self, x, style=None):
encoded = self.down(x)
# ... (SelfAttention 和 ResSkip 的程式碼,如果有的話)
if self.innermost:
if style is not None:
encoded = torch.cat([style.view(style.shape[0], style.shape[1], 1, 1), encoded], dim=1)
decoded = self.up(encoded)
# ... (插值和 ResSkip 的程式碼,如果有的話)
return torch.cat([x, decoded], 1), encoded.view(x.shape[0], -1)
sub_output, encoded_real_A = self.submodule(encoded, style) if self.submodule else (encoded, None)
decoded = self.up(sub_output)
# ... (插值和 ResSkip 的程式碼,如果有的話)
if self.outermost:
return decoded, encoded_real_A
else:
return torch.cat([x, decoded], 1), encoded_real_A
class UNetGenerator(nn.Module):
def __init__(self, input_nc=1, output_nc=1, num_downs=8, ngf=64, embedding_num=40, embedding_dim=128,
norm_layer=nn.InstanceNorm2d, use_dropout=False, self_attention=False, blur=False):
super(UNetGenerator, self).__init__()
# 最底層(innermost),負責風格嵌入處理
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None,
norm_layer=norm_layer, layer=1, embedding_dim=embedding_dim,
self_attention=self_attention, blur=blur, innermost=True)
# 中間層
for index in range(num_downs - 5):
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block,
norm_layer=norm_layer, layer=index+2, use_dropout=use_dropout,
self_attention=self_attention, blur=blur)
# 上層(恢復影像解析度)
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block,
norm_layer=norm_layer, layer=5, self_attention=self_attention, blur=blur)
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block,
norm_layer=norm_layer, layer=6, self_attention=self_attention, blur=blur)
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block,
norm_layer=norm_layer, layer=7, self_attention=self_attention, blur=blur)
# 最外層(outermost)
self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block,
norm_layer=norm_layer, layer=8, self_attention=self_attention, blur=blur, outermost=True)
self.embedder = nn.Embedding(embedding_num, embedding_dim)
def _prepare_style(self, style_or_label):
if style_or_label is not None and 'LongTensor' in style_or_label.type():
return self.embedder(style_or_label)
else:
return style_or_label
def forward(self, x, style_or_label=None):
style = self._prepare_style(style_or_label)
fake_B, encoded_real_A = self.model(x, style)
return fake_B, encoded_real_A
def encode(self, x, style_or_label=None):
style = self._prepare_style(style_or_label)
_, encoded_real_A = self.model(x, style)
return encoded_real_A
主要修改:
- 將
UNetGenerator
初始化函數中的norm_layer=nn.BatchNorm2d
變更為norm_layer=nn.InstanceNorm2d
。 UnetSkipConnectionBlock
內的norm_layer
也為nn.InstanceNorm2d
。
這樣一來,UNetGenerator
中的所有標準化層都會使用 InstanceNorm2d
,從而增強模型對字型風格的適應能力。
好的,針對字型風格轉換(Noto Sans → Zen Maru Gothic)的任務,並將 Discriminator 中的 InstanceNorm2d
替換為 BatchNorm2d
,以下是修改後的程式碼:
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import math
class Discriminator(nn.Module):
def __init__(self, input_nc, embedding_num, ndf=64, norm_layer=nn.BatchNorm2d,
image_size=256, final_channels=1, blur=False):
super(Discriminator, self).__init__()
use_bias = norm_layer != nn.BatchNorm2d
kw = 5
padw = 2
sequence = [
nn.utils.spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw)),
nn.LeakyReLU(0.2, True)
]
nf_mult = 1
for n in range(1, 3):
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
nn.utils.spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias)),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [
nn.utils.spectral_norm(nn.Conv2d(ndf * nf_mult, final_channels, kernel_size=kw, stride=1, padding=padw, bias=use_bias)),
norm_layer(final_channels),
nn.LeakyReLU(0.2, True)
]
self.model = nn.Sequential(*sequence)
image_size = math.ceil(image_size / 8)
final_features = final_channels * image_size * image_size
self.binary = nn.Linear(final_features, 1)
self.category = nn.Linear(final_features, embedding_num)
self.blur = blur
self.gaussian_blur = T.GaussianBlur(kernel_size=1, sigma=1.0) # 設定模糊程度
def forward(self, input):
features = self.model(input)
if self.blur:
features = self.gaussian_blur(features)
features = features.view(input.shape[0], -1)
binary_logits = self.binary(features)
category_logits = self.category(features)
return binary_logits, category_logits
class CategoryLoss(nn.Module):
def __init__(self, category_num):
super(CategoryLoss, self).__init__()
emb = nn.Embedding(category_num, category_num)
emb.weight.data = torch.eye(category_num)
self.emb = emb
self.loss = nn.BCEWithLogitsLoss()
def forward(self, category_logits, labels):
target = self.emb(labels)
return self.loss(category_logits, target)
class Zi2ZiModel:
def __init__(self, input_nc=1, embedding_num=40, embedding_dim=128, ngf=64, ndf=64,
Lconst_penalty=10, Lcategory_penalty=1, L1_penalty=100,
schedule=10, lr=0.001, gpu_ids=None, save_dir='.', is_training=True,
image_size=256, self_attention=False, residual_block=False,
weight_decay=1e-5, final_channels=1, beta1=0.5, g_blur=False, d_blur=False, epoch=40):
if is_training:
self.use_dropout = True
else:
self.use_dropout = False
self.Lconst_penalty = Lconst_penalty
self.Lcategory_penalty = Lcategory_penalty
self.L1_penalty = L1_penalty
self.schedule = schedule
self.save_dir = save_dir
self.gpu_ids = gpu_ids
self.input_nc = input_nc
self.embedding_dim = embedding_dim
self.embedding_num = embedding_num
self.ngf = ngf
self.ndf = ndf
self.lr = lr
self.beta1 = beta1
self.weight_decay = weight_decay
self.is_training = is_training
self.image_size = image_size
self.self_attention = self_attention
self.residual_block = residual_block
self.final_channels = final_channels
self.epoch = epoch
self.g_blur = g_blur
self.d_blur = d_blur
self.scaler_G = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None
self.scaler_D = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None
device = torch.device("cuda" if self.gpu_ids and torch.cuda.is_available() else "cpu")
self.device = device
self.feature_matching_loss = nn.L1Loss()
def setup(self):
# ... (UNetGenerator 的設置,保持 InstanceNorm2d)
self.netG = UNetGenerator(
input_nc=self.input_nc,
output_nc=self.input_nc,
ngf=self.ngf,
use_dropout=self.use_dropout,
embedding_num=self.embedding_num,
embedding_dim=self.embedding_dim,
self_attention=self.self_attention,
blur=self.g_blur,
norm_layer=nn.InstanceNorm2d
)
# Discriminator 的設置,使用 BatchNorm2d
self.netD = Discriminator(
input_nc=2 * self.input_nc,
embedding_num=self.embedding_num,
ndf=self.ndf,
final_channels=self.final_channels,
image_size=self.image_size,
blur=self.d_blur,
norm_layer=nn.BatchNorm2d
)
init_net(self.netG, gpu_ids=self.gpu_ids)
init_net(self.netD, gpu_ids=self.gpu_ids)
self.optimizer_G = torch.optim.AdamW(self.netG.parameters(), lr=self.lr, betas=(self.beta1, 0.999), weight_decay=self.weight_decay)
self.optimizer_D = torch.optim.AdamW(self.netD.parameters(), lr=self.lr, betas=(self.beta1, 0.999), weight_decay=self.weight_decay)
eta_min = 1e-6
self.scheduler_G = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer_G, T_max=self.epoch, eta_min=eta_min)
self.scheduler_D = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer_D, T_max=self.epoch, eta_min=eta_min)
self.category_loss = CategoryLoss(self.embedding_num)
self.l1_loss = nn.L1Loss()
self.mse = nn.MSELoss()
if self.gpu_ids:
self.category_loss.cuda()
self.l1_loss.cuda()
self.mse.cuda()
if self.is_training:
self.netD.train()
self.netG.train()
else:
self.netD.eval()
self.netG.eval()
# ... (其他 Zi2ZiModel 中的方法,保持不變)
主要修改:
Discriminator
的__init__
函數中,norm_layer
的預設值改為nn.BatchNorm2d
。Zi2ZiModel
的setup
函數中,self.netD
的初始化部分,norm_layer
參數被明確設定為nn.BatchNorm2d