🔹 nn.L1Loss()
vs nn.BCELoss()
的區別
這兩者是 不同類型的損失函數,主要用於不同的學習目標。
損失函數 | 適用場景 | 計算方式 | 典型輸入 | 典型輸出 |
---|---|---|---|---|
nn.L1Loss() | 迴歸(Regression) | 計算預測值與目標值的 絕對誤差(L1距離) | 任意實數範圍的值 | 任意實數範圍的誤差值 |
nn.BCELoss() | 分類(Classification) | 用於二元分類問題,計算 交叉熵損失(Binary Cross Entropy, BCE) | 範圍為 [0,1] 的機率值 | 範圍為 [0,1] 的損失值 |
🔹 nn.L1Loss()
(L1 絕對誤差損失)
L1Loss()
是 L1 損失函數,也稱為 Mean Absolute Error (MAE)。適用於 回歸問題,用來衡量 預測值與真值之間的絕對差距。
數學公式
L=1N∑i=1N∣yi−y^i∣L = \frac{1}{N} \sum_{i=1}^{N} | y_i – \hat{y}_i |
其中:
- yiy_i 是 真實值
- y^i\hat{y}_i 是 預測值
- NN 是 樣本數
使用場景
- 風格學習(如字型補全)→ 目標是讓生成的字與真實字的像素接近
- 圖像處理(如去噪、自動色彩化)
- 回歸問題(預測數值,如房價、溫度)
範例
import torch
import torch.nn as nn
criterion = nn.L1Loss()
y_pred = torch.tensor([0.2, 0.8, 0.3]) # 模型輸出
y_true = torch.tensor([0.0, 1.0, 0.5]) # 目標值
loss = criterion(y_pred, y_true)
print(loss.item()) # 輸出 L1 Loss 值
🔹 nn.BCELoss()
(二元交叉熵損失)
BCELoss()
是 Binary Cross Entropy(BCE),適用於 二元分類問題,通常用來訓練 判別器(Discriminator),讓其輸出接近 0
(假)或 1
(真)。
數學公式
L=−1N∑i=1N[yilog(y^i)+(1−yi)log(1−y^i)]L = -\frac{1}{N} \sum_{i=1}^{N} \left[ y_i \log(\hat{y}_i) + (1 – y_i) \log(1 – \hat{y}_i) \right]
其中:
- yiy_i 是 真實標籤(0 或 1)
- y^i\hat{y}_i 是 模型的輸出(範圍為 [0,1])
- NN 是 樣本數
⚠ 注意:BCE 需要
sigmoid()
處理輸出!
- 如果
discriminator
最後一層是nn.Sigmoid()
,可以直接用BCELoss()
- 如果
discriminator
沒有sigmoid()
,應該改用BCEWithLogitsLoss()
使用場景
- 二元分類問題(判斷某個輸入是否屬於某一類別)
- GAN 判別器(Discriminator)(讓假圖輸出
0
,真圖輸出1
) - 異常檢測(判斷某個輸入是否異常)
範例
criterion = nn.BCELoss()
y_pred = torch.tensor([0.8, 0.2, 0.6]) # 模型輸出(0~1 的機率)
y_true = torch.tensor([1.0, 0.0, 1.0]) # 真實標籤(0 或 1)
loss = criterion(y_pred, y_true)
print(loss.item()) # 輸出 BCE Loss 值
🔹 nn.L1Loss()
vs nn.BCELoss()
在 GAN 應用
在 生成對抗網絡(GAN) 中:
- Generator(生成器)常使用
L1Loss()
或L2Loss()
來讓輸出更貼近真實圖片。 - Discriminator(判別器)通常使用
BCELoss()
來區分 真實圖片(label=1) 和 生成圖片(label=0)。
✅ 總結
損失函數 | 適用於 | 典型輸入 | 典型輸出 | 何時使用? |
---|---|---|---|---|
nn.L1Loss() | 迴歸、影像生成 | 任意數值 | 實數誤差 | 生成器(Generator),讓輸出與目標圖片接近 |
nn.BCELoss() | 分類、GAN 判別器 | [0,1] 機率 | [0,1] 損失 | 判別器(Discriminator),區分真實與生成數據 |
如果你的模型要 學習字型風格(例如 Zen Maru Gothic),L1Loss 更合適,因為它直接度量像素差異。
如果你是在訓練 GAN 的判別器,則應該使用 BCELoss 來判斷真實和生成數據的概率。
👉 總結來說
criterion = nn.L1Loss()
適用於 生成器(比較生成字與真實字的差異)criterion = nn.BCELoss()
適用於 判別器(區分真實與生成的圖片)
🚀 希望這個解釋對你有幫助!如果有更多問題,請隨時問我! 🚀
這個版本的 trainer.py
使用 U-Net 結構 來強化 Generator,讓補全字型時能更好地保持字形細節。
✅ Generator 改為 U-Net 架構 (UnetSkipConnectionBlock)
✅ 支援 batch_size
調整
✅ 解析度 256x256
✅ 加入 SelfAttention
✅ L1Loss() 讓生成結果更接近 Zen Maru Gothic
✅ 每個 epoch 儲存模型,支援繼續訓練
📌 trainer.py
(使用 U-Net Generator)
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import os
# 設定裝置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 超參數設定
batch_size = 16 # 可調整
lr = 0.0002
epochs = 100
image_size = 256
# 影像轉換
transform = transforms.Compose([
transforms.Grayscale(num_output_channels=1), # 轉換為單通道灰階
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
# 資料集與 DataLoader
dataset_path = "./datasets/zenmaru"
dataset = datasets.ImageFolder(root=dataset_path, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# ===========================
# Self Attention
# ===========================
class SelfAttention(nn.Module):
def __init__(self, in_channels):
super(SelfAttention, self).__init__()
self.query = nn.Conv2d(in_channels, in_channels // 8, 1)
self.key = nn.Conv2d(in_channels, in_channels // 8, 1)
self.value = nn.Conv2d(in_channels, in_channels, 1)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
B, C, H, W = x.shape
query = self.query(x).view(B, -1, H * W)
key = self.key(x).view(B, -1, H * W).permute(0, 2, 1)
value = self.value(x).view(B, -1, H * W)
attention = self.softmax(torch.bmm(query, key))
out = torch.bmm(value, attention.permute(0, 2, 1))
out = out.view(B, C, H, W)
return out + x
# ===========================
# UNet Skip Connection Block
# ===========================
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, in_channels, out_channels, submodule=None, use_attention=False):
super(UnetSkipConnectionBlock, self).__init__()
self.submodule = submodule
self.use_attention = use_attention
# Down-sampling
self.down = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.2, True)
)
# Self-Attention
if self.use_attention:
self.attn = SelfAttention(out_channels)
# Up-sampling
self.up = nn.Sequential(
nn.ConvTranspose2d(out_channels * 2, in_channels, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(in_channels),
nn.ReLU()
)
def forward(self, x):
down_x = self.down(x)
if self.use_attention:
down_x = self.attn(down_x)
if self.submodule is not None:
down_x = self.submodule(down_x)
up_x = self.up(torch.cat([down_x, x], dim=1))
return up_x
# ===========================
# Generator (U-Net)
# ===========================
class UNetGenerator(nn.Module):
def __init__(self):
super(UNetGenerator, self).__init__()
self.model = UnetSkipConnectionBlock(1, 64,
UnetSkipConnectionBlock(64, 128,
UnetSkipConnectionBlock(128, 256, use_attention=True))) # 在 64x64 解析度加 SelfAttention
def forward(self, x):
return self.model(x)
# ===========================
# Discriminator
# ===========================
class PatchDiscriminator(nn.Module):
def __init__(self):
super(PatchDiscriminator, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(1, 64, 4, 2, 1), nn.LeakyReLU(0.2, True),
nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, True),
SelfAttention(128), # Self-Attention at 64x64
nn.Conv2d(128, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, True),
nn.Conv2d(256, 1, 4, 1, 0), nn.Sigmoid()
)
def forward(self, x):
return self.model(x)
# 初始化模型
generator = UNetGenerator().to(device)
discriminator = PatchDiscriminator().to(device)
# 優化器與損失函數
criterion = nn.L1Loss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
# 嘗試載入已有的模型
if os.path.exists("generator.pth"):
generator.load_state_dict(torch.load("generator.pth"))
print("✅ Generator 模型已載入")
if os.path.exists("discriminator.pth"):
discriminator.load_state_dict(torch.load("discriminator.pth"))
print("✅ Discriminator 模型已載入")
# ===========================
# 訓練模型
# ===========================
for epoch in range(epochs):
for i, (noto_sans, _) in enumerate(dataloader):
noto_sans = noto_sans.to(device)
# 訓練 Generator
optimizer_G.zero_grad()
fake_zenmaru = generator(noto_sans)
g_loss = criterion(fake_zenmaru, noto_sans) # 目標是生成接近原始字的風格
g_loss.backward()
optimizer_G.step()
# 訓練 Discriminator
optimizer_D.zero_grad()
real_loss = criterion(discriminator(noto_sans), torch.ones_like(discriminator(noto_sans)))
fake_loss = criterion(discriminator(fake_zenmaru.detach()), torch.zeros_like(discriminator(fake_zenmaru.detach())))
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()
# 訓練進度輸出
if i % 10 == 0:
print(f"Epoch [{epoch+1}/{epochs}] Step [{i}/{len(dataloader)}] | G Loss: {g_loss.item():.4f} | D Loss: {d_loss.item():.4f}")
# 每個 epoch 存模型
torch.save(generator.state_dict(), "generator.pth")
torch.save(discriminator.state_dict(), "discriminator.pth")
print(f"✅ 模型已儲存: Epoch {epoch+1}")
print("🎉 訓練完成!")
📌 改進點
- Generator 改為 U-Net 架構 (UnetSkipConnectionBlock)
- 批次大小 (
batch_size
) 可調整 - 解析度
256x256
- 加入 SelfAttention
- L1Loss() 讓字型生成結果更接近 Zen Maru Gothic
- 每個 epoch 儲存模型,支援繼續訓練
🚀 這個版本的 trainer.py
已經最佳化!可以訓練 Zen Maru Gothic 缺字補全了! 🚀