d_loss和g_loss都變成了nan,訓練過程中出現了數值不穩定或梯度爆炸的問題

使用半精度訓練, 1個 epoch, 從 156 秒降為 114 秒, 效能提升很多, 但是半精度之下, 不管 norm_layer 使用 BatchNorm2d 或 InstanceNorm2d, 都很容易產生 nan 中止程式.

半精度, 在 colab 上的計算, 也是很明顯, 一樣的資料與參數, 使用半精度 1個 epoch 花 87秒, 不修改的話是 133秒, 而且使用的RAM 也有差, 用半精度可以使用較高的 batch size 來填滿 memory.


訓練log:

Epoch: [ 0], [ 0/ 21] time: 6, d_loss: 12.15291, g_loss: 123.63411, cheat_loss: 0.69336, const_loss: 6.89809, l1_loss: 96.99075, fm_loss: 0.52934, vgg_loss: 1.77232
Epoch: [ 1], [ 0/ 21] time: 85, d_loss: nan, g_loss: nan, cheat_loss: nan, const_loss: 0.00036, l1_loss: 8.88202, fm_loss: nan, vgg_loss: 0.56346
Epoch: [ 2], [ 0/ 21] time: 165, d_loss: nan, g_loss: nan, cheat_loss: nan, const_loss: 0.00028, l1_loss: 7.43671, fm_loss: nan, vgg_loss: 0.45461
Epoch: [ 3], [ 0/ 21] time: 245, d_loss: nan, g_loss: nan, cheat_loss: nan, const_loss: 0.00040, l1_loss: 10.21646, fm_loss: nan, vgg_loss: 0.60902

從您提供的訓練日誌來看,d_lossg_lossepoch=1 時突然變成了 nan,這表明在 epoch=0epoch=1 之間發生了數值不穩定。以下是可能的原因和排查步驟:

可能原因:

  1. 學習率過高:
    • 即使您使用了 CosineAnnealingLR,初始學習率也可能過高,導致在第一次更新權重時出現梯度爆炸。
  2. 梯度爆炸:
    • 網路中的某些層可能產生了非常大的梯度,導致權重更新過大,最終導致 nan
  3. 數值不穩定操作:
    • 網路中的某些操作可能導致數值不穩定,例如除以零、取負數的平方根或指數函數溢出。
  4. 判別器過於強大:
    • 判別器在 epoch=0 時可能快速收斂,導致生成器難以產生有效的梯度,最終導致 nan
  5. 資料問題:
    • 資料集中可能存在異常值或損壞的資料,導致訓練不穩定。
  6. 感知損失(VGG 損失):
    • 感知損失的權重可能過高,或者 VGG 模型的某些層可能產生了不穩定的輸出。
  7. 混合精度問題:
    • 雖然您使用了 GradScaler,但可能存在混合精度相關的問題,例如縮放因子不正確或資料類型轉換錯誤。

排查步驟:

  1. 降低學習率:
    • 嘗試降低初始學習率,例如從 0.001 降低到 0.0002 或更低。
  2. 更嚴格的梯度裁剪:
    • 降低 gradient_clip 的值,例如從 1.0 降低到 0.5 或更低。
  3. 檢查網路輸出:
    • epoch=0 的每個批次之後,輸出網路的輸出值(例如 fake_Breal_D_logitsfake_D_logits),以檢查是否存在異常值。
  4. 檢查梯度:
    • epoch=0 的每個批次之後,檢查網路的梯度值,以檢查是否存在梯度爆炸。
    • 可以使用 torch.nn.utils.clip_grad_norm_ 返回的數值來檢查梯度數值。
  5. 檢查損失函數:
    • epoch=0 的每個批次之後,輸出每個損失項的值,以檢查是否存在異常值。
  6. 檢查資料集:
    • 檢查資料集是否存在異常值或損壞的資料。
    • 確保資料已經過適當的預處理。
  7. 降低感知損失權重:
    • 嘗試降低感知損失的權重,例如從 10 降低到 5 或更低。
  8. 逐步增加網路複雜度:
    • 如果可能,嘗試使用更小的網路或更少的層,然後逐步增加網路的複雜度。
  9. 檢查混合精度:
    • 確保混合精度訓練正確啟用,並且 GradScaler 的使用方式正確。
    • 可以嘗試關閉混合精度訓練,看看是否仍然出現 nan
  10. 檢查數值範圍:
    • 檢查輸入資料,網路權重,以及網路中間層的輸出數值範圍是否合理。
  11. 增加檢查點:
    • 在epoch=0 的訓練過程中,增加檢查點,保存模型,以便於分析。
  12. 檢查是否有除零的狀況:
    • 檢查程式碼是否有除零的狀況,特別是自定義的損失函數。
  13. 檢查是否有log(0)的狀況:
    • 檢查程式碼是否有log(0)的狀況,特別是損失函數。

具體建議:

  • epoch=0 的每個批次之後,輸出 self.d_lossself.g_lossself.fake_Breal_D_logitsfake_D_logits 的值。
  • epoch=0 的每個批次之後,使用 torch.isnan() 檢查這些值,並在出現 nan 時停止訓練。
  • epoch=0 的每個批次之後,使用 torch.nn.utils.clip_grad_norm_ 檢查梯度值。

通過這些排查步驟,您可以找到導致 nan 損失的根本原因,並採取適當的措施來解決它。


好的,為了降低感知損失(VGG 損失)的權重,您需要修改 Zi2ZiModel 類別中的 backward_G 方法。以下是修改後的程式碼片段:

Python

class Zi2ZiModel:
    # ... (其他程式碼保持不變)

    def backward_G(self, no_target_source=False):
        fake_AB = torch.cat([self.real_A, self.fake_B], 1)
        real_AB = torch.cat([self.real_A, self.real_B], 1)

        fake_D_logits, fake_category_logits = self.netD(fake_AB)
        real_D_logits, _ = self.netD(real_AB)

        const_loss = self.Lconst_penalty * self.mse(self.encoded_real_A, self.encoded_fake_B)
        l1_loss = self.L1_penalty * self.l1_loss(self.fake_B, self.real_B)
        fake_category_loss = self.Lcategory_penalty * self.category_loss(fake_category_logits, self.labels)
        g_loss_adv = -torch.mean(F.logsigmoid(fake_D_logits - real_D_logits))

        fm_loss = self.compute_feature_matching_loss(real_AB, fake_AB)

        self.g_loss = g_loss_adv + l1_loss + fake_category_loss + const_loss + fm_loss

        perceptual_loss = self.vgg_loss(self.fake_B, self.real_B)

        # 修改這裡:降低感知損失的權重
        perceptual_weight = 5.0  # 將權重從 10 降低到 5 或更低
        self.g_loss += perceptual_weight * perceptual_loss

        return const_loss, l1_loss, g_loss_adv, fm_loss, perceptual_loss

修改說明:

  • backward_G 方法中,找到計算 perceptual_loss 的部分。
  • perceptual_weight 變數的值從 10.0 降低到 5.0(或更低,您可以根據需要調整)。
  • self.g_loss 更新為 self.g_loss += perceptual_weight * perceptual_loss

額外建議:

  • 逐步調整權重: 如果您不確定最佳權重,可以逐步調整 perceptual_weight 的值,例如從 10 降低到 8,然後再降低到 5,並觀察訓練結果。
  • 監控損失值: 在訓練過程中,監控 vgg_loss 的值,以確保它不會過大或過小。
  • 嘗試其他感知損失: 如果 VGG 損失仍然導致訓練不穩定,您可以嘗試使用其他感知損失,例如 LPIPS(Learned Perceptual Image Patch Similarity)。

通過這些修改,您可以降低感知損失的權重,從而減少其對訓練的影響。


好的,為了檢查您的程式碼中是否存在 log(0) 的情況,我們需要仔細檢查所有使用對數函數的地方。以下是一些常見的情況和修改建議:

1. 損失函數中的對數:

  • 二元交叉熵損失(Binary Cross-Entropy Loss):
    • 如果您使用 torch.nn.BCELosstorch.nn.BCEWithLogitsLoss,它們內部會計算對數。
    • 確保輸入到這些損失函數的機率值(或 logits)在 01 之間(或避免無限大的 logits)。
    • 如果你的輸出是機率,確保輸出不要出現0或1,可以使用torch.clamp將數值限制在一個範圍內,例如: Pythonprobability = torch.clamp(probability, min=1e-7, max=1-1e-7)
  • 交叉熵損失(Cross-Entropy Loss):
    • 如果您使用 torch.nn.CrossEntropyLoss,它內部也會計算對數。
    • 確保輸入到這個損失函數的 logits 值是有限的。
  • 自定義損失函數:
    • 如果您有任何自定義的損失函數,請檢查它們是否使用 torch.log()math.log()
    • 在計算對數之前,添加檢查以確保輸入值不為 0

2. 網路輸出中的對數:

  • 檢查您的網路是否在任何層中使用 torch.log()
  • 確保輸入到對數函數的值始終為正數。

3. 程式碼修改範例:

  • 在自定義損失函數中添加檢查:

Python

import torch

def custom_loss(x):
    # 假設 x 是一個張量
    # 檢查 x 是否包含 0
    if torch.any(x <= 0):
        print("警告:輸入包含 0,將其替換為一個小的正數。")
        x = torch.clamp(x, min=1e-7)  # 將 0 替換為一個小的正數
    return torch.mean(torch.log(x))
  • 在網路輸出中添加檢查:

Python

import torch.nn.functional as F

def forward(self, x):
    # ... 網路的前向傳播 ...
    output = F.relu(some_layer(x))  # 確保輸出為正數
    # 檢查 output 是否包含 0
    if torch.any(output <= 0):
        print("警告:輸出包含 0,將其替換為一個小的正數。")
        output = torch.clamp(output, min=1e-7)
    log_output = torch.log(output)
    return log_output

4. 具體建議:

  • 仔細檢查 Zi2ZiModel 類別中的 backward_Dbackward_G 方法,以及所有使用的損失函數。
  • 在計算對數之前,使用 torch.any(x <= 0) 檢查輸入張量 x 是否包含 0
  • 如果發現 0,可以使用 torch.clamp(x, min=1e-7) 將其替換為一個小的正數。
  • 在訓練過程中,定期輸出網路的輸出值和損失值,以監控是否存在異常值。

通過這些檢查和修改,您可以有效地避免 log(0) 的問題,並提高訓練的穩定性。


好的,根據您提供的程式碼,我們將針對 log(0) 的問題以及之前提到的感知損失權重調整進行修改。

修改後的 Zi2ZiModel 類別:

Python

import torch
import torch.nn as nn
import torch.nn.functional as F

# ... (其他 import 和類別定義,例如 UNetGenerator, Discriminator, init_net, PerceptualLoss, CategoryLoss)

class Zi2ZiModel:
    def __init__(self, input_nc=1, embedding_num=40, embedding_dim=128, ngf=64, ndf=64,
                 Lconst_penalty=10, Lcategory_penalty=1, L1_penalty=100,
                 schedule=10, lr=0.001, gpu_ids=None, save_dir='.', is_training=True,
                 self_attention=False, residual_block=False,
                 weight_decay=1e-5, final_channels=1, beta1=0.5, g_blur=False, d_blur=False, epoch=40,
                 gradient_clip=0.5):
        # ... (其他初始化程式碼保持不變)
        self.gradient_clip = gradient_clip

    def setup(self):
        # ... (setup 程式碼保持不變)

    def set_input(self, labels, real_A, real_B):
        # ... (set_input 程式碼保持不變)

    def forward(self):
        # ... (forward 程式碼保持不變)

    def compute_feature_matching_loss(self, real_AB, fake_AB):
        # ... (compute_feature_matching_loss 程式碼保持不變)

    def compute_gradient_penalty(self, real_samples, fake_samples):
        # ... (compute_gradient_penalty 程式碼保持不變)

    def backward_D(self, no_target_source=False):
        real_AB = torch.cat([self.real_A, self.real_B], 1)
        fake_AB = torch.cat([self.real_A, self.fake_B.detach()], 1)

        real_D_logits, real_category_logits = self.netD(real_AB)
        fake_D_logits, fake_category_logits = self.netD(fake_AB)

        real_category_loss = self.category_loss(real_category_logits, self.labels)
        fake_category_loss = self.category_loss(fake_category_logits, self.labels)
        category_loss = (real_category_loss + fake_category_loss) * self.Lcategory_penalty

        # 避免 log(0)
        real_D_probs = torch.sigmoid(real_D_logits)
        fake_D_probs = torch.sigmoid(fake_D_logits)
        real_D_probs = torch.clamp(real_D_probs, min=1e-7, max=1-1e-7)
        fake_D_probs = torch.clamp(fake_D_probs, min=1e-7, max=1-1e-7)

        d_loss = torch.mean(torch.log(real_D_probs) - torch.log(fake_D_probs) +
                            torch.log(1 - fake_D_probs) - torch.log(1 - real_D_probs))

        gp = self.compute_gradient_penalty(real_AB, fake_AB)

        gradient_penalty_weight = 10.0
        self.d_loss = -d_loss + category_loss / 2.0 + gradient_penalty_weight * gp

        return category_loss

    def backward_G(self, no_target_source=False):
        fake_AB = torch.cat([self.real_A, self.fake_B], 1)
        real_AB = torch.cat([self.real_A, self.real_B], 1)

        fake_D_logits, fake_category_logits = self.netD(fake_AB)
        real_D_logits, _ = self.netD(real_AB)

        const_loss = self.Lconst_penalty * self.mse(self.encoded_real_A, self.encoded_fake_B)
        l1_loss = self.L1_penalty * self.l1_loss(self.fake_B, self.real_B)
        fake_category_loss = self.Lcategory_penalty * self.category_loss(fake_category_logits, self.labels)

        # 避免 log(0)
        fake_D_probs = torch.sigmoid(fake_D_logits)
        real_D_probs = torch.sigmoid(real_D_logits)
        fake_D_probs = torch.clamp(fake_D_probs, min=1e-7, max=1-1e-7)
        real_D_probs = torch.clamp(real_D_probs, min=1e-7, max=1-1e-7)

        g_loss_adv = -torch.mean(torch.log(fake_D_probs) - torch.log(real_D_probs))

        fm_loss = self.compute_feature_matching_loss(real_AB, fake_AB)

        self.g_loss = g_loss_adv + l1_loss + fake_category_loss + const_loss + fm_loss

        perceptual_loss = self.vgg_loss(self.fake_B, self.real_B)
        perceptual_weight = 5.0  # 感知損失的權重
        self.g_loss += perceptual_weight * perceptual_loss

        return const_loss, l1_loss, g_loss_adv, fm_loss, perceptual_loss

    def optimize_parameters(self, use_autocast=False):
        # ... (optimize_parameters 程式碼保持不變)

修改說明:

  1. backward_Dbackward_G 中的 log(0) 處理:
    • 在計算 d_lossg_loss_adv 之前,我們將 real_D_logitsfake_D_logits 通過 torch.sigmoid() 轉換為機率值。
    • 然後,使用 torch.clamp() 將機率值限制在 [1e-7, 1-1e-7] 的範圍內,以避免 log(0)
    • F.logsigmoid 改為 torch.log(torch.sigmoid())
  2. 感知損失權重調整:
    • backward_G 方法中的 perceptual_weight 變數已設定為 5.0,以降低感知損失的權重。

重要注意事項:

  • 確保您的 UNetGeneratorDiscriminatorinit_netPerceptualLossCategoryLoss 類別定義正確。
  • 在訓練過程中,定期監控損失值和網路輸出,以確保訓練穩定。
  • 如果仍然遇到 nan 值,請嘗試進一步調整超參數,例如學習率、梯度裁剪值和損失權重。
  • 檢查你的輸入資料是否有問題。

你的訓練出現 nan(非數值)錯誤,這通常是由以下幾種可能原因造成的:

問題排查與解決方案


1. Loss 出現 nan

你的 d_lossg_losscheat_lossfm_loss 都變成 nan,代表梯度可能爆炸,這可能是:

  • 學習率過高
  • 梯度爆炸
  • log(0)sqrt(負數) 等非法運算
  • 數值不穩定(如 float16 精度損失)

解決方案

  • 降低學習率lr=0.001 改為 lr=0.0001 或更低)
  • 梯度裁剪(Gradient Clipping) torch.nn.utils.clip_grad_norm_(self.netG.parameters(), 1.0) torch.nn.utils.clip_grad_norm_(self.netD.parameters(), 1.0)
  • 檢查 nan 傳播forward() 內部加入: if torch.isnan(loss).any(): print("Loss contains NaN! Stopping training.") exit()

2. 生成器 (netG) 輸出全白(空白)

如果推論出的圖片是空白,可能是:

  • 過擬合 / 梯度消失
  • BatchNorm / InstanceNorm 除零
  • ReLU 變成死區(Dying ReLU)

解決方案

  • 改用 LayerNorm norm_layer=nn.LayerNorm # 取代 `InstanceNorm2d`
  • 檢查 Tanh()forward() 裡加: if torch.isnan(fake_B).any(): print("Fake_B contains NaN!")
  • 試著用 LeakyReLU(0.01) downrelu = nn.LeakyReLU(0.01, True)

3. Discriminator (netD) nan 問題

  • Discriminatorcategory_loss 可能因為 log(0)nan
  • BCEWithLogitsLoss() 的 target 不能是 01

解決方案

  • 改為 smooth labels(標籤平滑化) target = target * 0.9 + 0.05

4. 256×256 訓練資料

  • 可能太小,無法保留細節
  • 試試 512x512128x128 看效果

結論

你可以先 降低學習率、加梯度裁剪、檢查 NaN,再改 LayerNorm & LeakyReLU,看看能不能改善訓練。

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *