criterion = nn.L1Loss() 與 criterion = nn.BCELoss() 區別?

🔹 nn.L1Loss() vs nn.BCELoss() 的區別

這兩者是 不同類型的損失函數,主要用於不同的學習目標。

損失函數適用場景計算方式典型輸入典型輸出
nn.L1Loss()迴歸(Regression)計算預測值與目標值的 絕對誤差(L1距離)任意實數範圍的值任意實數範圍的誤差值
nn.BCELoss()分類(Classification)用於二元分類問題,計算 交叉熵損失(Binary Cross Entropy, BCE)範圍為 [0,1] 的機率值範圍為 [0,1] 的損失值

🔹 nn.L1Loss()(L1 絕對誤差損失)

L1Loss() 是 L1 損失函數,也稱為 Mean Absolute Error (MAE)。適用於 回歸問題,用來衡量 預測值與真值之間的絕對差距

數學公式

L=1N∑i=1N∣yi−y^i∣L = \frac{1}{N} \sum_{i=1}^{N} | y_i – \hat{y}_i |

其中:

  • yiy_i 是 真實值
  • y^i\hat{y}_i 是 預測值
  • NN 是 樣本數

使用場景

  • 風格學習(如字型補全)→ 目標是讓生成的字與真實字的像素接近
  • 圖像處理(如去噪、自動色彩化)
  • 回歸問題(預測數值,如房價、溫度)

範例

import torch
import torch.nn as nn

criterion = nn.L1Loss()

y_pred = torch.tensor([0.2, 0.8, 0.3])  # 模型輸出
y_true = torch.tensor([0.0, 1.0, 0.5])  # 目標值

loss = criterion(y_pred, y_true)
print(loss.item())  # 輸出 L1 Loss 值

🔹 nn.BCELoss()(二元交叉熵損失)

BCELoss()Binary Cross Entropy(BCE),適用於 二元分類問題,通常用來訓練 判別器(Discriminator),讓其輸出接近 0(假)或 1(真)。

數學公式

L=−1N∑i=1N[yilog⁡(y^i)+(1−yi)log⁡(1−y^i)]L = -\frac{1}{N} \sum_{i=1}^{N} \left[ y_i \log(\hat{y}_i) + (1 – y_i) \log(1 – \hat{y}_i) \right]

其中:

  • yiy_i 是 真實標籤(0 或 1)
  • y^i\hat{y}_i 是 模型的輸出(範圍為 [0,1])
  • NN 是 樣本數

⚠ 注意:BCE 需要 sigmoid() 處理輸出!

  • 如果 discriminator 最後一層是 nn.Sigmoid(),可以直接用 BCELoss()
  • 如果 discriminator 沒有 sigmoid(),應該改用 BCEWithLogitsLoss()

使用場景

  • 二元分類問題(判斷某個輸入是否屬於某一類別)
  • GAN 判別器(Discriminator)(讓假圖輸出 0,真圖輸出 1
  • 異常檢測(判斷某個輸入是否異常)

範例

criterion = nn.BCELoss()

y_pred = torch.tensor([0.8, 0.2, 0.6])  # 模型輸出(0~1 的機率)
y_true = torch.tensor([1.0, 0.0, 1.0])  # 真實標籤(0 或 1)

loss = criterion(y_pred, y_true)
print(loss.item())  # 輸出 BCE Loss 值

🔹 nn.L1Loss() vs nn.BCELoss() 在 GAN 應用

生成對抗網絡(GAN) 中:

  • Generator(生成器)常使用 L1Loss()L2Loss() 來讓輸出更貼近真實圖片。
  • Discriminator(判別器)通常使用 BCELoss() 來區分 真實圖片(label=1)生成圖片(label=0)

✅ 總結

損失函數適用於典型輸入典型輸出何時使用?
nn.L1Loss()迴歸、影像生成任意數值實數誤差生成器(Generator),讓輸出與目標圖片接近
nn.BCELoss()分類、GAN 判別器[0,1] 機率[0,1] 損失判別器(Discriminator),區分真實與生成數據

如果你的模型要 學習字型風格(例如 Zen Maru Gothic),L1Loss 更合適,因為它直接度量像素差異。

如果你是在訓練 GAN 的判別器,則應該使用 BCELoss 來判斷真實和生成數據的概率。

👉 總結來說

  • criterion = nn.L1Loss() 適用於 生成器(比較生成字與真實字的差異)
  • criterion = nn.BCELoss() 適用於 判別器(區分真實與生成的圖片)

🚀 希望這個解釋對你有幫助!如果有更多問題,請隨時問我! 🚀


這個版本的 trainer.py 使用 U-Net 結構 來強化 Generator,讓補全字型時能更好地保持字形細節。
Generator 改為 U-Net 架構 (UnetSkipConnectionBlock)
支援 batch_size 調整
解析度 256x256
加入 SelfAttention
L1Loss() 讓生成結果更接近 Zen Maru Gothic
每個 epoch 儲存模型,支援繼續訓練


📌 trainer.py (使用 U-Net Generator)

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import os

# 設定裝置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 超參數設定
batch_size = 16  # 可調整
lr = 0.0002
epochs = 100
image_size = 256

# 影像轉換
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # 轉換為單通道灰階
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# 資料集與 DataLoader
dataset_path = "./datasets/zenmaru"
dataset = datasets.ImageFolder(root=dataset_path, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# ===========================
#       Self Attention
# ===========================
class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        self.query = nn.Conv2d(in_channels, in_channels // 8, 1)
        self.key = nn.Conv2d(in_channels, in_channels // 8, 1)
        self.value = nn.Conv2d(in_channels, in_channels, 1)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        B, C, H, W = x.shape
        query = self.query(x).view(B, -1, H * W)
        key = self.key(x).view(B, -1, H * W).permute(0, 2, 1)
        value = self.value(x).view(B, -1, H * W)

        attention = self.softmax(torch.bmm(query, key))
        out = torch.bmm(value, attention.permute(0, 2, 1))
        out = out.view(B, C, H, W)
        return out + x

# ===========================
#    UNet Skip Connection Block
# ===========================
class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, in_channels, out_channels, submodule=None, use_attention=False):
        super(UnetSkipConnectionBlock, self).__init__()
        self.submodule = submodule
        self.use_attention = use_attention

        # Down-sampling
        self.down = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, True)
        )

        # Self-Attention
        if self.use_attention:
            self.attn = SelfAttention(out_channels)

        # Up-sampling
        self.up = nn.Sequential(
            nn.ConvTranspose2d(out_channels * 2, in_channels, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(in_channels),
            nn.ReLU()
        )

    def forward(self, x):
        down_x = self.down(x)
        if self.use_attention:
            down_x = self.attn(down_x)

        if self.submodule is not None:
            down_x = self.submodule(down_x)

        up_x = self.up(torch.cat([down_x, x], dim=1))
        return up_x

# ===========================
#       Generator (U-Net)
# ===========================
class UNetGenerator(nn.Module):
    def __init__(self):
        super(UNetGenerator, self).__init__()
        self.model = UnetSkipConnectionBlock(1, 64,
                      UnetSkipConnectionBlock(64, 128, 
                      UnetSkipConnectionBlock(128, 256, use_attention=True)))  # 在 64x64 解析度加 SelfAttention

    def forward(self, x):
        return self.model(x)

# ===========================
#       Discriminator
# ===========================
class PatchDiscriminator(nn.Module):
    def __init__(self):
        super(PatchDiscriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 64, 4, 2, 1), nn.LeakyReLU(0.2, True),
            nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, True),
            SelfAttention(128),  # Self-Attention at 64x64
            nn.Conv2d(128, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, True),
            nn.Conv2d(256, 1, 4, 1, 0), nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

# 初始化模型
generator = UNetGenerator().to(device)
discriminator = PatchDiscriminator().to(device)

# 優化器與損失函數
criterion = nn.L1Loss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

# 嘗試載入已有的模型
if os.path.exists("generator.pth"):
    generator.load_state_dict(torch.load("generator.pth"))
    print("✅ Generator 模型已載入")
if os.path.exists("discriminator.pth"):
    discriminator.load_state_dict(torch.load("discriminator.pth"))
    print("✅ Discriminator 模型已載入")

# ===========================
#       訓練模型
# ===========================
for epoch in range(epochs):
    for i, (noto_sans, _) in enumerate(dataloader):
        noto_sans = noto_sans.to(device)

        # 訓練 Generator
        optimizer_G.zero_grad()
        fake_zenmaru = generator(noto_sans)
        g_loss = criterion(fake_zenmaru, noto_sans)  # 目標是生成接近原始字的風格
        g_loss.backward()
        optimizer_G.step()

        # 訓練 Discriminator
        optimizer_D.zero_grad()
        real_loss = criterion(discriminator(noto_sans), torch.ones_like(discriminator(noto_sans)))
        fake_loss = criterion(discriminator(fake_zenmaru.detach()), torch.zeros_like(discriminator(fake_zenmaru.detach())))
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        # 訓練進度輸出
        if i % 10 == 0:
            print(f"Epoch [{epoch+1}/{epochs}] Step [{i}/{len(dataloader)}] | G Loss: {g_loss.item():.4f} | D Loss: {d_loss.item():.4f}")

    # 每個 epoch 存模型
    torch.save(generator.state_dict(), "generator.pth")
    torch.save(discriminator.state_dict(), "discriminator.pth")
    print(f"✅ 模型已儲存: Epoch {epoch+1}")

print("🎉 訓練完成!")

📌 改進點

  1. Generator 改為 U-Net 架構 (UnetSkipConnectionBlock)
  2. 批次大小 (batch_size) 可調整
  3. 解析度 256x256
  4. 加入 SelfAttention
  5. L1Loss() 讓字型生成結果更接近 Zen Maru Gothic
  6. 每個 epoch 儲存模型,支援繼續訓練

🚀 這個版本的 trainer.py 已經最佳化!可以訓練 Zen Maru Gothic 缺字補全了! 🚀

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *