使用半精度訓練, 1個 epoch, 從 156 秒降為 114 秒, 效能提升很多, 但是半精度之下, 不管 norm_layer 使用 BatchNorm2d 或 InstanceNorm2d, 都很容易產生 nan 中止程式.
半精度, 在 colab 上的計算, 也是很明顯, 一樣的資料與參數, 使用半精度 1個 epoch 花 87秒, 不修改的話是 133秒, 而且使用的RAM 也有差, 用半精度可以使用較高的 batch size 來填滿 memory.
訓練log:
Epoch: [ 0], [ 0/ 21] time: 6, d_loss: 12.15291, g_loss: 123.63411, cheat_loss: 0.69336, const_loss: 6.89809, l1_loss: 96.99075, fm_loss: 0.52934, vgg_loss: 1.77232
Epoch: [ 1], [ 0/ 21] time: 85, d_loss: nan, g_loss: nan, cheat_loss: nan, const_loss: 0.00036, l1_loss: 8.88202, fm_loss: nan, vgg_loss: 0.56346
Epoch: [ 2], [ 0/ 21] time: 165, d_loss: nan, g_loss: nan, cheat_loss: nan, const_loss: 0.00028, l1_loss: 7.43671, fm_loss: nan, vgg_loss: 0.45461
Epoch: [ 3], [ 0/ 21] time: 245, d_loss: nan, g_loss: nan, cheat_loss: nan, const_loss: 0.00040, l1_loss: 10.21646, fm_loss: nan, vgg_loss: 0.60902
從您提供的訓練日誌來看,d_loss
和 g_loss
在 epoch=1
時突然變成了 nan
,這表明在 epoch=0
到 epoch=1
之間發生了數值不穩定。以下是可能的原因和排查步驟:
可能原因:
- 學習率過高:
- 即使您使用了
CosineAnnealingLR
,初始學習率也可能過高,導致在第一次更新權重時出現梯度爆炸。
- 即使您使用了
- 梯度爆炸:
- 網路中的某些層可能產生了非常大的梯度,導致權重更新過大,最終導致
nan
。
- 網路中的某些層可能產生了非常大的梯度,導致權重更新過大,最終導致
- 數值不穩定操作:
- 網路中的某些操作可能導致數值不穩定,例如除以零、取負數的平方根或指數函數溢出。
- 判別器過於強大:
- 判別器在
epoch=0
時可能快速收斂,導致生成器難以產生有效的梯度,最終導致nan
。
- 判別器在
- 資料問題:
- 資料集中可能存在異常值或損壞的資料,導致訓練不穩定。
- 感知損失(VGG 損失):
- 感知損失的權重可能過高,或者 VGG 模型的某些層可能產生了不穩定的輸出。
- 混合精度問題:
- 雖然您使用了
GradScaler
,但可能存在混合精度相關的問題,例如縮放因子不正確或資料類型轉換錯誤。
- 雖然您使用了
排查步驟:
- 降低學習率:
- 嘗試降低初始學習率,例如從 0.001 降低到 0.0002 或更低。
- 更嚴格的梯度裁剪:
- 降低
gradient_clip
的值,例如從 1.0 降低到 0.5 或更低。
- 降低
- 檢查網路輸出:
- 在
epoch=0
的每個批次之後,輸出網路的輸出值(例如fake_B
、real_D_logits
、fake_D_logits
),以檢查是否存在異常值。
- 在
- 檢查梯度:
- 在
epoch=0
的每個批次之後,檢查網路的梯度值,以檢查是否存在梯度爆炸。 - 可以使用
torch.nn.utils.clip_grad_norm_
返回的數值來檢查梯度數值。
- 在
- 檢查損失函數:
- 在
epoch=0
的每個批次之後,輸出每個損失項的值,以檢查是否存在異常值。
- 在
- 檢查資料集:
- 檢查資料集是否存在異常值或損壞的資料。
- 確保資料已經過適當的預處理。
- 降低感知損失權重:
- 嘗試降低感知損失的權重,例如從 10 降低到 5 或更低。
- 逐步增加網路複雜度:
- 如果可能,嘗試使用更小的網路或更少的層,然後逐步增加網路的複雜度。
- 檢查混合精度:
- 確保混合精度訓練正確啟用,並且
GradScaler
的使用方式正確。 - 可以嘗試關閉混合精度訓練,看看是否仍然出現
nan
。
- 確保混合精度訓練正確啟用,並且
- 檢查數值範圍:
- 檢查輸入資料,網路權重,以及網路中間層的輸出數值範圍是否合理。
- 增加檢查點:
- 在epoch=0 的訓練過程中,增加檢查點,保存模型,以便於分析。
- 檢查是否有除零的狀況:
- 檢查程式碼是否有除零的狀況,特別是自定義的損失函數。
- 檢查是否有log(0)的狀況:
- 檢查程式碼是否有log(0)的狀況,特別是損失函數。
具體建議:
- 在
epoch=0
的每個批次之後,輸出self.d_loss
、self.g_loss
、self.fake_B
、real_D_logits
、fake_D_logits
的值。 - 在
epoch=0
的每個批次之後,使用torch.isnan()
檢查這些值,並在出現nan
時停止訓練。 - 在
epoch=0
的每個批次之後,使用torch.nn.utils.clip_grad_norm_
檢查梯度值。
通過這些排查步驟,您可以找到導致 nan
損失的根本原因,並採取適當的措施來解決它。
好的,為了降低感知損失(VGG 損失)的權重,您需要修改 Zi2ZiModel
類別中的 backward_G
方法。以下是修改後的程式碼片段:
Python
class Zi2ZiModel:
# ... (其他程式碼保持不變)
def backward_G(self, no_target_source=False):
fake_AB = torch.cat([self.real_A, self.fake_B], 1)
real_AB = torch.cat([self.real_A, self.real_B], 1)
fake_D_logits, fake_category_logits = self.netD(fake_AB)
real_D_logits, _ = self.netD(real_AB)
const_loss = self.Lconst_penalty * self.mse(self.encoded_real_A, self.encoded_fake_B)
l1_loss = self.L1_penalty * self.l1_loss(self.fake_B, self.real_B)
fake_category_loss = self.Lcategory_penalty * self.category_loss(fake_category_logits, self.labels)
g_loss_adv = -torch.mean(F.logsigmoid(fake_D_logits - real_D_logits))
fm_loss = self.compute_feature_matching_loss(real_AB, fake_AB)
self.g_loss = g_loss_adv + l1_loss + fake_category_loss + const_loss + fm_loss
perceptual_loss = self.vgg_loss(self.fake_B, self.real_B)
# 修改這裡:降低感知損失的權重
perceptual_weight = 5.0 # 將權重從 10 降低到 5 或更低
self.g_loss += perceptual_weight * perceptual_loss
return const_loss, l1_loss, g_loss_adv, fm_loss, perceptual_loss
修改說明:
- 在
backward_G
方法中,找到計算perceptual_loss
的部分。 - 將
perceptual_weight
變數的值從10.0
降低到5.0
(或更低,您可以根據需要調整)。 - 將
self.g_loss
更新為self.g_loss += perceptual_weight * perceptual_loss
。
額外建議:
- 逐步調整權重: 如果您不確定最佳權重,可以逐步調整
perceptual_weight
的值,例如從 10 降低到 8,然後再降低到 5,並觀察訓練結果。 - 監控損失值: 在訓練過程中,監控
vgg_loss
的值,以確保它不會過大或過小。 - 嘗試其他感知損失: 如果 VGG 損失仍然導致訓練不穩定,您可以嘗試使用其他感知損失,例如 LPIPS(Learned Perceptual Image Patch Similarity)。
通過這些修改,您可以降低感知損失的權重,從而減少其對訓練的影響。
好的,為了檢查您的程式碼中是否存在 log(0)
的情況,我們需要仔細檢查所有使用對數函數的地方。以下是一些常見的情況和修改建議:
1. 損失函數中的對數:
- 二元交叉熵損失(Binary Cross-Entropy Loss):
- 如果您使用
torch.nn.BCELoss
或torch.nn.BCEWithLogitsLoss
,它們內部會計算對數。 - 確保輸入到這些損失函數的機率值(或 logits)在
0
和1
之間(或避免無限大的 logits)。 - 如果你的輸出是機率,確保輸出不要出現0或1,可以使用torch.clamp將數值限制在一個範圍內,例如: Python
probability = torch.clamp(probability, min=1e-7, max=1-1e-7)
- 如果您使用
- 交叉熵損失(Cross-Entropy Loss):
- 如果您使用
torch.nn.CrossEntropyLoss
,它內部也會計算對數。 - 確保輸入到這個損失函數的 logits 值是有限的。
- 如果您使用
- 自定義損失函數:
- 如果您有任何自定義的損失函數,請檢查它們是否使用
torch.log()
或math.log()
。 - 在計算對數之前,添加檢查以確保輸入值不為
0
。
- 如果您有任何自定義的損失函數,請檢查它們是否使用
2. 網路輸出中的對數:
- 檢查您的網路是否在任何層中使用
torch.log()
。 - 確保輸入到對數函數的值始終為正數。
3. 程式碼修改範例:
- 在自定義損失函數中添加檢查:
Python
import torch
def custom_loss(x):
# 假設 x 是一個張量
# 檢查 x 是否包含 0
if torch.any(x <= 0):
print("警告:輸入包含 0,將其替換為一個小的正數。")
x = torch.clamp(x, min=1e-7) # 將 0 替換為一個小的正數
return torch.mean(torch.log(x))
- 在網路輸出中添加檢查:
Python
import torch.nn.functional as F
def forward(self, x):
# ... 網路的前向傳播 ...
output = F.relu(some_layer(x)) # 確保輸出為正數
# 檢查 output 是否包含 0
if torch.any(output <= 0):
print("警告:輸出包含 0,將其替換為一個小的正數。")
output = torch.clamp(output, min=1e-7)
log_output = torch.log(output)
return log_output
4. 具體建議:
- 仔細檢查
Zi2ZiModel
類別中的backward_D
和backward_G
方法,以及所有使用的損失函數。 - 在計算對數之前,使用
torch.any(x <= 0)
檢查輸入張量x
是否包含0
。 - 如果發現
0
,可以使用torch.clamp(x, min=1e-7)
將其替換為一個小的正數。 - 在訓練過程中,定期輸出網路的輸出值和損失值,以監控是否存在異常值。
通過這些檢查和修改,您可以有效地避免 log(0)
的問題,並提高訓練的穩定性。
好的,根據您提供的程式碼,我們將針對 log(0)
的問題以及之前提到的感知損失權重調整進行修改。
修改後的 Zi2ZiModel
類別:
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
# ... (其他 import 和類別定義,例如 UNetGenerator, Discriminator, init_net, PerceptualLoss, CategoryLoss)
class Zi2ZiModel:
def __init__(self, input_nc=1, embedding_num=40, embedding_dim=128, ngf=64, ndf=64,
Lconst_penalty=10, Lcategory_penalty=1, L1_penalty=100,
schedule=10, lr=0.001, gpu_ids=None, save_dir='.', is_training=True,
self_attention=False, residual_block=False,
weight_decay=1e-5, final_channels=1, beta1=0.5, g_blur=False, d_blur=False, epoch=40,
gradient_clip=0.5):
# ... (其他初始化程式碼保持不變)
self.gradient_clip = gradient_clip
def setup(self):
# ... (setup 程式碼保持不變)
def set_input(self, labels, real_A, real_B):
# ... (set_input 程式碼保持不變)
def forward(self):
# ... (forward 程式碼保持不變)
def compute_feature_matching_loss(self, real_AB, fake_AB):
# ... (compute_feature_matching_loss 程式碼保持不變)
def compute_gradient_penalty(self, real_samples, fake_samples):
# ... (compute_gradient_penalty 程式碼保持不變)
def backward_D(self, no_target_source=False):
real_AB = torch.cat([self.real_A, self.real_B], 1)
fake_AB = torch.cat([self.real_A, self.fake_B.detach()], 1)
real_D_logits, real_category_logits = self.netD(real_AB)
fake_D_logits, fake_category_logits = self.netD(fake_AB)
real_category_loss = self.category_loss(real_category_logits, self.labels)
fake_category_loss = self.category_loss(fake_category_logits, self.labels)
category_loss = (real_category_loss + fake_category_loss) * self.Lcategory_penalty
# 避免 log(0)
real_D_probs = torch.sigmoid(real_D_logits)
fake_D_probs = torch.sigmoid(fake_D_logits)
real_D_probs = torch.clamp(real_D_probs, min=1e-7, max=1-1e-7)
fake_D_probs = torch.clamp(fake_D_probs, min=1e-7, max=1-1e-7)
d_loss = torch.mean(torch.log(real_D_probs) - torch.log(fake_D_probs) +
torch.log(1 - fake_D_probs) - torch.log(1 - real_D_probs))
gp = self.compute_gradient_penalty(real_AB, fake_AB)
gradient_penalty_weight = 10.0
self.d_loss = -d_loss + category_loss / 2.0 + gradient_penalty_weight * gp
return category_loss
def backward_G(self, no_target_source=False):
fake_AB = torch.cat([self.real_A, self.fake_B], 1)
real_AB = torch.cat([self.real_A, self.real_B], 1)
fake_D_logits, fake_category_logits = self.netD(fake_AB)
real_D_logits, _ = self.netD(real_AB)
const_loss = self.Lconst_penalty * self.mse(self.encoded_real_A, self.encoded_fake_B)
l1_loss = self.L1_penalty * self.l1_loss(self.fake_B, self.real_B)
fake_category_loss = self.Lcategory_penalty * self.category_loss(fake_category_logits, self.labels)
# 避免 log(0)
fake_D_probs = torch.sigmoid(fake_D_logits)
real_D_probs = torch.sigmoid(real_D_logits)
fake_D_probs = torch.clamp(fake_D_probs, min=1e-7, max=1-1e-7)
real_D_probs = torch.clamp(real_D_probs, min=1e-7, max=1-1e-7)
g_loss_adv = -torch.mean(torch.log(fake_D_probs) - torch.log(real_D_probs))
fm_loss = self.compute_feature_matching_loss(real_AB, fake_AB)
self.g_loss = g_loss_adv + l1_loss + fake_category_loss + const_loss + fm_loss
perceptual_loss = self.vgg_loss(self.fake_B, self.real_B)
perceptual_weight = 5.0 # 感知損失的權重
self.g_loss += perceptual_weight * perceptual_loss
return const_loss, l1_loss, g_loss_adv, fm_loss, perceptual_loss
def optimize_parameters(self, use_autocast=False):
# ... (optimize_parameters 程式碼保持不變)
修改說明:
backward_D
和backward_G
中的log(0)
處理:- 在計算
d_loss
和g_loss_adv
之前,我們將real_D_logits
和fake_D_logits
通過torch.sigmoid()
轉換為機率值。 - 然後,使用
torch.clamp()
將機率值限制在[1e-7, 1-1e-7]
的範圍內,以避免log(0)
。 - 將
F.logsigmoid
改為torch.log(torch.sigmoid())
。
- 在計算
- 感知損失權重調整:
backward_G
方法中的perceptual_weight
變數已設定為5.0
,以降低感知損失的權重。
重要注意事項:
- 確保您的
UNetGenerator
、Discriminator
、init_net
、PerceptualLoss
和CategoryLoss
類別定義正確。 - 在訓練過程中,定期監控損失值和網路輸出,以確保訓練穩定。
- 如果仍然遇到
nan
值,請嘗試進一步調整超參數,例如學習率、梯度裁剪值和損失權重。 - 檢查你的輸入資料是否有問題。
你的訓練出現 nan
(非數值)錯誤,這通常是由以下幾種可能原因造成的:
問題排查與解決方案
1. Loss 出現 nan
你的 d_loss
、g_loss
、cheat_loss
和 fm_loss
都變成 nan
,代表梯度可能爆炸,這可能是:
- 學習率過高
- 梯度爆炸
log(0)
或sqrt(負數)
等非法運算- 數值不穩定(如
float16
精度損失)
解決方案
- 降低學習率(
lr=0.001
改為lr=0.0001
或更低) - 梯度裁剪(Gradient Clipping)
torch.nn.utils.clip_grad_norm_(self.netG.parameters(), 1.0) torch.nn.utils.clip_grad_norm_(self.netD.parameters(), 1.0)
- 檢查
nan
傳播 在forward()
內部加入:if torch.isnan(loss).any(): print("Loss contains NaN! Stopping training.") exit()
2. 生成器 (netG
) 輸出全白(空白)
如果推論出的圖片是空白,可能是:
- 過擬合 / 梯度消失
- BatchNorm / InstanceNorm 除零
- ReLU 變成死區(Dying ReLU)
解決方案
- 改用
LayerNorm
norm_layer=nn.LayerNorm # 取代 `InstanceNorm2d`
- 檢查
Tanh()
在forward()
裡加:if torch.isnan(fake_B).any(): print("Fake_B contains NaN!")
- 試著用
LeakyReLU(0.01)
downrelu = nn.LeakyReLU(0.01, True)
3. Discriminator (netD
) nan
問題
Discriminator
的category_loss
可能因為log(0)
變nan
BCEWithLogitsLoss()
的 target 不能是0
或1
解決方案
- 改為
smooth labels
(標籤平滑化)target = target * 0.9 + 0.05
4. 256×256 訓練資料
- 可能太小,無法保留細節
- 試試
512x512
或128x128
看效果
結論
你可以先 降低學習率、加梯度裁剪、檢查 NaN
,再改 LayerNorm
& LeakyReLU
,看看能不能改善訓練。