Hinge Loss 介紹:
Hinge Loss 在 GAN 中常用來替代 BCE loss,定義如下:
- 對於真樣本:
max(0, 1 - D(real))
- 對於假樣本:
max(0, 1 + D(fake))
- 最小化時: 讓
D(real) >= 1
,D(fake) <= -1
這比起 BCE loss 更穩定,因為 BCE loss 容易導致梯度消失。
這個 trainer.py
已整合建議:
- 使用 Hinge Loss 取代 BCEWithLogitsLoss。
- 合併 BinaryLoss 與 CategoryLoss 變成
DiscriminatorLoss
。 - G 的 loss 計算 修正
cheat_loss
,使用 Hinge loss 來訓練 G。
import torch
import torch.nn as nn
import torch.optim as optim
class DiscriminatorLoss(nn.Module):
def __init__(self, category_num):
super(DiscriminatorLoss, self).__init__()
self.category_loss = nn.CrossEntropyLoss()
def hinge_loss(self, logits, real):
if real:
return torch.mean(torch.relu(1.0 - logits))
else:
return torch.mean(torch.relu(1.0 + logits))
def forward(self, binary_logits, category_logits, labels, real=True):
binary_loss = self.hinge_loss(binary_logits, real)
category_loss = self.category_loss(category_logits, labels)
return binary_loss, category_loss
class Zi2ZiModel(nn.Module):
def __init__(self, netG, netD, embedding_num, Lcategory_penalty=1.0, Lconst_penalty=15.0, L1_penalty=100.0):
super(Zi2ZiModel, self).__init__()
self.netG = netG
self.netD = netD
self.embedding_num = embedding_num
self.Lcategory_penalty = Lcategory_penalty
self.Lconst_penalty = Lconst_penalty
self.L1_penalty = L1_penalty
self.mse = nn.MSELoss()
self.l1_loss = nn.L1Loss()
self.discriminator_loss = DiscriminatorLoss(self.embedding_num)
def set_input(self, real_A, real_B, labels):
self.real_A = real_A
self.real_B = real_B
self.labels = labels
def forward(self):
self.fake_B = self.netG(self.real_A)
def backward_D(self):
real_AB = torch.cat([self.real_A, self.real_B], 1)
fake_AB = torch.cat([self.real_A, self.fake_B], 1)
real_D_logits, real_category_logits = self.netD(real_AB)
fake_D_logits, fake_category_logits = self.netD(fake_AB.detach())
d_loss_real, real_category_loss = self.discriminator_loss(real_D_logits, real_category_logits, self.labels, real=True)
d_loss_fake, fake_category_loss = self.discriminator_loss(fake_D_logits, fake_category_logits, self.labels, real=False)
self.d_loss = d_loss_real + d_loss_fake + (real_category_loss + fake_category_loss) * self.Lcategory_penalty / 2.0
self.d_loss.backward()
def backward_G(self):
fake_AB = torch.cat([self.real_A, self.fake_B], 1)
fake_D_logits, fake_category_logits = self.netD(fake_AB)
const_loss = self.Lconst_penalty * self.mse(self.real_A, self.fake_B)
l1_loss = self.L1_penalty * self.l1_loss(self.fake_B, self.real_B)
_, fake_category_loss = self.discriminator_loss(fake_D_logits, fake_category_logits, self.labels, real=True)
cheat_loss = -torch.mean(fake_D_logits)
self.g_loss = cheat_loss + l1_loss + fake_category_loss + const_loss
self.g_loss.backward()
def optimize_parameters(self, optimizer_G, optimizer_D):
self.forward()
optimizer_D.zero_grad()
self.backward_D()
optimizer_D.step()
optimizer_G.zero_grad()
self.backward_G()
optimizer_G.step()
gemini
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import functools
import math
import time
import os
# 假設 args, checkpoint_dir, data_dir, DatasetFromObj, init_net, global_steps, start_time 已定義
class UNetGenerator(nn.Module):
# ... (與原程式碼相同)
class UnetSkipConnectionBlock(nn.Module):
# ... (與原程式碼相同)
class Discriminator(nn.Module):
# ... (與原程式碼相同)
class CombinedLoss(nn.Module):
def __init__(self, category_num, category_weight=1.0, hinge_weight=1.0):
super(CombinedLoss, self).__init__()
self.category_num = category_num
self.category_weight = category_weight
self.hinge_weight = hinge_weight
self.emb = nn.Embedding(category_num, category_num)
self.emb.weight.data = torch.eye(category_num)
self.bce = nn.BCEWithLogitsLoss()
def forward(self, binary_logits, category_logits, labels, real):
# Category Loss
target = self.emb(labels)
category_loss = self.bce(category_logits, target) * self.category_weight
# Hinge Loss
if real:
hinge_loss = torch.relu(1 - binary_logits).mean() * self.hinge_weight
else:
hinge_loss = torch.relu(1 + binary_logits).mean() * self.hinge_weight
return category_loss + hinge_loss
class Zi2ZiModel:
def __init__(self, input_nc=3, embedding_num=40, embedding_dim=128,
ngf=64, ndf=64,
Lconst_penalty=15, Lcategory_penalty=1, L1_penalty=100,
schedule=10, lr=0.001, gpu_ids=None, save_dir='.', is_training=True,
image_size=256):
# ... (與原程式碼相同)
def setup(self):
# ... (與原程式碼相同)
self.combined_loss = CombinedLoss(self.embedding_num, category_weight=self.Lcategory_penalty) #初始化 combined loss
if self.gpu_ids:
# ... (與原程式碼相同)
self.combined_loss.cuda() #將 combined loss 放到 cuda
def set_input(self, labels, real_A, real_B):
# ... (與原程式碼相同)
def forward(self):
self.fake_B, self.encoded_real_A = self.netG(self.real_A, self.labels)
self.encoded_fake_B = self.netG(self.fake_B).view(self.fake_B.shape[0], -1)
def backward_D(self, no_target_source=False):
real_AB = torch.cat([self.real_A, self.real_B], 1)
fake_AB = torch.cat([self.real_A, self.fake_B], 1)
real_D_logits, real_category_logits = self.netD(real_AB)
fake_D_logits, fake_category_logits = self.netD(fake_AB.detach())
real_loss = self.combined_loss(real_D_logits, real_category_logits, self.labels, real=True)
fake_loss = self.combined_loss(fake_D_logits, fake_category_logits, self.labels, real=False)
self.d_loss = real_loss + fake_loss
self.d_loss.backward()
return real_loss.item() #返回 real loss
def backward_G(self, no_target_source=False):
fake_AB = torch.cat([self.real_A, self.fake_B], 1)
fake_D_logits, fake_category_logits = self.netD(fake_AB)
const_loss = self.Lconst_penalty * self.mse(self.encoded_real_A, self.encoded_fake_B)
l1_loss = self.L1_penalty * self.l1_loss(self.fake_B, self.real_B)
cheat_loss = self.combined_loss(fake_D_logits, fake_category_logits, self.labels, real=True) #使用 combined loss
self.g_loss = cheat_loss + l1_loss + const_loss
self.g_loss.backward()
return const_loss, l1_loss, cheat_loss.item() #返回 cheat loss
def update_lr(self):
# ... (與原程式碼相同)
def optimize_parameters(self):
self.forward() # compute fake images: G(A)
# update D
self.set_requires_grad(self.netD, True) # enable backprop for D
self.optimizer_D.zero_grad() # set D's gradients to zero
category_loss = self.backward_D() # calculate gradients for D
self.optimizer_D.step() # update D's weights
# update G
self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G
self.optimizer_G.zero_grad() # set G's gradients to zero
const_loss, l1_loss, cheat_loss = self.backward_G() # calculate gradients for G
self.optimizer_G.step() # udpate G's weights
self.forward() # compute fake images: G(A)
self.optimizer_G.zero_grad() # set G's gradients to zero
const_loss, l1_loss, cheat_loss = self.backward_G() # calculate gradients for G
self.optimizer_G.step() # udpate G's weights
return const_loss, l1_loss, category_loss, cheat_loss
def set_requires_grad(self, nets, requires_grad=False):
# ... (與原程式碼相同)
model = Zi2ZiModel(
input_nc=args.input_nc,
embedding_num=args.embedding_num,
embedding_dim=args.embedding_dim,
Lconst_penalty=args.Lconst_penalty,
Lcategory_penalty=args.Lcategory_penalty,
save_dir=checkpoint_dir,
gpu_ids=args.gpu_ids
)
model.setup()
if args.resume:
model.load_networks(args.resume)
train_dataset = DatasetFromObj(os.path.join(data_dir, 'train.obj'))
total_batches = math.ceil(len(train_dataset) / args.batch_size)
for epoch in range(args.epoch):
dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
for bid, batch in enumerate(dataloader):
model.set_input(batch[0], batch[2], batch[1])
const_loss, l1_loss, category_loss, cheat_loss = model.optimize_parameters()
if bid % 100 == 0:
passed = time.time() - start_time
log_format = "Epoch: [%2d], [%4d/%4d] time: %4.2f, d_loss: %.5f, g_loss: %.5f, " + \
"category_loss: %.5f, cheat_loss: %.5f, const_loss: %.5f, l1_loss: %.5f"
print(log_format % (epoch, bid, total_batches, passed, model.d_loss, model.g_loss.item(),
category_loss, cheat_loss, const_loss, l1_loss))
if global_steps % args.checkpoint_steps == 0:
model.save_networks(global_steps)
global_steps += 1
if (epoch + 1) % args.schedule == 0:
model.update_lr()