是的,你可以透過 增加 category_loss
和 binary_loss
來改進 Discriminator
的設計,並直接從 Discriminator
取得 g_loss
來訓練 Generator
,這樣 G
的學習方向會更清楚,使生成的字體更符合 Zen Maru Gothic
風格。
📌 方法概述
目前你的 Generator
主要依賴:
L1 loss
(像素級對齊)Feature Matching loss
(透過D
提取特徵,使G
生成更接近真實字體)
這些方法雖然有效,但 它們沒有提供強烈的語義約束,例如:
- 生成的字體應該與
Zen Maru Gothic
在字形結構上相似 (而非單純 L1 距離小) Discriminator
的分類能力未充分發揮,只用來區分真偽 (real/fake)
解決方案:
- 加入
category_loss
(分類損失)- 讓
D
學習區分不同的字體風格 (Noto Sans
vsZen Maru Gothic
) G
需要生成更符合Zen Maru Gothic
類別的字體
- 讓
- 加入
binary_loss
(對抗損失)G
需要騙過D
,讓D
判斷生成的字體是真實的Zen Maru Gothic
- 加入
const_loss
(內容一致性損失)- 讓
G
生成的特徵與Noto Sans
保持某種對應關係,防止過度偏移
- 讓
- 最終
g_loss
來自Discriminator
🔧 具體實作
🟢 修改 Discriminator
讓 Discriminator
除了輸出真實/假 (binary classification),還額外輸出類別分類結果:
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.conv_layers = nn.Sequential(
nn.Conv2d(1, 64, 4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, 4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, 4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
)
self.binary_head = nn.Conv2d(256, 1, kernel_size=4) # 判斷真假
self.category_head = nn.Conv2d(256, 2, kernel_size=4) # 判斷字體類別 (Noto Sans vs Zen Maru Gothic)
def forward(self, x):
features = self.conv_layers(x)
binary_output = torch.sigmoid(self.binary_head(features)) # 判斷真/假
category_output = torch.softmax(self.category_head(features), dim=1) # 分類
return binary_output, category_output, features
🟢 修改 Generator
訓練 (g_loss
來自 D
)
# 訓練 Generator
optimizer_G.zero_grad()
output = generator(noto)
# 取得 D 的判斷
fake_binary_logits, fake_category_logits, encoded_fake_B = discriminator(output)
real_binary_logits, real_category_logits, encoded_real_A = discriminator(zenmaru)
# 內容一致性損失 (const loss)
const_loss = 10.0 * criterion_mse(encoded_real_A, encoded_fake_B)
# L1 損失 (影像層級)
l1_loss = 100.0 * criterion_l1(output, zenmaru)
# Category loss (希望 G 生成的字體能接近 Zen Maru Gothic)
fake_category_loss = 5.0 * criterion_ce(fake_category_logits, torch.ones_like(fake_category_logits[:, 0], dtype=torch.long))
# Binary loss (希望 G 騙過 D)
cheat_loss = criterion_bce(fake_binary_logits, torch.ones_like(fake_binary_logits))
# 最終 Generator loss
g_loss = cheat_loss + l1_loss + fake_category_loss + const_loss
g_loss.backward()
optimizer_G.step()
🔹 超參數權重說明:
10.0 * const_loss
:保證特徵相似,但不過度約束100.0 * l1_loss
:確保生成結果與Zen Maru Gothic
相似5.0 * category_loss
:引導G
生成更符合Zen Maru Gothic
類別cheat_loss
:讓G
騙過D
🟢 修改 Discriminator
訓練
# 訓練 Discriminator
optimizer_D.zero_grad()
real_binary_logits, real_category_logits, _ = discriminator(zenmaru)
fake_binary_logits, fake_category_logits, _ = discriminator(output.detach())
# Binary loss (D 判斷真假)
real_loss = criterion_bce(real_binary_logits, torch.ones_like(real_binary_logits))
fake_loss = criterion_bce(fake_binary_logits, torch.zeros_like(fake_binary_logits))
binary_loss = (real_loss + fake_loss) / 2
# Category loss (D 判斷字體類別)
real_category_loss = criterion_ce(real_category_logits, torch.ones_like(real_category_logits[:, 0], dtype=torch.long))
fake_category_loss = criterion_ce(fake_category_logits, torch.zeros_like(fake_category_logits[:, 0], dtype=torch.long))
category_loss = (real_category_loss + fake_category_loss) / 2
# 總損失
d_loss = binary_loss + category_loss
d_loss.backward()
optimizer_D.step()
📌 這樣修改後的好處
✅ 更清楚的學習方向
G
不只是學 L1 距離,還需要讓D
認為它生成的字體是Zen Maru Gothic
D
需要學會區分Noto Sans
和Zen Maru Gothic
,進一步提升G
的生成質量
✅ 增加穩定性,避免梯度爆炸
const_loss
確保G
生成的字體不會偏移太遠category_loss
確保G
的結果保持Zen Maru Gothic
風格,而不是生成隨機結果
✅ Discriminator 的作用更強
D
現在不只是判斷真假,還學會分類不同字體G
不能只靠欺騙D
,還要符合Zen Maru Gothic
的風格
🔍 總結
✔ 在 Discriminator
增加 category_loss
(分類)與 binary_loss
(對抗)
✔ 讓 G
的損失來自 D
,而不只是 L1 Loss
✔ 使用 const_loss
確保字形特徵一致
✔ 解決 resume
訓練後效果變差的問題
這樣的 Feature Matching + Category Loss + Binary Loss
設計,能讓 G
產生更符合 Zen Maru Gothic
風格的字體 🎯🚀
trainer.py
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from dataset import DatasetFromObj # 確保有這個 dataset loader
from model import UNetGenerator, Discriminator # 確保有這些模型
def load_checkpoint(device, generator, discriminator, checkpoint_dir, resume):
""" 載入 checkpoint """
g_path = os.path.join(checkpoint_dir, f"{resume}_net_G.pth")
d_path = os.path.join(checkpoint_dir, f"{resume}_net_D.pth")
if os.path.exists(g_path):
generator.load_state_dict(torch.load(g_path, map_location=device))
print(f"✅ Generator({resume}) 模型已載入")
if os.path.exists(d_path):
discriminator.load_state_dict(torch.load(d_path, map_location=device))
print(f"✅ Discriminator({resume}) 模型已載入")
def train(args):
# 設定 device
device = torch.device("cuda" if args.gpu_ids and torch.cuda.is_available() else "cpu")
# 影像轉換
transform = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
# 載入資料集
train_dataset = DatasetFromObj(os.path.join(args.data_dir, 'train.obj'), transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
# 初始化 Generator 和 Discriminator
generator = UNetGenerator().to(device)
discriminator = Discriminator().to(device)
# 續訓
if args.resume:
load_checkpoint(device, generator, discriminator, args.checkpoint_dir, args.resume)
# 損失函數
criterion_l1 = nn.L1Loss()
criterion_mse = nn.MSELoss()
criterion_bce = nn.BCELoss()
criterion_ce = nn.CrossEntropyLoss()
# 優化器
optimizer_G = optim.Adam(generator.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
optimizer_D = optim.Adam(discriminator.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
# 訓練迴圈
for epoch in range(args.epoch):
for bid, (_, zenmaru, noto) in enumerate(train_dataloader):
noto, zenmaru = noto.to(device), zenmaru.to(device)
# 訓練 Generator
optimizer_G.zero_grad()
output = generator(noto)
# 取得 D 的判斷
fake_binary_logits, fake_category_logits, encoded_fake_B = discriminator(output)
real_binary_logits, real_category_logits, encoded_real_A = discriminator(zenmaru)
# Content consistency loss
const_loss = 10.0 * criterion_mse(encoded_real_A, encoded_fake_B)
# L1 損失 (影像層級)
l1_loss = 100.0 * criterion_l1(output, zenmaru)
# Category loss (希望 G 生成 Zen Maru Gothic)
fake_category_loss = 5.0 * criterion_ce(fake_category_logits, torch.ones_like(fake_category_logits[:, 0], dtype=torch.long))
# Binary loss (G 希望騙過 D)
cheat_loss = criterion_bce(fake_binary_logits, torch.ones_like(fake_binary_logits))
# 最終 Generator loss
g_loss = cheat_loss + l1_loss + fake_category_loss + const_loss
g_loss.backward()
optimizer_G.step()
# 訓練 Discriminator
optimizer_D.zero_grad()
real_binary_logits, real_category_logits, _ = discriminator(zenmaru)
fake_binary_logits, fake_category_logits, _ = discriminator(output.detach())
# Binary loss (D 判斷真假)
real_loss = criterion_bce(real_binary_logits, torch.ones_like(real_binary_logits))
fake_loss = criterion_bce(fake_binary_logits, torch.zeros_like(fake_binary_logits))
binary_loss = (real_loss + fake_loss) / 2
# Category loss (D 判斷字體類別)
real_category_loss = criterion_ce(real_category_logits, torch.ones_like(real_category_logits[:, 0], dtype=torch.long))
fake_category_loss = criterion_ce(fake_category_logits, torch.zeros_like(fake_category_logits[:, 0], dtype=torch.long))
category_loss = (real_category_loss + fake_category_loss) / 2
# 總 Discriminator loss
d_loss = binary_loss + category_loss
d_loss.backward()
optimizer_D.step()
print(f"Epoch [{epoch+1}/{args.epoch}] G_loss: {g_loss.item():.4f} D_loss: {d_loss.item():.4f}")
# 儲存模型
torch.save(generator.state_dict(), os.path.join(args.checkpoint_dir, "final_net_G.pth"))
torch.save(discriminator.state_dict(), os.path.join(args.checkpoint_dir, "final_net_D.pth"))
print("✅ 訓練完成,模型已儲存!")