將 CategoryLoss 和 BinaryLoss 合併為一個損失函數,並將 BinaryLoss 由 BCEWithLogitsLoss 改為 Hinge Loss

Hinge Loss 介紹:

Hinge Loss 在 GAN 中常用來替代 BCE loss,定義如下:

  • 對於真樣本: max(0, 1 - D(real))
  • 對於假樣本: max(0, 1 + D(fake))
  • 最小化時: 讓 D(real) >= 1D(fake) <= -1

這比起 BCE loss 更穩定,因為 BCE loss 容易導致梯度消失。

這個 trainer.py 已整合建議:

  1. 使用 Hinge Loss 取代 BCEWithLogitsLoss。
  2. 合併 BinaryLoss 與 CategoryLoss 變成 DiscriminatorLoss
  3. G 的 loss 計算 修正 cheat_loss,使用 Hinge loss 來訓練 G。
import torch
import torch.nn as nn
import torch.optim as optim

class DiscriminatorLoss(nn.Module):
    def __init__(self, category_num):
        super(DiscriminatorLoss, self).__init__()
        self.category_loss = nn.CrossEntropyLoss()
    
    def hinge_loss(self, logits, real):
        if real:
            return torch.mean(torch.relu(1.0 - logits))
        else:
            return torch.mean(torch.relu(1.0 + logits))

    def forward(self, binary_logits, category_logits, labels, real=True):
        binary_loss = self.hinge_loss(binary_logits, real)
        category_loss = self.category_loss(category_logits, labels)
        return binary_loss, category_loss

class Zi2ZiModel(nn.Module):
    def __init__(self, netG, netD, embedding_num, Lcategory_penalty=1.0, Lconst_penalty=15.0, L1_penalty=100.0):
        super(Zi2ZiModel, self).__init__()
        self.netG = netG
        self.netD = netD
        self.embedding_num = embedding_num
        self.Lcategory_penalty = Lcategory_penalty
        self.Lconst_penalty = Lconst_penalty
        self.L1_penalty = L1_penalty
        self.mse = nn.MSELoss()
        self.l1_loss = nn.L1Loss()
        self.discriminator_loss = DiscriminatorLoss(self.embedding_num)
    
    def set_input(self, real_A, real_B, labels):
        self.real_A = real_A
        self.real_B = real_B
        self.labels = labels

    def forward(self):
        self.fake_B = self.netG(self.real_A)

    def backward_D(self):
        real_AB = torch.cat([self.real_A, self.real_B], 1)
        fake_AB = torch.cat([self.real_A, self.fake_B], 1)

        real_D_logits, real_category_logits = self.netD(real_AB)
        fake_D_logits, fake_category_logits = self.netD(fake_AB.detach())

        d_loss_real, real_category_loss = self.discriminator_loss(real_D_logits, real_category_logits, self.labels, real=True)
        d_loss_fake, fake_category_loss = self.discriminator_loss(fake_D_logits, fake_category_logits, self.labels, real=False)

        self.d_loss = d_loss_real + d_loss_fake + (real_category_loss + fake_category_loss) * self.Lcategory_penalty / 2.0
        self.d_loss.backward()
    
    def backward_G(self):
        fake_AB = torch.cat([self.real_A, self.fake_B], 1)
        fake_D_logits, fake_category_logits = self.netD(fake_AB)

        const_loss = self.Lconst_penalty * self.mse(self.real_A, self.fake_B)
        l1_loss = self.L1_penalty * self.l1_loss(self.fake_B, self.real_B)
        _, fake_category_loss = self.discriminator_loss(fake_D_logits, fake_category_logits, self.labels, real=True)
        
        cheat_loss = -torch.mean(fake_D_logits)
        
        self.g_loss = cheat_loss + l1_loss + fake_category_loss + const_loss
        self.g_loss.backward()
    
    def optimize_parameters(self, optimizer_G, optimizer_D):
        self.forward()
        
        optimizer_D.zero_grad()
        self.backward_D()
        optimizer_D.step()
        
        optimizer_G.zero_grad()
        self.backward_G()
        optimizer_G.step()

gemini

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import functools
import math
import time
import os

# 假設 args, checkpoint_dir, data_dir, DatasetFromObj, init_net, global_steps, start_time 已定義

class UNetGenerator(nn.Module):
    # ... (與原程式碼相同)

class UnetSkipConnectionBlock(nn.Module):
    # ... (與原程式碼相同)

class Discriminator(nn.Module):
    # ... (與原程式碼相同)

class CombinedLoss(nn.Module):
    def __init__(self, category_num, category_weight=1.0, hinge_weight=1.0):
        super(CombinedLoss, self).__init__()
        self.category_num = category_num
        self.category_weight = category_weight
        self.hinge_weight = hinge_weight
        self.emb = nn.Embedding(category_num, category_num)
        self.emb.weight.data = torch.eye(category_num)
        self.bce = nn.BCEWithLogitsLoss()

    def forward(self, binary_logits, category_logits, labels, real):
        # Category Loss
        target = self.emb(labels)
        category_loss = self.bce(category_logits, target) * self.category_weight

        # Hinge Loss
        if real:
            hinge_loss = torch.relu(1 - binary_logits).mean() * self.hinge_weight
        else:
            hinge_loss = torch.relu(1 + binary_logits).mean() * self.hinge_weight

        return category_loss + hinge_loss

class Zi2ZiModel:
    def __init__(self, input_nc=3, embedding_num=40, embedding_dim=128,
                 ngf=64, ndf=64,
                 Lconst_penalty=15, Lcategory_penalty=1, L1_penalty=100,
                 schedule=10, lr=0.001, gpu_ids=None, save_dir='.', is_training=True,
                 image_size=256):

        # ... (與原程式碼相同)

    def setup(self):
        # ... (與原程式碼相同)
        self.combined_loss = CombinedLoss(self.embedding_num, category_weight=self.Lcategory_penalty) #初始化 combined loss

        if self.gpu_ids:
            # ... (與原程式碼相同)
            self.combined_loss.cuda() #將 combined loss 放到 cuda

    def set_input(self, labels, real_A, real_B):
        # ... (與原程式碼相同)

    def forward(self):
        self.fake_B, self.encoded_real_A = self.netG(self.real_A, self.labels)
        self.encoded_fake_B = self.netG(self.fake_B).view(self.fake_B.shape[0], -1)

    def backward_D(self, no_target_source=False):

        real_AB = torch.cat([self.real_A, self.real_B], 1)
        fake_AB = torch.cat([self.real_A, self.fake_B], 1)

        real_D_logits, real_category_logits = self.netD(real_AB)
        fake_D_logits, fake_category_logits = self.netD(fake_AB.detach())

        real_loss = self.combined_loss(real_D_logits, real_category_logits, self.labels, real=True)
        fake_loss = self.combined_loss(fake_D_logits, fake_category_logits, self.labels, real=False)

        self.d_loss = real_loss + fake_loss
        self.d_loss.backward()
        return real_loss.item() #返回 real loss

    def backward_G(self, no_target_source=False):
        fake_AB = torch.cat([self.real_A, self.fake_B], 1)
        fake_D_logits, fake_category_logits = self.netD(fake_AB)

        const_loss = self.Lconst_penalty * self.mse(self.encoded_real_A, self.encoded_fake_B)
        l1_loss = self.L1_penalty * self.l1_loss(self.fake_B, self.real_B)
        cheat_loss = self.combined_loss(fake_D_logits, fake_category_logits, self.labels, real=True) #使用 combined loss

        self.g_loss = cheat_loss + l1_loss + const_loss
        self.g_loss.backward()
        return const_loss, l1_loss, cheat_loss.item() #返回 cheat loss

    def update_lr(self):
        # ... (與原程式碼相同)

    def optimize_parameters(self):
        self.forward()  # compute fake images: G(A)
        # update D
        self.set_requires_grad(self.netD, True)  # enable backprop for D
        self.optimizer_D.zero_grad()  # set D's gradients to zero
        category_loss = self.backward_D()  # calculate gradients for D
        self.optimizer_D.step()  # update D's weights
        # update G
        self.set_requires_grad(self.netD, False)  # D requires no gradients when optimizing G
        self.optimizer_G.zero_grad()  # set G's gradients to zero
        const_loss, l1_loss, cheat_loss = self.backward_G()  # calculate gradients for G
        self.optimizer_G.step()  # udpate G's weights

        self.forward()  # compute fake images: G(A)
        self.optimizer_G.zero_grad()  # set G's gradients to zero
        const_loss, l1_loss, cheat_loss = self.backward_G()  # calculate gradients for G
        self.optimizer_G.step()  # udpate G's weights
        return const_loss, l1_loss, category_loss, cheat_loss

    def set_requires_grad(self, nets, requires_grad=False):
        # ... (與原程式碼相同)

model = Zi2ZiModel(
    input_nc=args.input_nc,
    embedding_num=args.embedding_num,
    embedding_dim=args.embedding_dim,
    Lconst_penalty=args.Lconst_penalty,
    Lcategory_penalty=args.Lcategory_penalty,
    save_dir=checkpoint_dir,
    gpu_ids=args.gpu_ids
)
model.setup()
if args.resume:
    model.load_networks(args.resume)
train_dataset = DatasetFromObj(os.path.join(data_dir, 'train.obj'))
total_batches = math.ceil(len(train_dataset) / args.batch_size)
for epoch in range(args.epoch):
    dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    for bid, batch in enumerate(dataloader):
        model.set_input(batch[0], batch[2], batch[1])
        const_loss, l1_loss, category_loss, cheat_loss = model.optimize_parameters()
        if bid % 100 == 0:
            passed = time.time() - start_time
            log_format = "Epoch: [%2d], [%4d/%4d] time: %4.2f, d_loss: %.5f, g_loss: %.5f, " + \
                         "category_loss: %.5f, cheat_loss: %.5f, const_loss: %.5f, l1_loss: %.5f"
            print(log_format % (epoch, bid, total_batches, passed, model.d_loss, model.g_loss.item(),
                                 category_loss, cheat_loss, const_loss, l1_loss))
        if global_steps % args.checkpoint_steps == 0:
            model.save_networks(global_steps)
        global_steps += 1
    if (epoch + 1) % args.schedule == 0:
        model.update_lr()

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *