BatchNorm2d 在 Self-Attention 模組不同層的結果

只訓練一個 epoch, 512×512 的解析度下


附上陽春版的訓練script:


trainer.py

import glob
import os

import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image, ImageDraw, ImageFont
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from tqdm import tqdm

from model import PatchDiscriminator, UNetGenerator

# 訓練參數
batch_size = 8  # 設定 batch size
lr = 0.0002
epochs = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 影像轉換
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# 自訂 Dataset
class FontDataset(Dataset):
    def __init__(self, dataset_dir, transform=None):
        self.noto_images = sorted(glob.glob(f"{dataset_dir}/notosans_dataset/instance_images/*.png"))
        self.zenmaru_images = sorted(glob.glob(f"{dataset_dir}/zenmaru_dataset/instance_images/*.png"))
        self.transform = transform

    def __len__(self):
        return len(self.noto_images)

    def __getitem__(self, idx):
        noto_img = Image.open(self.noto_images[idx]).convert("L")
        zenmaru_img = Image.open(self.zenmaru_images[idx]).convert("L")

        if self.transform:
            noto_img = self.transform(noto_img)
            zenmaru_img = self.transform(zenmaru_img)

        return noto_img, zenmaru_img

# 建立 DataLoader
dataset_dir = "c:/AI/datasets_512"
#dataset_dir = "c:/AI/datasets_256"
dataset = FontDataset(dataset_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 初始化模型
generator = UNetGenerator(
    norm_layer=nn.BatchNorm2d
    #norm_layer=nn.InstanceNorm2d
    ).to(device)
discriminator = PatchDiscriminator().to(device)

# 優化器與損失函數
criterion = nn.L1Loss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

# 嘗試載入已有的模型
if os.path.exists("generator.pth"):
    generator.load_state_dict(torch.load("generator.pth"))
    print("✅ Generator 模型已載入")
if os.path.exists("discriminator.pth"):
    discriminator.load_state_dict(torch.load("discriminator.pth"))
    print("✅ Discriminator 模型已載入")

# ===========================
#       訓練模型
# ===========================
for epoch in range(epochs):
    for i, (noto_sans, zenmaru) in enumerate(dataloader):
        noto_sans = noto_sans.to(device)
        zenmaru = zenmaru.to(device)

        # 訓練 Generator
        optimizer_G.zero_grad()
        output = generator(noto_sans)
        g_loss = criterion(output, zenmaru)
        g_loss.backward()
        optimizer_G.step()        

        # 訓練 Discriminator
        optimizer_D.zero_grad()
        real_labels = torch.ones_like(discriminator(zenmaru), device=device)
        fake_labels = torch.zeros_like(discriminator(output.detach()), device=device)
        real_loss = criterion(discriminator(zenmaru), real_labels)
        fake_loss = criterion(discriminator(output.detach()), fake_labels)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

    # 訓練進度輸出
    print(f"Epoch [{epoch+1}/{epochs}] Step [{i}/{len(dataloader)}] | G Loss: {g_loss.item():.4f} | D Loss: {d_loss.item():.4f}")
    #if i % 10 == 0:

# 每個 epoch 存模型
torch.save(generator.state_dict(), "generator.pth")
torch.save(discriminator.state_dict(), "discriminator.pth")
print(f"✅ 模型已儲存: Epoch {epoch+1}")
print("🎉 訓練完成!")

model.py

import torch
import torch.nn as nn
import torch.nn.functional as F
import functools

# Self-Attention 模組
class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        self.query = nn.Conv2d(in_channels, in_channels // 8, 1)
        self.key = nn.Conv2d(in_channels, in_channels // 8, 1)
        self.value = nn.Conv2d(in_channels, in_channels, 1)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        B, C, H, W = x.shape
        query = self.query(x).view(B, -1, H * W).permute(0, 2, 1)  # [B, H*W, C/8]
        key = self.key(x).view(B, -1, H * W)  # [B, C/8, H*W]
        value = self.value(x).view(B, -1, H * W).permute(0, 2, 1)  # [B, H*W, C]

        attention = self.softmax(torch.bmm(query, key))  # [B, H*W, H*W]
        out = torch.bmm(attention, value)  # [B, H*W, C]
        out = out.permute(0, 2, 1).view(B, C, H, W)  # 重新 reshape 回原始維度
        return out + x

# UNet 跳躍連接模組
class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc, input_nc=None,
                 submodule=None, use_attention=False, norm_layer=nn.BatchNorm2d, layer=0):
        super(UnetSkipConnectionBlock, self).__init__()
        """Construct a Unet submodule with skip connections.
        Parameters:
            outer_nc (int) -- the number of filters in the outer conv layer
            inner_nc (int) -- the number of filters in the inner conv layer
            input_nc (int) -- the number of channels in input images/features
            submodule (UnetSkipConnectionBlock) -- previously defined submodules
            norm_layer          -- normalization layer
            user_dropout (bool) -- if use dropout layers.
        """
        outermost = (layer == 8)
        innermost = (layer == 1)
        self.outermost = outermost
        self.innermost = innermost
        self.layer = layer
        self.use_attention = use_attention

        use_bias = norm_layer == nn.InstanceNorm2d  # 只在 InstanceNorm2d 時使用 bias
        self.norm_layer = norm_layer

        if input_nc is None:
            input_nc = outer_nc

        # Downsampling
        self.down = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.down_norm = norm_layer(inner_nc) if not outermost else nn.Identity()

        # ✅ 新增 Adaptive Pooling,避免變成 1x1
        #self.down_pool = nn.AdaptiveAvgPool2d((4, 4)) if not outermost else nn.Identity()

        # Upsampling
        if outermost:
            self.up = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1)
            self.up_norm = nn.Identity()
        elif innermost:
            self.up = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            self.up_norm = norm_layer(outer_nc)
        else:
            self.up = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            self.up_norm = norm_layer(outer_nc)

        if use_attention:
            self.attn = SelfAttention(inner_nc)

        self.submodule = submodule

    def forward(self, x):
        down_x = self.down(x)
        #down_x = self.down_pool(down_x)  # ✅ 確保最小尺寸 4x4,避免 InstanceNorm 失敗

        if not self.outermost:
            if self.layer==1:
                # 避免進入 InstanceNorm (1,1) 出錯
                if self.norm_layer == nn.InstanceNorm2d and down_x.shape[2] == 1 and down_x.shape[3] == 1:
                    pass
                else:
                    down_x = self.down_norm(down_x)
                    down_x = F.leaky_relu(down_x, 0.2, inplace=True)
            else:
                down_x = self.down_norm(down_x)
                down_x = F.leaky_relu(down_x, 0.2, inplace=True)

        # 加入注意力機制
        if self.use_attention and not self.outermost:
            down_x = self.attn(down_x)

        if self.submodule is not None:
            down_x = self.submodule(down_x)

        up_x = self.up(down_x)
        up_x = self.up_norm(up_x)

        if self.outermost:
            return up_x
        else:
            up_x = F.relu(up_x, inplace=True)

            # 使 up_x 和 x 在 H 和 W 維度匹配
            #diffY = x.size(2) - up_x.size(2)
            #diffX = x.size(3) - up_x.size(3)
            #up_x = F.pad(up_x, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2))

            return torch.cat([up_x, x], dim=1)


# UNet 生成器
class UNetGenerator(nn.Module):
    def __init__(self, input_nc=1, output_nc=1, num_downs=8, ngf=64, use_attention=True, norm_layer=nn.InstanceNorm2d):
        """
        Construct a Unet generator
        Parameters:
            input_nc (int)  -- the number of channels in input images
            output_nc (int) -- the number of channels in output images
            num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
                                image of size 128x128 will become of size 1x1 # at the bottleneck
            ngf (int)       -- the number of filters in the last conv layer
            norm_layer      -- normalization layer
            self_attention  -- self attention status
            self_attention_layer -- append to layer
            residual_block  -- residual block status
        """
        super(UNetGenerator, self).__init__()

        # 最內層(bottleneck)
        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, norm_layer=norm_layer, layer=1)

        # 中間層
        for index in range(num_downs - 5):
            loop_use_attention = False
            if index+2==4:
                loop_use_attention = True
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, submodule=unet_block, norm_layer=norm_layer, use_attention=loop_use_attention, layer=index+2)

        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, submodule=unet_block, norm_layer=norm_layer, layer=5)
        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, submodule=unet_block, norm_layer=norm_layer, use_attention=use_attention, layer=6)
        unet_block = UnetSkipConnectionBlock(ngf * 1, ngf * 2, submodule=unet_block, norm_layer=norm_layer, layer=7)

        # 最外層
        self.model = UnetSkipConnectionBlock(output_nc, ngf * 1, input_nc=input_nc, submodule=unet_block, norm_layer=norm_layer, layer=8)

    def forward(self, x):
        output = self.model(x)
        output = torch.tanh(output) # 使用 Tanh 縮放輸出到 [-1, 1]

        # 強制輸出為 (512, 512)
        #output = F.interpolate(output, size=(512, 512), mode='bilinear', align_corners=False)
        return output

# ===========================
#       Discriminator
# ===========================
class PatchDiscriminator(nn.Module):
    def __init__(self):
        super(PatchDiscriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 64, 4, 2, 1), nn.LeakyReLU(0.2, True),
            nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, True),
            #SelfAttention(128),  # Self-Attention at 64x64
            nn.Conv2d(128, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, True),
            nn.Conv2d(256, 1, 4, 1, 0), nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

infer.py

import torch
import os
import cv2
import numpy as np
from torchvision import transforms
from model import UNetGenerator
from PIL import Image

# 設定裝置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 讀取已訓練的模型
generator = UNetGenerator(
    norm_layer=nn.BatchNorm2d
    ).to(device)
generator.load_state_dict(torch.load("generator.pth", map_location=device), strict=False)
generator.eval()  # 設定為推論模式

# 圖像處理
transform = transforms.Compose([
    transforms.ToTensor(),          # 轉為 Tensor
    transforms.Normalize((0.5,), (0.5,))  # 標準化
])

# 讀取 Noto Sans 輸入字型圖像
def load_image(image_path):
    image = Image.open(image_path).convert("L")  # 灰階模式
    image = transform(image).unsqueeze(0)  # 增加 batch 維度
    return image.to(device)

# 推論函數
def infer(image_path, output_path):
    input_tensor = load_image(image_path)

    with torch.no_grad():
        output_tensor = generator(input_tensor)  # 產生 Zen Maru Gothic 字型

    # 轉換回 PIL Image
    output_image = output_tensor.squeeze().cpu().numpy()  # 去掉所有多餘維度
    output_image = (output_image * 127.5 + 127.5).astype(np.uint8)  # 反標準化
    output_image = Image.fromarray(output_image, mode="L")  # 轉為 PIL Image

    output_image.save(output_path)
    print(f"✅ 推論完成,結果已儲存至: {output_path}")

# 測試推論
if __name__ == "__main__":
    #test_image = "test_noto_sans_256.png"  # 測試的 Noto Sans 缺字圖像
    test_image = "test_noto_sans_512.png"  # 測試的 Noto Sans 缺字圖像
    #output_image = "output_zenmaru_256.png"  # 生成的 Zen Maru Gothic 缺字補全
    output_image = "output_zenmaru_512.png"  # 生成的 Zen Maru Gothic 缺字補全
    infer(test_image, output_image)

如果只有 4GB RAM, batch size=16, 在 layer=4 是可以訓練, 但再往下, layer=3 記憶體就會爆掉, 改用 batch size=8 就可以解決, 目前 batch size=8, sa 在 layer 3+5, 解析度512×512 是使用掉 GPU 的 3.8GB RAM.

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *