
只訓練一個 epoch, 512×512 的解析度下
附上陽春版的訓練script:
trainer.py
import glob
import os
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image, ImageDraw, ImageFont
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from tqdm import tqdm
from model import PatchDiscriminator, UNetGenerator
# 訓練參數
batch_size = 8 # 設定 batch size
lr = 0.0002
epochs = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 影像轉換
transform = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
# 自訂 Dataset
class FontDataset(Dataset):
def __init__(self, dataset_dir, transform=None):
self.noto_images = sorted(glob.glob(f"{dataset_dir}/notosans_dataset/instance_images/*.png"))
self.zenmaru_images = sorted(glob.glob(f"{dataset_dir}/zenmaru_dataset/instance_images/*.png"))
self.transform = transform
def __len__(self):
return len(self.noto_images)
def __getitem__(self, idx):
noto_img = Image.open(self.noto_images[idx]).convert("L")
zenmaru_img = Image.open(self.zenmaru_images[idx]).convert("L")
if self.transform:
noto_img = self.transform(noto_img)
zenmaru_img = self.transform(zenmaru_img)
return noto_img, zenmaru_img
# 建立 DataLoader
dataset_dir = "c:/AI/datasets_512"
#dataset_dir = "c:/AI/datasets_256"
dataset = FontDataset(dataset_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 初始化模型
generator = UNetGenerator(
norm_layer=nn.BatchNorm2d
#norm_layer=nn.InstanceNorm2d
).to(device)
discriminator = PatchDiscriminator().to(device)
# 優化器與損失函數
criterion = nn.L1Loss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
# 嘗試載入已有的模型
if os.path.exists("generator.pth"):
generator.load_state_dict(torch.load("generator.pth"))
print("✅ Generator 模型已載入")
if os.path.exists("discriminator.pth"):
discriminator.load_state_dict(torch.load("discriminator.pth"))
print("✅ Discriminator 模型已載入")
# ===========================
# 訓練模型
# ===========================
for epoch in range(epochs):
for i, (noto_sans, zenmaru) in enumerate(dataloader):
noto_sans = noto_sans.to(device)
zenmaru = zenmaru.to(device)
# 訓練 Generator
optimizer_G.zero_grad()
output = generator(noto_sans)
g_loss = criterion(output, zenmaru)
g_loss.backward()
optimizer_G.step()
# 訓練 Discriminator
optimizer_D.zero_grad()
real_labels = torch.ones_like(discriminator(zenmaru), device=device)
fake_labels = torch.zeros_like(discriminator(output.detach()), device=device)
real_loss = criterion(discriminator(zenmaru), real_labels)
fake_loss = criterion(discriminator(output.detach()), fake_labels)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()
# 訓練進度輸出
print(f"Epoch [{epoch+1}/{epochs}] Step [{i}/{len(dataloader)}] | G Loss: {g_loss.item():.4f} | D Loss: {d_loss.item():.4f}")
#if i % 10 == 0:
# 每個 epoch 存模型
torch.save(generator.state_dict(), "generator.pth")
torch.save(discriminator.state_dict(), "discriminator.pth")
print(f"✅ 模型已儲存: Epoch {epoch+1}")
print("🎉 訓練完成!")
model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import functools
# Self-Attention 模組
class SelfAttention(nn.Module):
def __init__(self, in_channels):
super(SelfAttention, self).__init__()
self.query = nn.Conv2d(in_channels, in_channels // 8, 1)
self.key = nn.Conv2d(in_channels, in_channels // 8, 1)
self.value = nn.Conv2d(in_channels, in_channels, 1)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
B, C, H, W = x.shape
query = self.query(x).view(B, -1, H * W).permute(0, 2, 1) # [B, H*W, C/8]
key = self.key(x).view(B, -1, H * W) # [B, C/8, H*W]
value = self.value(x).view(B, -1, H * W).permute(0, 2, 1) # [B, H*W, C]
attention = self.softmax(torch.bmm(query, key)) # [B, H*W, H*W]
out = torch.bmm(attention, value) # [B, H*W, C]
out = out.permute(0, 2, 1).view(B, C, H, W) # 重新 reshape 回原始維度
return out + x
# UNet 跳躍連接模組
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc, input_nc=None,
submodule=None, use_attention=False, norm_layer=nn.BatchNorm2d, layer=0):
super(UnetSkipConnectionBlock, self).__init__()
"""Construct a Unet submodule with skip connections.
Parameters:
outer_nc (int) -- the number of filters in the outer conv layer
inner_nc (int) -- the number of filters in the inner conv layer
input_nc (int) -- the number of channels in input images/features
submodule (UnetSkipConnectionBlock) -- previously defined submodules
norm_layer -- normalization layer
user_dropout (bool) -- if use dropout layers.
"""
outermost = (layer == 8)
innermost = (layer == 1)
self.outermost = outermost
self.innermost = innermost
self.layer = layer
self.use_attention = use_attention
use_bias = norm_layer == nn.InstanceNorm2d # 只在 InstanceNorm2d 時使用 bias
self.norm_layer = norm_layer
if input_nc is None:
input_nc = outer_nc
# Downsampling
self.down = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
self.down_norm = norm_layer(inner_nc) if not outermost else nn.Identity()
# ✅ 新增 Adaptive Pooling,避免變成 1x1
#self.down_pool = nn.AdaptiveAvgPool2d((4, 4)) if not outermost else nn.Identity()
# Upsampling
if outermost:
self.up = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1)
self.up_norm = nn.Identity()
elif innermost:
self.up = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
self.up_norm = norm_layer(outer_nc)
else:
self.up = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
self.up_norm = norm_layer(outer_nc)
if use_attention:
self.attn = SelfAttention(inner_nc)
self.submodule = submodule
def forward(self, x):
down_x = self.down(x)
#down_x = self.down_pool(down_x) # ✅ 確保最小尺寸 4x4,避免 InstanceNorm 失敗
if not self.outermost:
if self.layer==1:
# 避免進入 InstanceNorm (1,1) 出錯
if self.norm_layer == nn.InstanceNorm2d and down_x.shape[2] == 1 and down_x.shape[3] == 1:
pass
else:
down_x = self.down_norm(down_x)
down_x = F.leaky_relu(down_x, 0.2, inplace=True)
else:
down_x = self.down_norm(down_x)
down_x = F.leaky_relu(down_x, 0.2, inplace=True)
# 加入注意力機制
if self.use_attention and not self.outermost:
down_x = self.attn(down_x)
if self.submodule is not None:
down_x = self.submodule(down_x)
up_x = self.up(down_x)
up_x = self.up_norm(up_x)
if self.outermost:
return up_x
else:
up_x = F.relu(up_x, inplace=True)
# 使 up_x 和 x 在 H 和 W 維度匹配
#diffY = x.size(2) - up_x.size(2)
#diffX = x.size(3) - up_x.size(3)
#up_x = F.pad(up_x, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2))
return torch.cat([up_x, x], dim=1)
# UNet 生成器
class UNetGenerator(nn.Module):
def __init__(self, input_nc=1, output_nc=1, num_downs=8, ngf=64, use_attention=True, norm_layer=nn.InstanceNorm2d):
"""
Construct a Unet generator
Parameters:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
image of size 128x128 will become of size 1x1 # at the bottleneck
ngf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer
self_attention -- self attention status
self_attention_layer -- append to layer
residual_block -- residual block status
"""
super(UNetGenerator, self).__init__()
# 最內層(bottleneck)
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, norm_layer=norm_layer, layer=1)
# 中間層
for index in range(num_downs - 5):
loop_use_attention = False
if index+2==4:
loop_use_attention = True
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, submodule=unet_block, norm_layer=norm_layer, use_attention=loop_use_attention, layer=index+2)
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, submodule=unet_block, norm_layer=norm_layer, layer=5)
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, submodule=unet_block, norm_layer=norm_layer, use_attention=use_attention, layer=6)
unet_block = UnetSkipConnectionBlock(ngf * 1, ngf * 2, submodule=unet_block, norm_layer=norm_layer, layer=7)
# 最外層
self.model = UnetSkipConnectionBlock(output_nc, ngf * 1, input_nc=input_nc, submodule=unet_block, norm_layer=norm_layer, layer=8)
def forward(self, x):
output = self.model(x)
output = torch.tanh(output) # 使用 Tanh 縮放輸出到 [-1, 1]
# 強制輸出為 (512, 512)
#output = F.interpolate(output, size=(512, 512), mode='bilinear', align_corners=False)
return output
# ===========================
# Discriminator
# ===========================
class PatchDiscriminator(nn.Module):
def __init__(self):
super(PatchDiscriminator, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(1, 64, 4, 2, 1), nn.LeakyReLU(0.2, True),
nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, True),
#SelfAttention(128), # Self-Attention at 64x64
nn.Conv2d(128, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, True),
nn.Conv2d(256, 1, 4, 1, 0), nn.Sigmoid()
)
def forward(self, x):
return self.model(x)
infer.py
import torch
import os
import cv2
import numpy as np
from torchvision import transforms
from model import UNetGenerator
from PIL import Image
# 設定裝置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 讀取已訓練的模型
generator = UNetGenerator(
norm_layer=nn.BatchNorm2d
).to(device)
generator.load_state_dict(torch.load("generator.pth", map_location=device), strict=False)
generator.eval() # 設定為推論模式
# 圖像處理
transform = transforms.Compose([
transforms.ToTensor(), # 轉為 Tensor
transforms.Normalize((0.5,), (0.5,)) # 標準化
])
# 讀取 Noto Sans 輸入字型圖像
def load_image(image_path):
image = Image.open(image_path).convert("L") # 灰階模式
image = transform(image).unsqueeze(0) # 增加 batch 維度
return image.to(device)
# 推論函數
def infer(image_path, output_path):
input_tensor = load_image(image_path)
with torch.no_grad():
output_tensor = generator(input_tensor) # 產生 Zen Maru Gothic 字型
# 轉換回 PIL Image
output_image = output_tensor.squeeze().cpu().numpy() # 去掉所有多餘維度
output_image = (output_image * 127.5 + 127.5).astype(np.uint8) # 反標準化
output_image = Image.fromarray(output_image, mode="L") # 轉為 PIL Image
output_image.save(output_path)
print(f"✅ 推論完成,結果已儲存至: {output_path}")
# 測試推論
if __name__ == "__main__":
#test_image = "test_noto_sans_256.png" # 測試的 Noto Sans 缺字圖像
test_image = "test_noto_sans_512.png" # 測試的 Noto Sans 缺字圖像
#output_image = "output_zenmaru_256.png" # 生成的 Zen Maru Gothic 缺字補全
output_image = "output_zenmaru_512.png" # 生成的 Zen Maru Gothic 缺字補全
infer(test_image, output_image)
如果只有 4GB RAM, batch size=16, 在 layer=4 是可以訓練, 但再往下, layer=3 記憶體就會爆掉, 改用 batch size=8 就可以解決, 目前 batch size=8, sa 在 layer 3+5, 解析度512×512 是使用掉 GPU 的 3.8GB RAM.