使用 Pix2Pix GAN,讓 Noto Sans 的輪廓轉換成 Zen Maru Gothic 風格

這是一個完整的 Glyph GAN 訓練腳本,用來 補全 Zen Maru Gothic 缺字,並確保字型風格一致。
主要流程:

  1. 準備數據 → 讀取 Noto Sans 和 Zen Maru Gothic 字型,生成對應的字型圖像數據集。
  2. 訓練 Glyph GAN → 使用 Pix2Pix GAN(條件 GAN)來學習字型風格轉換。
  3. 生成缺字 → 利用訓練好的模型補全 Zen Maru Gothic 缺字。

📜 完整 Python 訓練腳本

請確保你在 Colab 環境執行此腳本,並已安裝所需的套件

pip install torch torchvision pillow fontTools tqdm

🔹 1. 生成訓練數據

我們會 從字型檔案提取字體圖像,並儲存到 /content/dataset 目錄下。

import os
from fontTools.ttLib import TTFont
from PIL import Image, ImageDraw, ImageFont
from tqdm import tqdm

# 設定字型與存檔路徑
NotoSans_font_path = "/content/NotoSansJP-Regular.otf"
ZenMaru_font_path = "/content/ZenMaruGothic-Regular.ttf"
dataset_dir = "/content/dataset"
os.makedirs(dataset_dir, exist_ok=True)

# 影像大小 256x256
IMG_SIZE = (256, 256)

# 生成字型圖像
def generate_glyph_image(char, font_path, img_size=IMG_SIZE):
    font = ImageFont.truetype(font_path, 200)  # 調整字型大小
    img = Image.new("L", img_size, color=255)  # 白底黑字
    draw = ImageDraw.Draw(img)
    draw.text((30, 30), char, font=font, fill=0)
    return img

# 取得字型支援的字
def get_supported_chars(font_path):
    font = TTFont(font_path)
    cmap = font["cmap"].tables[0].cmap
    return set(cmap.keys())

# 取交集字元
noto_chars = get_supported_chars(NotoSans_font_path)
zenmaru_chars = get_supported_chars(ZenMaru_font_path)
common_chars = sorted(noto_chars & zenmaru_chars)

# 生成訓練數據
for char in tqdm(common_chars[:1000]):  # 取 1000 個常見字
    noto_img = generate_glyph_image(chr(char), NotoSans_font_path)
    zenmaru_img = generate_glyph_image(chr(char), ZenMaru_font_path)
    
    noto_img.save(f"{dataset_dir}/noto_{char}.png")
    zenmaru_img.save(f"{dataset_dir}/zenmaru_{char}.png")

print("✅ 訓練數據準備完成!")

🔹 2. 訓練 Glyph GAN(Pix2Pix)

這部分使用 Pix2Pix GAN,讓 Noto Sans 的輪廓轉換成 Zen Maru Gothic 風格。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from tqdm import tqdm
import glob

# 設定 PyTorch 設備
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 影像轉換
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((256, 256)),  # 更新為 256x256
    transforms.ToTensor(),
])

# 自訂 Dataset
class FontDataset(Dataset):
    def __init__(self, dataset_dir, transform=None):
        self.noto_images = sorted(glob.glob(f"{dataset_dir}/noto_*.png"))
        self.zenmaru_images = sorted(glob.glob(f"{dataset_dir}/zenmaru_*.png"))
        self.transform = transform

    def __len__(self):
        return len(self.noto_images)

    def __getitem__(self, idx):
        noto_img = Image.open(self.noto_images[idx]).convert("L")
        zenmaru_img = Image.open(self.zenmaru_images[idx]).convert("L")

        if self.transform:
            noto_img = self.transform(noto_img)
            zenmaru_img = self.transform(zenmaru_img)

        return noto_img, zenmaru_img

# 建立 DataLoader
dataset = FontDataset(dataset_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

# **改進版 U-Net Generator**
class UNetGenerator(nn.Module):
    def __init__(self):
        super(UNetGenerator, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# PatchGAN Discriminator
class PatchDiscriminator(nn.Module):
    def __init__(self):
        super(PatchDiscriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 1, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

# 初始化模型
generator = UNetGenerator().to(device)
discriminator = PatchDiscriminator().to(device)
criterion = nn.L1Loss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# 訓練模型
num_epochs = 100
for epoch in range(num_epochs):
    for noto, zenmaru in tqdm(dataloader):
        noto, zenmaru = noto.to(device), zenmaru.to(device)

        # 訓練 Generator
        optimizer_G.zero_grad()
        output = generator(noto)
        g_loss = criterion(output, zenmaru)
        g_loss.backward()
        optimizer_G.step()

        # 訓練 Discriminator
        optimizer_D.zero_grad()
        real_labels = torch.ones_like(discriminator(zenmaru), device=device)
        fake_labels = torch.zeros_like(discriminator(output.detach()), device=device)
        real_loss = criterion(discriminator(zenmaru), real_labels)
        fake_loss = criterion(discriminator(output.detach()), fake_labels)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

    print(f"Epoch {epoch+1}/{num_epochs}, G_Loss: {g_loss.item():.4f}, D_Loss: {d_loss.item():.4f}")

# 保存模型
torch.save(generator.state_dict(), "glyph_gan_256.pth")
print("✅ 訓練完成,模型已保存!")

🔹 3. 生成缺字

Zen Maru Gothic 缺少某個字時,可以用 Glyph GAN 自動補全:

import torch
import os
import cv2
import numpy as np
from torchvision import transforms
from model import UNetGenerator  # 確保 model.py 有定義 Generator
from PIL import Image

# 設定裝置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 讀取已訓練的模型
generator = UNetGenerator().to(device)
generator.load_state_dict(torch.load("glyph_gan_generator.pth", map_location=device))
generator.eval()  # 設定為推論模式

# 圖像處理
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # 調整大小為 256x256
    transforms.ToTensor(),          # 轉為 Tensor
    transforms.Normalize((0.5,), (0.5,))  # 標準化
])

# 讀取 Noto Sans 輸入字型圖像
def load_image(image_path):
    image = Image.open(image_path).convert("L")  # 灰階模式
    image = transform(image).unsqueeze(0)  # 增加 batch 維度
    return image.to(device)

# 推論函數
def infer(image_path, output_path):
    input_tensor = load_image(image_path)

    with torch.no_grad():
        output_tensor = generator(input_tensor)  # 產生 Zen Maru Gothic 字型

    # 轉換回 PIL Image
    output_image = output_tensor.squeeze(0).cpu().numpy()  # 去掉 batch 維度
    output_image = (output_image * 127.5 + 127.5).astype(np.uint8)  # 反標準化
    output_image = Image.fromarray(output_image, mode="L")  # 轉為 PIL Image

    output_image.save(output_path)
    print(f"✅ 推論完成,結果已儲存至: {output_path}")

# 測試推論
if __name__ == "__main__":
    test_image = "test_noto_sans.png"  # 測試的 Noto Sans 缺字圖像
    output_image = "output_zenmaru.png"  # 生成的 Zen Maru Gothic 缺字補全

    infer(test_image, output_image)

🎯 結論

Glyph GAN 可以自動補全 Zen Maru Gothic 缺字,並保持風格一致。
比 Stable Diffusion Inpainting 更準確,避免隨機變形筆畫。
訓練後的模型可以大規模補全,適合字型開發者使用。

這樣,你就可以讓 Zen Maru Gothic 變得完整,補全所有缺失的字啦!🚀


✅ 保存 Generator & Discriminator 模型

在訓練完成後,除了 generator 之外,也要儲存 discriminator,以便後續 fine-tune。
使用 torch.save() 來儲存模型權重,這樣可以在之後的訓練中繼續加載。

🔹 1. 儲存模型

# 保存 Generator 與 Discriminator 權重
torch.save(generator.state_dict(), "glyph_gan_generator.pth")
torch.save(discriminator.state_dict(), "glyph_gan_discriminator.pth")
print("✅ Generator & Discriminator 模型已保存!")

🔄 接續訓練的方法

如果要繼續訓練,可以 重新加載已訓練的模型,並從某個 epoch 之後繼續訓練。

🔹 2. 加載已訓練的模型

# 初始化模型
generator = UNetGenerator().to(device)
discriminator = PatchDiscriminator().to(device)

# 加載模型權重
generator.load_state_dict(torch.load("glyph_gan_generator.pth"))
discriminator.load_state_dict(torch.load("glyph_gan_discriminator.pth"))

print("✅ 加載已訓練的 Generator & Discriminator 模型!")

🔹 3. 繼續訓練

只需修改 num_epochs,並 保持原有的 optimizer 設定

# 設定繼續訓練的 Epochs
num_epochs = 50  # 可調整

for epoch in range(num_epochs):
    for noto, zenmaru in tqdm(dataloader):
        noto, zenmaru = noto.to(device), zenmaru.to(device)

        # 訓練 Generator
        optimizer_G.zero_grad()
        output = generator(noto)
        g_loss = criterion(output, zenmaru)
        g_loss.backward()
        optimizer_G.step()

        # 訓練 Discriminator
        optimizer_D.zero_grad()
        real_loss = criterion(discriminator(zenmaru), torch.ones_like(zenmaru))
        fake_loss = criterion(discriminator(output.detach()), torch.zeros_like(output))
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

    print(f"Epoch {epoch+1}/{num_epochs}, G_Loss: {g_loss.item():.4f}, D_Loss: {d_loss.item():.4f}")

# 再次保存模型
torch.save(generator.state_dict(), "glyph_gan_generator_v2.pth")
torch.save(discriminator.state_dict(), "glyph_gan_discriminator_v2.pth")
print("✅ 訓練完成,模型已更新!")

📌 這樣,你可以

儲存 & 讀取 Generator / Discriminator
隨時繼續訓練,不必重頭來過
更長時間訓練,提高 Zen Maru Gothic 缺字補全精度 🚀


在推論 (infer.py) 階段,只需要使用 Generator (generator.pth),而不需要讀取 Discriminator (discriminator.pth),原因如下:


1️⃣ 推論階段只需要「生成」字型

  • Generator (G) 的作用是 將 Noto Sans 缺字轉換為 Zen Maru Gothic 風格的補全字
  • Discriminator (D) 只在訓練時使用,它的任務是幫助 G 學習,使其輸出更接近真實字型風格。

💡 在推論時,我們只需要 G 來「產生」新字,而不需要 D 來判斷真假。


2️⃣ Discriminator 主要用於訓練

在 Glyph GAN 訓練時:

  1. G 會嘗試生成 Zen Maru Gothic 風格的字型。
  2. D 會判斷 G 生成的字是否與真實 Zen Maru Gothic 字型一致。
  3. G 根據 D 的回饋來學習,使生成的字更接近真實 Zen Maru Gothic。

但在推論階段:

  • G 已經訓練完成,並且能夠獨立產生高品質的字型,不再需要 D 來評估真假
  • D 沒有生成能力,它的輸出只是判斷結果 (real/fake),對於推論沒有實際用途。

3️⃣ 省去不必要的計算,提升推論速度

  • 讀取 D 會佔用額外的 GPU 記憶體,但在推論時沒有幫助。
  • 只讀取 G 可以讓推論更輕量、更快速

📌 總結

Generator (G)Discriminator (D)
訓練時產生 Zen Maru Gothic 字判斷真假,幫助 G 學習
推論時直接產生 Zen Maru Gothic 字不需要,因為 G 已經訓練好

所以,在 推論 (infer.py) 階段,只需要讀取 generator.pth,不需要 discriminator.pth!🚀

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *