增加梯度裁剪,限制梯度的最大值,從而防止梯度爆炸

好的,為了在您的字型風格轉換模型中加入梯度裁剪,我們需要在 optimize_parameters 方法中,在反向傳播之後,對生成器(G)和判別器(D)的梯度進行裁剪。

以下是修改後的 optimize_parameters 方法:

Python

class Zi2ZiModel:
    # ... (其他程式碼)

    def optimize_parameters(self, use_autocast=False):
        self.forward()
        self.set_requires_grad(self.netD, True)
        self.optimizer_D.zero_grad()
        if use_autocast:
            with torch.amp.autocast(device_type='cuda'):
                category_loss = self.backward_D()
                scaled_d_loss = self.scaler_D.scale(self.d_loss)
                scaled_d_loss.backward()
                self.scaler_D.step(self.optimizer_D)
                self.scaler_D.update()
        else:
            category_loss = self.backward_D()
            self.d_loss.backward()
            self.optimizer_D.step()

        # 梯度裁剪 (判別器)
        torch.nn.utils.clip_grad_norm_(self.netD.parameters(), max_norm=1.0) # 梯度裁剪 max_norm=1.0

        self.set_requires_grad(self.netD, False)
        self.optimizer_G.zero_grad()
        const_loss, l1_loss, cheat_loss, fm_loss = 0, 0, 0, 0

        if use_autocast:
            with torch.amp.autocast(device_type='cuda'):
                const_loss, l1_loss, cheat_loss, fm_loss = self.backward_G()
                scaled_g_loss = self.scaler_G.scale(self.g_loss)
                scaled_g_loss.backward()
                self.scaler_G.step(self.optimizer_G)
                self.scaler_G.update()
        else:
            const_loss, l1_loss, cheat_loss, fm_loss = self.backward_G()
            self.g_loss.backward()
            self.optimizer_G.step()

        # 梯度裁剪 (生成器)
        torch.nn.utils.clip_grad_norm_(self.netG.parameters(), max_norm=1.0) # 梯度裁剪 max_norm=1.0

        self.forward()
        self.optimizer_G.zero_grad()

        if use_autocast:
            with torch.amp.autocast(device_type='cuda'):
                const_loss, l1_loss, cheat_loss, fm_loss = self.backward_G()
                scaled_g_loss = self.scaler_G.scale(self.g_loss)
                scaled_g_loss.backward()
                self.scaler_G.step(self.optimizer_G)
                self.scaler_G.update()
        else:
            const_loss, l1_loss, cheat_loss, fm_loss = self.backward_G()
            self.g_loss.backward()
            self.optimizer_G.step()

        # 梯度裁剪 (生成器)
        torch.nn.utils.clip_grad_norm_(self.netG.parameters(), max_norm=1.0) # 梯度裁剪 max_norm=1.0

        return const_loss, l1_loss, cheat_loss, fm_loss

修改重點:

  1. 添加梯度裁剪:
    • self.optimizer_D.step() 之後,添加 torch.nn.utils.clip_grad_norm_(self.netD.parameters(), max_norm=1.0)
    • self.optimizer_G.step() 之後,添加 torch.nn.utils.clip_grad_norm_(self.netG.parameters(), max_norm=1.0)
  2. max_norm 參數:
    • max_norm=1.0 指定梯度的最大範數。您可以根據需要調整這個值。
    • 較小的 max_norm 值會更積極地裁剪梯度,但可能會導致訓練速度變慢。

注意事項:

  • 梯度裁剪應該在 optimizer.step() 之後進行,以確保梯度已經計算出來。
  • 您可以根據您的具體任務和模型,調整 max_norm 的值。
  • 梯度裁剪是一種常用的技術,用於防止梯度爆炸,但它也可能導致梯度消失。

通過這些修改,您的模型將會使用梯度裁剪來限制梯度的最大值,從而防止梯度爆炸。

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *