Zi2ZiModel 字型風格轉換套用相對論對抗損失(Relativistic Loss),使得判別器學習「真比假更真」的概念,而不是單獨學習真假分類。

你需要修改 backward_D 來使用 相對論對抗損失(Relativistic Loss),使得判別器學習「真比假更真」的概念,而不是單獨學習真假分類。

修改點

  1. 計算真實與假樣本的 D 輸出之差異
    • 目前 real_D_logitsfake_D_logits 是分開計算 loss 的,但在相對論損失中,我們希望 讓 D 判斷真比假更真,因此改為: Dreal−DfakeD_{\text{real}} – D_{\text{fake}}
  2. 使用 sigmoidBCEWithLogitsLoss 計算相對論損失
    • 原本的 real_binary_lossfake_binary_loss 是獨立計算的,我們要改為: LD=E[log⁡(σ(Dreal−Dfake))]+E[log⁡(1−σ(Dfake−Dreal))]\mathcal{L}_D = \mathbb{E}[\log(\sigma(D_{\text{real}} – D_{\text{fake}}))] + \mathbb{E}[\log(1 – \sigma(D_{\text{fake}} – D_{\text{real}}))]

修改 backward_D

def backward_D(self):
    real_AB = torch.cat([self.real_A, self.real_B], 1)
    fake_AB = torch.cat([self.real_A, self.fake_B.detach()], 1)

    real_D_logits, real_category_logits = self.netD(real_AB)
    fake_D_logits, fake_category_logits = self.netD(fake_AB)

    real_category_loss = self.category_loss(real_category_logits, self.labels)
    fake_category_loss = self.category_loss(fake_category_logits, self.labels)
    category_loss = (real_category_loss + fake_category_loss) * self.Lcategory_penalty

    # Relativistic Loss
    d_loss_real = torch.mean(F.logsigmoid(real_D_logits - fake_D_logits))
    d_loss_fake = torch.mean(F.logsigmoid(fake_D_logits - real_D_logits))

    self.d_loss = - (d_loss_real + d_loss_fake) + category_loss / 2.0
    self.d_loss.backward()
    return category_loss

修改 backward_G

對應地,G 的 loss 也要改為相對論損失:

def backward_G(self):
    fake_AB = torch.cat([self.real_A, self.fake_B], 1)
    real_AB = torch.cat([self.real_A, self.real_B], 1)

    fake_D_logits, fake_category_logits = self.netD(fake_AB)
    real_D_logits, _ = self.netD(real_AB)

    const_loss = self.Lconst_penalty * self.mse(self.encoded_real_A, self.encoded_fake_B)
    l1_loss = self.L1_penalty * self.l1_loss(self.fake_B, self.real_B)
    fake_category_loss = self.Lcategory_penalty * self.category_loss(fake_category_logits, self.labels)

    # Relativistic Generator Loss
    g_loss_adv = -torch.mean(F.logsigmoid(fake_D_logits - real_D_logits))

    self.g_loss = g_loss_adv + l1_loss + fake_category_loss + const_loss
    self.g_loss.backward()
    return const_loss, l1_loss, g_loss_adv

這樣就完成了 相對論對抗損失(Relativistic Loss) 的實作!
這種改進可以讓判別器學習「相對」的差異,從而提升訓練穩定性,並減少模式崩潰(mode collapse)的風險。 🚀


chatgpt 檢查程式碼正確性:

這段程式碼已經正確地將 Relativistic Loss 應用到 DiscriminatorGenerator。我會進一步檢查並確保它的數學公式和 PyTorch 操作方式是合理的。


檢查與修正

  1. Relativistic Discriminator Loss
    • 原公式:
      LD=−E[log⁡σ(D(xreal)−D(xfake))]−E[log⁡(1−σ(D(xfake)−D(xreal)))]L_D = -\mathbb{E}[\log\sigma(D(x_{real}) – D(x_{fake}))] -\mathbb{E}[\log(1 – \sigma(D(x_{fake}) – D(x_{real})))]
    • 你的程式碼:
      d_loss_real = torch.mean(F.logsigmoid(real_D_logits - fake_D_logits)) d_loss_fake = torch.mean(F.logsigmoid(fake_D_logits - real_D_logits)) self.d_loss = - (d_loss_real + d_loss_fake) + category_loss / 2.0
      符合公式,數學正確!
  2. Relativistic Generator Loss
    • 原公式:
      LG=−E[log⁡σ(D(xfake)−D(xreal))]L_G = -\mathbb{E}[\log\sigma(D(x_{fake}) – D(x_{real}))]
    • 你的程式碼:
      g_loss_adv = -torch.mean(F.logsigmoid(fake_D_logits - real_D_logits))
      符合公式,數學正確!
  3. self.real_binary_loss 沒有存取
    • 你原本的 self.real_binary_loss = BinaryLoss(True) 確實沒有再被使用,可以刪除 BinaryLoss 這個額外的計算。
  4. 梯度回傳 (backward)
    • self.d_loss.backward()self.g_loss.backward() 都有正常執行
    • 確保 backward() 正確地回傳梯度

優化建議

你的 backward_D()backward_G() 邏輯已經沒問題,但可以進一步優化:

  1. 簡化 Relativistic Loss 寫法
    d_loss = torch.mean(F.logsigmoid(real_D_logits - fake_D_logits) + F.logsigmoid(fake_D_logits - real_D_logits)) self.d_loss = - d_loss + category_loss / 2.0
    這樣可以減少變數存取次數,提高可讀性。
  2. 刪除 self.real_binary_lossself.fake_binary_loss
    del self.real_binary_loss
    del self.fake_binary_loss
    這樣可以避免混淆,因為它們已經不再被使用。

最終修正版本

請查看這個經過優化的版本:

def backward_D(self, no_target_source=False):
    real_AB = torch.cat([self.real_A, self.real_B], 1)
    fake_AB = torch.cat([self.real_A, self.fake_B.detach()], 1)

    real_D_logits, real_category_logits = self.netD(real_AB)
    fake_D_logits, fake_category_logits = self.netD(fake_AB)

    real_category_loss = self.category_loss(real_category_logits, self.labels)
    fake_category_loss = self.category_loss(fake_category_logits, self.labels)
    category_loss = (real_category_loss + fake_category_loss) * self.Lcategory_penalty

    # Relativistic Discriminator Loss
    d_loss = torch.mean(F.logsigmoid(real_D_logits - fake_D_logits) + 
                         F.logsigmoid(fake_D_logits - real_D_logits))

    self.d_loss = - d_loss + category_loss / 2.0
    self.d_loss.backward()
    return category_loss

def backward_G(self, no_target_source=False):
    fake_AB = torch.cat([self.real_A, self.fake_B], 1)
    real_AB = torch.cat([self.real_A, self.real_B], 1)

    fake_D_logits, fake_category_logits = self.netD(fake_AB)
    real_D_logits, _ = self.netD(real_AB)

    const_loss = self.Lconst_penalty * self.mse(self.encoded_real_A, self.encoded_fake_B)
    l1_loss = self.L1_penalty * self.l1_loss(self.fake_B, self.real_B)
    fake_category_loss = self.Lcategory_penalty * self.category_loss(fake_category_logits, self.labels)

    # Relativistic Generator Loss
    g_loss_adv = -torch.mean(F.logsigmoid(fake_D_logits - real_D_logits))

    self.g_loss = g_loss_adv + l1_loss + fake_category_loss + const_loss
    self.g_loss.backward()
    return const_loss, l1_loss, g_loss_adv

總結

你的 Relativistic Loss 計算是正確的
修正了 self.real_binary_loss 未被存取的問題
最佳化了 Loss 計算,提高可讀性與效能

這樣應該能讓你的 GAN 訓練更加穩定!🔥 🚀

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *