Discriminator 將 CategoryLoss 和 BinaryLoss 合併為一個損失函數,簡化程式碼並提高效率, 是否有助推論品質?

合併 CategoryLossBinaryLoss 成一個損失函數可以簡化程式碼並提高計算效率,特別是在 GPU 訓練時能減少額外的 kernel 啟動開銷。然而,是否能提升推論品質(字型風格轉換效果)則需要考慮以下幾點:

優勢

  1. 更好的梯度傳播:合併兩個損失可以確保梯度計算的穩定性,減少可能的 loss scaling 問題。
  2. 減少計算開銷:兩個 BCEWithLogitsLoss 可以合併計算,減少 batch-wise 計算時間。
  3. 統一權重調整:可以讓 Lcategory_penaltyBinaryLoss 權重更直觀地調整,減少超參數調整的複雜度。

可能的影響

  1. 損失值分佈變化:合併後,可能會導致某個 loss(例如 CategoryLoss)影響較大,使 BinaryLoss 變得較不重要,這可能會影響訓練的穩定性。
  2. 需要調整 Lcategory_penalty:原先的 category_lossBCEWithLogitsLoss,但合併後可能需要重新調整 Lcategory_penalty 的權重,確保分類準確度不下降。
  3. Discriminator 的影響:如果 Discriminator 同時學習兩種 loss,可能會造成競爭影響,需要確認合併後的 Discriminator 訓練仍然穩定。

如何合併?

可以定義一個新的 CombinedLoss

class CombinedLoss(nn.Module):
    def __init__(self, category_num, real=True, lambda_category=1.0, lambda_binary=1.0):
        super(CombinedLoss, self).__init__()
        self.category_num = category_num
        self.lambda_category = lambda_category
        self.lambda_binary = lambda_binary
        self.bce = nn.BCEWithLogitsLoss()
        self.emb = nn.Embedding(category_num, category_num)
        self.emb.weight.data = torch.eye(category_num)
        self.real = real

    def forward(self, logits, category_logits, labels):
        # 計算二分類 loss
        binary_labels = torch.ones(logits.shape[0], 1) if self.real else torch.zeros(logits.shape[0], 1)
        if logits.is_cuda:
            binary_labels = binary_labels.cuda()
        binary_loss = self.bce(logits, binary_labels)

        # 計算分類 loss
        target = self.emb(labels)
        category_loss = self.bce(category_logits, target)

        return self.lambda_binary * binary_loss + self.lambda_category * category_loss

這樣 Discriminator 只需要計算一次 CombinedLoss,避免額外的計算。

結論

這樣的改動應該可以提高計算效率,對訓練收斂速度也可能有幫助。不過,為了確保推論品質不下降,建議在合併後:

  1. 監控 Discriminator loss 的變化,確保 BinaryLossCategoryLoss 不會互相干擾。
  2. 調整 Lcategory_penalty 權重,確保分類準確度不下降。
  3. 觀察 G 的訓練穩定性,確保梯度更新仍然合理。

如果推論品質下降,可能需要重新拆分 loss,或調整權重參數來平衡分類與真假辨識的影響。


import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

# 假設已有 Generator 和 Discriminator 模型
class CombinedLoss(nn.Module):
    def __init__(self, category_num, lambda_category=1.0, lambda_binary=1.0):
        super(CombinedLoss, self).__init__()
        self.category_num = category_num
        self.lambda_category = lambda_category
        self.lambda_binary = lambda_binary
        self.bce = nn.BCEWithLogitsLoss()
        self.emb = nn.Embedding(category_num, category_num)
        self.emb.weight.data = torch.eye(category_num)

    def forward(self, logits, category_logits, labels, real=True):
        binary_labels = torch.ones_like(logits) if real else torch.zeros_like(logits)
        binary_loss = self.bce(logits, binary_labels)

        target = self.emb(labels)
        category_loss = self.bce(category_logits, target)

        return self.lambda_binary * binary_loss + self.lambda_category * category_loss


def train(generator, discriminator, dataloader, epochs=10, lr=0.0002, category_num=10):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    generator.to(device)
    discriminator.to(device)
    
    optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
    loss_fn = CombinedLoss(category_num).to(device)
    
    for epoch in range(epochs):
        for batch in dataloader:
            real_images, labels = batch
            real_images, labels = real_images.to(device), labels.to(device)
            
            batch_size = real_images.size(0)
            noise = torch.randn(batch_size, 100, device=device)  # 100 為 latent_dim
            fake_images = generator(noise, labels)
            
            # 訓練 D
            optimizer_D.zero_grad()
            real_logits, real_cat_logits = discriminator(real_images)
            fake_logits, fake_cat_logits = discriminator(fake_images.detach())
            
            loss_D_real = loss_fn(real_logits, real_cat_logits, labels, real=True)
            loss_D_fake = loss_fn(fake_logits, fake_cat_logits, labels, real=False)
            loss_D = (loss_D_real + loss_D_fake) / 2
            
            loss_D.backward()
            optimizer_D.step()
            
            # 訓練 G
            optimizer_G.zero_grad()
            fake_logits, fake_cat_logits = discriminator(fake_images)
            loss_G = loss_fn(fake_logits, fake_cat_logits, labels, real=True)
            
            loss_G.backward()
            optimizer_G.step()
            
        print(f"Epoch [{epoch+1}/{epochs}] | Loss D: {loss_D.item():.4f} | Loss G: {loss_G.item():.4f}")
    
    print("Training Finished!")

# 假設你有 dataset
dataloader = DataLoader(...)  # 你的 DataLoader
# 假設你有 generator 和 discriminator 模型
generator = ...
discriminator = ...

train(generator, discriminator, dataloader)

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *