好的,我們將 Feature Matching Loss 的程式碼加入到 Zi2ZiModel
類別中,並在 backward_G
和 optimize_parameters
方法中加入對應的計算和優化步驟。
首先,我們需要在 Zi2ZiModel
的 __init__
方法中初始化 feature_matching_loss
:
Python
class Zi2ZiModel:
def __init__(self, input_nc=1, embedding_num=40, embedding_dim=128, ngf=64, ndf=64,
Lconst_penalty=10, Lcategory_penalty=1, L1_penalty=100,
schedule=10, lr=0.001, gpu_ids=None, save_dir='.', is_training=True,
image_size=256, self_attention=False, residual_block=False,
weight_decay = 1e-5, final_channels=1, beta1=0.5, g_blur=False, d_blur=False, epoch=40):
# ... (其他初始化程式碼)
self.feature_matching_loss = nn.L1Loss() # 初始化 Feature Matching Loss
然後,我們需要修改 compute_feature_matching_loss
函數,確保它返回的是 Discriminator 中間層的特徵,而非最終的 logits:
Python
class Zi2ZiModel:
# ... (其他程式碼)
def compute_feature_matching_loss(self, real_AB, fake_AB):
real_features = self.netD.model[:-1](real_AB) # 取得 Discriminator 中間層的特徵
fake_features = self.netD.model[:-1](fake_AB) # 取得 Discriminator 中間層的特徵
return self.feature_matching_loss(real_features, fake_features)
注意: 這裡我們使用 self.netD.model[:-1]
來獲取 Discriminator 的中間層特徵,而不是最終的 logits。
接下來,我們在 backward_G
方法中加入 Feature Matching Loss 的計算:
Python
class Zi2ZiModel:
# ... (其他程式碼)
def backward_G(self, no_target_source=False):
fake_AB = torch.cat([self.real_A, self.fake_B], 1)
real_AB = torch.cat([self.real_A, self.real_B], 1)
fake_D_logits, fake_category_logits = self.netD(fake_AB)
real_D_logits, _ = self.netD(real_AB)
const_loss = self.Lconst_penalty * self.mse(self.encoded_real_A, self.encoded_fake_B)
l1_loss = self.L1_penalty * self.l1_loss(self.fake_B, self.real_B)
fake_category_loss = self.Lcategory_penalty * self.category_loss(fake_category_logits, self.labels)
g_loss_adv = -torch.mean(F.logsigmoid(fake_D_logits - real_D_logits))
# 計算 Feature Matching Loss
fm_loss = self.compute_feature_matching_loss(real_AB, fake_AB)
self.g_loss = g_loss_adv + l1_loss + fake_category_loss + const_loss + fm_loss # 加入 fm_loss
return const_loss, l1_loss, g_loss_adv, fm_loss # 返回 fm_loss
最後,我們在 optimize_parameters
方法中加入 Feature Matching Loss 的處理:
Python
class Zi2ZiModel:
# ... (其他程式碼)
def optimize_parameters(self, use_autocast=False):
# ... (其他程式碼)
self.set_requires_grad(self.netD, False)
self.optimizer_G.zero_grad()
const_loss, l1_loss, cheat_loss, fm_loss = 0, 0, 0, 0 # 初始化 fm_loss
if use_autocast:
with autocast():
const_loss, l1_loss, cheat_loss, fm_loss = self.backward_G()
scaled_g_loss = self.scaler_G.scale(self.g_loss)
scaled_g_loss.backward()
self.scaler_G.step(self.optimizer_G)
self.scaler_G.update()
else:
const_loss, l1_loss, cheat_loss, fm_loss = self.backward_G()
self.g_loss.backward()
self.optimizer_G.step()
self.forward()
self.optimizer_G.zero_grad()
if use_autocast:
with autocast():
const_loss, l1_loss, cheat_loss, fm_loss = self.backward_G()
scaled_g_loss = self.scaler_G.scale(self.g_loss)
scaled_g_loss.backward()
self.scaler_G.step(self.optimizer_G)
self.scaler_G.update()
else:
const_loss, l1_loss, cheat_loss, fm_loss = self.backward_G()
self.g_loss.backward()
self.optimizer_G.step()
return const_loss, l1_loss, cheat_loss, fm_loss # 返回 fm_loss
這樣,我們就成功地將 Feature Matching Loss 加入到了 Zi2ZiModel
中,並在訓練過程中進行了計算和優化。記得在訓練過程中監控 fm_loss
的變化,以確保其正常收斂。