下面我會把整段程式完整審查,指出 邏輯錯誤 / 潛在 Bug / 損失設計問題 / VGG16 與你自己 Encoder embedding 的不一致問題,並提出具體修正建議與更好的超參數設定。
✅ 第一部分:你的 Training Loop / forward / loss 設計 — 明確錯誤與建議
❗ 重大問題 1:Perceptual Loss layer_index = 15 (relu3_3) 對字體/黑白影像完全不合適
你目前用:
PerceptualLoss(layer_index=15) # relu3_3
這是 VGG16 的高語意層,對字體 / 線條圖像幾乎沒意義。
❗ 問題:
VGG16 是用 ImageNet 訓練的,relu3_3 之後開始專注於:
- dogs, cars, textures, edges with semantic context
- natural object structure
你的任務(Zi2Zi、字型、黑白):
- 高度依賴筆畫幾何
- 不含自然語意
→ relu3_3 會讓 loss 非常 noisy,對 Generator 發出強烈的錯誤梯度。
✔ 建議
只用 VGG 的低層特徵:
| 層 | Index | 建議 |
|---|---|---|
| relu1_2 | 3 | ✔ 優先使用 |
| relu2_2 | 8 | ✔ 可以使用 |
| relu3_3 | 15 | ❌ 不建議 (語意層) |
➡ 建議你把 layer_index 調成 3 或 8。
❗ 重大問題 2:VGG16 space 與 UNet embedding space 完全不一致
你現在的 const_loss:
fake_B_emb, real_B_emb = UNet return_feat()
self.loss_const = self.l1_loss(fake_B_emb, real_B_emb)
而 perceptual loss:
self.perceptual_loss(fake_B, real_B)
→ 兩者是在 完全不同空間:
| Loss | 來自哪裡 | 語意 | 特徵尺度 | 與影像相關性 |
|---|---|---|---|---|
| Const Loss | UNet bottleneck embedding | 字體的空間特徵 | low-level/geometry-based | 對字形精準 |
| VGG16 Perceptual | ImageNet VGG | 自然影像語意 | high-level | 可能錯誤引導 |
➡ 兩個 loss 的梯度方向會產生衝突,尤其是在字型任務。
❗ 重大問題 3:Generator loss 裡同時乘上 4 個高強度 loss → 梯度不穩
你目前的權重:
L1_penalty = 100
Lconst_penalty = 15
Lperceptual_penalty = 1? or >
GAN = 1
Category_loss ~1
→ L1 * 100 已經非常強
→ const_loss * 15 也很強
→ 再加 perceptual_loss * (未知?)
→ GAN 的比例完全被壓掉
在這種 setup 下:
❗ Generator 會幾乎完全忽略 GAN,只做 regression
(這會讓結果變模糊、缺乏銳利邊緣)
❗ 重大問題 4:PerceptualLoss 的 mean/std 會在 forward 中動態 .to()(可能破壞 no-grad 特性)
你寫了:
if input_img.is_cuda and not self.mean.is_cuda:
self.mean = self.mean.to(input_img.device)
這會:
- 修改 module buffer
- 可能導致梯度圖變化
- 不建議在 forward 中使用
.to()
✔ 正確方式:
在初始化時就移動到 GPU,而不是 forward 時改 buffer。
❗ 重大問題 5:Discriminator loss 的 category loss 權重不一致(D 乘 0.5,G 不乘 0.5)
看這段:
self.category_loss_D = (real_category_loss + fake_category_loss_d) * self.Lcategory_penalty
self.d_loss = (loss_D_real + loss_D_fake) * 0.5 + self.category_loss_D * 0.5
後面又用了:
fake_category_loss_G = self.category_loss(fake_category_logits, labels) * self.Lcategory_penalty
這會造成:
- D 的 category loss 被縮減為 0.5 倍
- G 的 category loss 全額
→ 不對稱
我建議保持一致。
❗ 重大問題 6:Discriminator 的 pred_fake 用 updated D(不是 standard GAN 方式)
你用:
pred_fake = self.netD(fake_AB)
而 D 已經更新過。
Standard GAN 兩種方式都有人用,但:
- 若使用 non-saturating GAN loss
- 通常 pred_fake 會用 更新後的 D(你目前即是如此)
所以這不是錯誤,只要你 aware。
❗ 重大問題 7:perceptual_loss 可能太強,造成梯度不穩
在字體任務上,建議:
L1_penalty = 100
Lconst_penalty = 5–10
Lperceptual_penalty = 0.01–0.05
GAN = 1
Category = 1
🔍 總體結論:最大的問題在 VGG16(不一致性 + 權重過強)
| 問題 | 嚴重度 |
|---|---|
| VGG16 層太深(relu3_3) | ❗❗❗ |
| VGG space 與你 UNet embedding space 衝突 | ❗❗❗ |
| loss 權重比例失衡 | ❗❗ |
| perceptual loss 對字型不適合 | ❗❗ |
| forward 動態移動 mean/std | ❗ |
| D/G category loss 權重不一致 | ❗ |
⭐ 建議你改善成以下版本
✔ 1. 修改 PerceptualLoss,使用低層:
PerceptualLoss(layer_index=3) # relu1_2
或使用 8(relu2_2)
✔ 2. 調整損失權重:
建議:
L1_penalty = 50–100
Lconst_penalty = 5
Lperceptual_penalty = 0.02
GAN = 1
Category = 1
✔ 3. 若你要保留 const_loss,應避免使用太深的 perceptual loss
因為:
- const_loss(UNet bottleneck feature)
- perceptual loss(ImageNet VGG low-level feature)
→ 同時存在時,只適合使用 低層 VGG(1_2 或 2_2)。
✔ 4. 修正 PerceptualLoss forward 中移動 mean/std:
改成 init 中就放進 device:
self.register_buffer('mean', mean.to(device))
self.register_buffer('std', std.to(device))
✔ 5. 更穩定的 G-loss 配比:
G_loss =
GAN_loss * 1.0 +
L1_loss * 50 +
const_loss * 5 +
perceptual_loss * 0.02 +
category_loss * 1
🎁 如果你願意
我可以幫你提供:
✔ 完整修正後的 PerceptualLoss class
(支援多層特徵 + 字體 friendly)
✔ 完整修改過的 Zi2ZiModel(正確 loss flow)
✔ 統一的損失權重建議
✔ 與 return_feat embedding 不衝突的 perceptual loss 設計
好的 — 我已把你要求的修正版整合好了:一個更健壯的 PerceptualLoss(多層、低層預設、device-safe),以及修正過的 Zi2ZiModel 關鍵部分(初始化、forward、optimize_parameters)。我同時修正了先前我指出的 bug/不良做法,並把常見的安全處理(feature pooling、權重建議、category loss 對稱)都放進去。你可以直接把下列程式碼片段貼回你的專案(替換對應 class / method)。
主要改動點(摘要)
- PerceptualLoss:預設用 VGG 低階層
[3, 8](relu1_2, relu2_2)、會在初始化時把模型及 mean/std 移到 device、回傳多層 feature 的加權 L1。- Zi2ZiModel:
- 修正重複/錯誤屬性賦值(例如 fake_binary_loss 重複);
- 將 perceptual module 在 setup 時建立並放到正確 device;
- 在 forward() 計算 const_loss 時對
feat做adaptive_avg_pool2d(..., (1,1))轉成向量,比較穩定(避免 spatial mismatch);- 調整 category loss 在 D、G 中的對稱處理;
- 建議并示範合理預設權重(你可按需微調)。
1) PerceptualLoss(整個 class)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import vgg16, VGG16_Weights
class PerceptualLoss(nn.Module):
"""
Perceptual loss using VGG features.
- layers: list of feature indices to extract (indices in torchvision vgg.features)
defaults to [3,8] -> relu1_2, relu2_2 (good for line/texture tasks like fonts)
- weights: list or scalar weight per layer (same length as layers) or single scalar broadcasted
- device: place the vgg and buffers on this device at init
"""
def __init__(self, layers=(3, 8), weights=1.0, device=None):
super().__init__()
# Accept tuple/list or single int
if isinstance(layers, int):
layers = (layers,)
self.layers = tuple(layers)
# weights handling
if isinstance(weights, (int, float)):
self.layer_weights = [float(weights)] * len(self.layers)
else:
assert len(weights) == len(self.layers)
self.layer_weights = [float(w) for w in weights]
# Load pretrained VGG features
vgg = vgg16(weights=VGG16_Weights.DEFAULT)
self.max_idx = max(self.layers)
# we'll iterate features up to max_idx and collect outputs at indices in self.layers
self.vgg_features = vgg.features[: self.max_idx + 1].eval()
# freeze params
for p in self.vgg_features.parameters():
p.requires_grad = False
# device placement
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = device
self.vgg_features = self.vgg_features.to(self.device)
# ImageNet mean/std (buffers)
mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
self.register_buffer("mean", mean.to(self.device))
self.register_buffer("std", std.to(self.device))
self.criterion = nn.L1Loss()
def _extract_features(self, x):
"""
run through vgg_features and collect outputs at indicated indices
returns list of tensors for selected layers
"""
features = []
cur = x
for idx, layer in enumerate(self.vgg_features):
cur = layer(cur)
if idx in self.layers:
# clone to avoid accidental in-place changes
features.append(cur)
return features
def forward(self, input_img, target_img):
"""
input_img, target_img: expected in range [-1, 1] (GAN convention). Channels can be 1 or 3.
Returns weighted sum of L1 distances between selected VGG feature maps.
"""
# move inputs to device
input_img = input_img.to(self.device)
target_img = target_img.to(self.device)
# convert [-1,1] -> [0,1]
input_img = (input_img + 1.0) * 0.5
target_img = (target_img + 1.0) * 0.5
# if grayscale -> replicate to 3 channels
if input_img.shape[1] == 1:
input_img = input_img.repeat(1, 3, 1, 1)
if target_img.shape[1] == 1:
target_img = target_img.repeat(1, 3, 1, 1)
# normalize with ImageNet mean/std (buffers already on device)
input_img = (input_img - self.mean) / self.std
target_img = (target_img - self.mean) / self.std
# Extract features
inp_feats = self._extract_features(input_img)
tgt_feats = self._extract_features(target_img)
assert len(inp_feats) == len(self.layer_weights)
loss = 0.0
for f_inp, f_tgt, w in zip(inp_feats, tgt_feats, self.layer_weights):
loss = loss + w * self.criterion(f_inp, f_tgt)
return loss
2) Zi2ZiModel(關鍵段落:__init__ / setup() / forward() / optimize_parameters())
把下面整段替換到你的 Zi2ZiModel 類(或以 patch 合併關鍵修改),它包含了上述修正與更安全的流程。
import torch
import torch.nn as nn
import torch.nn.functional as F
import functools
class Zi2ZiModel:
def __init__(self, input_nc=1, embedding_num=40, embedding_dim=128,
ngf=64, ndf=64,
Lconst_penalty=5.0, Lcategory_penalty=1.0, L1_penalty=100.0, Lperceptual_penalty=0.02,
schedule=10, lr=2e-4, gpu_ids=None, save_dir='.', is_training=True,
image_size=256, self_attention=False, d_spectral_norm=False, norm_type="instance"):
# basic attrs
self.gpu_ids = gpu_ids
self.device = torch.device("cuda" if self.gpu_ids and torch.cuda.is_available() else "cpu")
self.use_dropout = bool(is_training)
# loss weights
self.Lconst_penalty = float(Lconst_penalty)
self.Lcategory_penalty = float(Lcategory_penalty)
self.L1_penalty = float(L1_penalty)
self.Lperceptual_penalty = float(Lperceptual_penalty)
self.schedule = schedule
self.save_dir = save_dir
self.input_nc = input_nc
self.embedding_dim = embedding_dim
self.embedding_num = embedding_num
self.ngf = ngf
self.ndf = ndf
self.lr = lr
self.is_training = is_training
self.image_size = image_size
self.self_attention = self_attention
self.d_spectral_norm = d_spectral_norm
self.norm_type = norm_type
# placeholders set in setup
self.netG = None
self.netD = None
self.setup()
def setup(self):
# choose norm
if self.norm_type == 'batch':
norm_layer = nn.BatchNorm2d
elif self.norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
else:
raise NotImplementedError('normalization layer [%s] is not found' % self.norm_type)
# build nets (assumes UNetGenerator and Discriminator are defined and imported)
num_downs = 8 if self.image_size != 384 else 7
self.netG = UNetGenerator(
input_nc=self.input_nc,
output_nc=self.input_nc,
embedding_num=self.embedding_num,
embedding_dim=self.embedding_dim,
ngf=self.ngf,
norm_layer=norm_layer,
use_dropout=self.use_dropout,
num_downs=num_downs,
self_attention=self.self_attention
).to(self.device)
self.netD = Discriminator(
input_nc=2 * self.input_nc,
embedding_num=self.embedding_num,
ndf=self.ndf,
norm_layer=norm_layer,
image_size=self.image_size,
use_spectral_norm=self.d_spectral_norm
).to(self.device)
init_net(self.netG, gpu_ids=self.gpu_ids)
init_net(self.netD, gpu_ids=self.gpu_ids)
# optimizers (lower lr recommended from 1e-3 -> 2e-4)
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=self.lr, betas=(0.5, 0.999))
self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=self.lr, betas=(0.5, 0.999))
# losses
self.category_loss = CategoryLoss(self.embedding_num).to(self.device)
self.real_binary_loss = BinaryLoss(True).to(self.device)
self.fake_binary_loss = BinaryLoss(False).to(self.device)
self.l1_loss = nn.L1Loss().to(self.device)
# perceptual loss (use low-level default layers [3,8])
self.perceptual_loss = PerceptualLoss(layers=(3,8), weights=1.0, device=self.device)
self.mse = nn.MSELoss().to(self.device)
# training mode
if self.is_training:
self.netD.train()
self.netG.train()
else:
self.netD.eval()
self.netG.eval()
def set_input(self, data):
self.model_input_data = data
self.labels = data['label'].to(self.device)
self.real_A = data['A'].to(self.device)
self.real_B = data['B'].to(self.device)
def forward(self):
"""
Do single forward pass:
- produce fake_B, fake_B_emb from netG(A)
- produce real_B_emb from netG(real_B)
- compute l1, const, perceptual (perceptual only computed if Lperceptual_penalty > 0)
Note: const loss uses adaptive avg pooling to convert feature maps to vectors
"""
# produce fake and fake embedding in one forward
self.fake_B, fake_B_emb = self.netG(self.real_A, self.labels, return_feat=True)
# produce real embedding (one forward on real_B)
_, real_B_emb = self.netG(self.real_B, self.labels, return_feat=True)
# L1 reconstruction (pixel)
self.loss_l1 = self.l1_loss(self.fake_B, self.real_B)
# const loss: pool to vectors to avoid spatial mismatch
# adaptive avg pool -> (B, C, 1, 1) -> flatten
if fake_B_emb is None or real_B_emb is None:
# fallback (should not happen if netG return_feat True)
self.loss_const = torch.tensor(0.0, device=self.device)
else:
f_vec = F.adaptive_avg_pool2d(fake_B_emb, (1,1)).view(fake_B_emb.size(0), -1)
r_vec = F.adaptive_avg_pool2d(real_B_emb, (1,1)).view(real_B_emb.size(0), -1)
self.loss_const = self.mse(f_vec, r_vec) # MSE for embedding constancy
# perceptual loss (on pixels) using low-level VGG features
if self.Lperceptual_penalty > 0.0:
self.loss_perceptual = self.perceptual_loss(self.fake_B, self.real_B)
else:
self.loss_perceptual = torch.tensor(0.0, device=self.device)
# store for optimize step
self.fake_B_emb = fake_B_emb
self.real_B_emb = real_B_emb
self.fake_B = self.fake_B # already set above
return
def optimize_parameters(self):
# forward first
self.forward()
real_A = self.real_A
real_B = self.real_B
fake_B = self.fake_B
labels = self.labels
fake_AB = torch.cat([real_A, fake_B], dim=1)
real_AB = torch.cat([real_A, real_B], dim=1)
# ---------- Update D ----------
self.set_requires_grad(self.netD, True)
self.optimizer_D.zero_grad(set_to_none=True)
pred_fake_d, fake_category_logits_d = self.netD(fake_AB.detach())
pred_real, real_category_logits = self.netD(real_AB)
loss_D_real = self.real_binary_loss(pred_real)
loss_D_fake = self.fake_binary_loss(pred_fake_d)
# category loss symmetric (no extra half-scaling)
category_loss_real = self.category_loss(real_category_logits, labels)
category_loss_fake = self.category_loss(fake_category_logits_d, labels)
self.category_loss_D = (category_loss_real + category_loss_fake) * self.Lcategory_penalty
self.d_loss = 0.5 * (loss_D_real + loss_D_fake) + self.category_loss_D
self.d_loss.backward()
self.optimizer_D.step()
# ---------- Update G ----------
self.set_requires_grad(self.netD, False)
self.optimizer_G.zero_grad(set_to_none=True)
pred_fake, fake_category_logits = self.netD(fake_AB) # evaluation of fake with current D
# GAN loss for G (non-saturating (or use BinaryLoss with True target) )
self.loss_G_GAN = self.real_binary_loss(pred_fake)
fake_category_loss_G = self.category_loss(fake_category_logits, labels) * self.Lcategory_penalty
# compose total G loss with suggested scaling
self.g_loss = (
self.loss_G_GAN * 1.0 +
self.loss_l1 * (self.L1_penalty) +
self.loss_const * (self.Lconst_penalty) +
self.loss_perceptual * (self.Lperceptual_penalty) +
fake_category_loss_G
)
self.g_loss.backward()
# optional grad clipping (uncomment to use)
# torch.nn.utils.clip_grad_norm_(self.netG.parameters(), 1.0)
self.optimizer_G.step()
# Return logging scalars (as floats)
return {
"d_loss": self.d_loss.detach().item(),
"g_loss": self.g_loss.detach().item(),
"loss_l1": self.loss_l1.detach().item(),
"loss_const": self.loss_const.detach().item() if isinstance(self.loss_const, torch.Tensor) else float(self.loss_const),
"loss_perceptual": self.loss_perceptual.detach().item() if isinstance(self.loss_perceptual, torch.Tensor) else 0.0,
"loss_G_GAN": self.loss_G_GAN.detach().item()
}
def set_requires_grad(self, nets, requires_grad=False):
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad
3) 建議的超參數(起手)
lr = 2e-4(比 1e-3 更穩定)L1_penalty = 100(如果你確實需要 pixel fidelity,否則先 50)Lconst_penalty = 5.0(避免把 g_loss 全部拉高)Lperceptual_penalty = 0.01 ~ 0.05(若用低層 VGG)Lcategory_penalty = 1.0
實驗時先一次只改一個(例如先關掉 perceptual,看基線;再打開 perceptual 並用 0.02)。
4) 為什麼我把 embedding const loss 用 mse + adaptive_avg_pool2d?
- 你的
deepest_feat&second_feat可能在空間大小不同(你之前遇到 concat error)。把它們adaptive_avg_pool2d(...,(1,1))會得到穩定 (B, C, 1, 1),然後 flatten 為向量,比較穩定且不會因 H/W mismatch 發生錯誤。 - 用
MSE對 embedding 向量比較合理(L1 也可),但 MSE 更常用於特徵相似性(你可視需求改回 L1)。
5) 其他小改動建議(可選)
- 如果 D 過強,給 D 加 spectral_norm(你已支援
d_spectral_norm),或把 D 更新次數 vs G 的更新次數調整(例如每次更新 G 前更新 D 1 次或 2 次)。 - 若 perceptual loss 計算太慢,可只針對每 N 個 step 計算一次(不過不推薦作為常態,僅 debug 時用)。
- 若你要保留高層 VGG(relu3_3)做補充,請把權重設極小(如 0.001)且先確認輸出是否改善。
- 在訓練初期(前幾 epoch)把 L1 與 const loss 權重稍微降低,先讓 GAN 學會 high-level distribution,再慢慢把像素誤差提高,有時對 quality 有幫助。