在 Discriminator 增加 category_loss(分類)與 binary_loss(對抗)

是的,你可以透過 增加 category_lossbinary_loss 來改進 Discriminator 的設計,並直接從 Discriminator 取得 g_loss 來訓練 Generator,這樣 G 的學習方向會更清楚,使生成的字體更符合 Zen Maru Gothic 風格。


📌 方法概述

目前你的 Generator 主要依賴:

  1. L1 loss (像素級對齊)
  2. Feature Matching loss (透過 D 提取特徵,使 G 生成更接近真實字體)

這些方法雖然有效,但 它們沒有提供強烈的語義約束,例如:

  • 生成的字體應該與 Zen Maru Gothic字形結構上相似 (而非單純 L1 距離小)
  • Discriminator 的分類能力未充分發揮,只用來區分真偽 (real/fake)

解決方案

  1. 加入 category_loss(分類損失)
    • D 學習區分不同的字體風格 (Noto Sans vs Zen Maru Gothic)
    • G 需要生成更符合 Zen Maru Gothic 類別的字體
  2. 加入 binary_loss(對抗損失)
    • G 需要騙過 D,讓 D 判斷生成的字體是真實的 Zen Maru Gothic
  3. 加入 const_loss(內容一致性損失)
    • G 生成的特徵與 Noto Sans 保持某種對應關係,防止過度偏移
  4. 最終 g_loss 來自 Discriminator

🔧 具體實作

🟢 修改 Discriminator

Discriminator 除了輸出真實/假 (binary classification),還額外輸出類別分類結果:

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.binary_head = nn.Conv2d(256, 1, kernel_size=4)  # 判斷真假
        self.category_head = nn.Conv2d(256, 2, kernel_size=4)  # 判斷字體類別 (Noto Sans vs Zen Maru Gothic)

    def forward(self, x):
        features = self.conv_layers(x)
        binary_output = torch.sigmoid(self.binary_head(features))  # 判斷真/假
        category_output = torch.softmax(self.category_head(features), dim=1)  # 分類
        return binary_output, category_output, features

🟢 修改 Generator 訓練 (g_loss 來自 D)

# 訓練 Generator
optimizer_G.zero_grad()
output = generator(noto)

# 取得 D 的判斷
fake_binary_logits, fake_category_logits, encoded_fake_B = discriminator(output)
real_binary_logits, real_category_logits, encoded_real_A = discriminator(zenmaru)

# 內容一致性損失 (const loss)
const_loss = 10.0 * criterion_mse(encoded_real_A, encoded_fake_B)

# L1 損失 (影像層級)
l1_loss = 100.0 * criterion_l1(output, zenmaru)

# Category loss (希望 G 生成的字體能接近 Zen Maru Gothic)
fake_category_loss = 5.0 * criterion_ce(fake_category_logits, torch.ones_like(fake_category_logits[:, 0], dtype=torch.long))

# Binary loss (希望 G 騙過 D)
cheat_loss = criterion_bce(fake_binary_logits, torch.ones_like(fake_binary_logits))

# 最終 Generator loss
g_loss = cheat_loss + l1_loss + fake_category_loss + const_loss
g_loss.backward()
optimizer_G.step()

🔹 超參數權重說明

  • 10.0 * const_loss :保證特徵相似,但不過度約束
  • 100.0 * l1_loss :確保生成結果與 Zen Maru Gothic 相似
  • 5.0 * category_loss :引導 G 生成更符合 Zen Maru Gothic 類別
  • cheat_loss :讓 G 騙過 D

🟢 修改 Discriminator 訓練

# 訓練 Discriminator
optimizer_D.zero_grad()

real_binary_logits, real_category_logits, _ = discriminator(zenmaru)
fake_binary_logits, fake_category_logits, _ = discriminator(output.detach())

# Binary loss (D 判斷真假)
real_loss = criterion_bce(real_binary_logits, torch.ones_like(real_binary_logits))
fake_loss = criterion_bce(fake_binary_logits, torch.zeros_like(fake_binary_logits))
binary_loss = (real_loss + fake_loss) / 2

# Category loss (D 判斷字體類別)
real_category_loss = criterion_ce(real_category_logits, torch.ones_like(real_category_logits[:, 0], dtype=torch.long))
fake_category_loss = criterion_ce(fake_category_logits, torch.zeros_like(fake_category_logits[:, 0], dtype=torch.long))

category_loss = (real_category_loss + fake_category_loss) / 2

# 總損失
d_loss = binary_loss + category_loss
d_loss.backward()
optimizer_D.step()

📌 這樣修改後的好處

更清楚的學習方向

  • G 不只是學 L1 距離,還需要讓 D 認為它生成的字體是 Zen Maru Gothic
  • D 需要學會區分 Noto SansZen Maru Gothic,進一步提升 G 的生成質量

增加穩定性,避免梯度爆炸

  • const_loss 確保 G 生成的字體不會偏移太遠
  • category_loss 確保 G 的結果保持 Zen Maru Gothic 風格,而不是生成隨機結果

Discriminator 的作用更強

  • D 現在不只是判斷真假,還學會分類不同字體
  • G 不能只靠欺騙 D,還要符合 Zen Maru Gothic 的風格

🔍 總結

Discriminator 增加 category_loss(分類)與 binary_loss(對抗)
G 的損失來自 D,而不只是 L1 Loss
使用 const_loss 確保字形特徵一致
解決 resume 訓練後效果變差的問題

這樣的 Feature Matching + Category Loss + Binary Loss 設計,能讓 G 產生更符合 Zen Maru Gothic 風格的字體 🎯🚀

trainer.py

import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from dataset import DatasetFromObj  # 確保有這個 dataset loader
from model import UNetGenerator, Discriminator  # 確保有這些模型

def load_checkpoint(device, generator, discriminator, checkpoint_dir, resume):
    """ 載入 checkpoint """
    g_path = os.path.join(checkpoint_dir, f"{resume}_net_G.pth")
    d_path = os.path.join(checkpoint_dir, f"{resume}_net_D.pth")
    if os.path.exists(g_path):
        generator.load_state_dict(torch.load(g_path, map_location=device))
        print(f"✅ Generator({resume}) 模型已載入")
    if os.path.exists(d_path):
        discriminator.load_state_dict(torch.load(d_path, map_location=device))
        print(f"✅ Discriminator({resume}) 模型已載入")

def train(args):
    # 設定 device
    device = torch.device("cuda" if args.gpu_ids and torch.cuda.is_available() else "cpu")
    
    # 影像轉換
    transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])

    # 載入資料集
    train_dataset = DatasetFromObj(os.path.join(args.data_dir, 'train.obj'), transform=transform)
    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)

    # 初始化 Generator 和 Discriminator
    generator = UNetGenerator().to(device)
    discriminator = Discriminator().to(device)
    
    # 續訓
    if args.resume:
        load_checkpoint(device, generator, discriminator, args.checkpoint_dir, args.resume)

    # 損失函數
    criterion_l1 = nn.L1Loss()
    criterion_mse = nn.MSELoss()
    criterion_bce = nn.BCELoss()
    criterion_ce = nn.CrossEntropyLoss()

    # 優化器
    optimizer_G = optim.Adam(generator.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))

    # 訓練迴圈
    for epoch in range(args.epoch):
        for bid, (_, zenmaru, noto) in enumerate(train_dataloader):
            noto, zenmaru = noto.to(device), zenmaru.to(device)

            # 訓練 Generator
            optimizer_G.zero_grad()
            output = generator(noto)

            # 取得 D 的判斷
            fake_binary_logits, fake_category_logits, encoded_fake_B = discriminator(output)
            real_binary_logits, real_category_logits, encoded_real_A = discriminator(zenmaru)

            # Content consistency loss
            const_loss = 10.0 * criterion_mse(encoded_real_A, encoded_fake_B)

            # L1 損失 (影像層級)
            l1_loss = 100.0 * criterion_l1(output, zenmaru)

            # Category loss (希望 G 生成 Zen Maru Gothic)
            fake_category_loss = 5.0 * criterion_ce(fake_category_logits, torch.ones_like(fake_category_logits[:, 0], dtype=torch.long))

            # Binary loss (G 希望騙過 D)
            cheat_loss = criterion_bce(fake_binary_logits, torch.ones_like(fake_binary_logits))

            # 最終 Generator loss
            g_loss = cheat_loss + l1_loss + fake_category_loss + const_loss
            g_loss.backward()
            optimizer_G.step()

            # 訓練 Discriminator
            optimizer_D.zero_grad()
            real_binary_logits, real_category_logits, _ = discriminator(zenmaru)
            fake_binary_logits, fake_category_logits, _ = discriminator(output.detach())

            # Binary loss (D 判斷真假)
            real_loss = criterion_bce(real_binary_logits, torch.ones_like(real_binary_logits))
            fake_loss = criterion_bce(fake_binary_logits, torch.zeros_like(fake_binary_logits))
            binary_loss = (real_loss + fake_loss) / 2

            # Category loss (D 判斷字體類別)
            real_category_loss = criterion_ce(real_category_logits, torch.ones_like(real_category_logits[:, 0], dtype=torch.long))
            fake_category_loss = criterion_ce(fake_category_logits, torch.zeros_like(fake_category_logits[:, 0], dtype=torch.long))
            category_loss = (real_category_loss + fake_category_loss) / 2

            # 總 Discriminator loss
            d_loss = binary_loss + category_loss
            d_loss.backward()
            optimizer_D.step()

        print(f"Epoch [{epoch+1}/{args.epoch}] G_loss: {g_loss.item():.4f} D_loss: {d_loss.item():.4f}")

    # 儲存模型
    torch.save(generator.state_dict(), os.path.join(args.checkpoint_dir, "final_net_G.pth"))
    torch.save(discriminator.state_dict(), os.path.join(args.checkpoint_dir, "final_net_D.pth"))
    print("✅ 訓練完成,模型已儲存!")

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *