torch.isnan(d_loss) == True
代表你的 判別器損失中出現了 NaN(非數),這通常跟以下幾個狀況有關,尤其是你「部分載入權重」後出現這情況,代表初始化或 forward pass 某處有問題。
✅ 常見原因與解法:
1. 未載入的權重未被初始化(或初始化不當)
如果 netD
的某些層未載入,但也沒經過正確初始化,會導致輸出異常值 → d_loss
出現 NaN。
✅ 解法:
確保你有呼叫這段,並初始化了所有未載入的層:
self._initialize_unmatched_weights(self.netD, state_dict_D)
並且裡面使用的是 nn.init.kaiming_normal_()
而非 param.data.zero_()
等可能不穩定的初始化方式。
2. 學習率太高
若使用者變動了模型結構,之前的 learning rate 可能不再適用 → 直接爆梯度。
✅ 解法:
- 嘗試將
optimizer_D
的 learning rate 調低,例如從0.0002
→0.00005
試試。 - 搭配
gradient clipping
是好的,但不是萬靈丹。
3. 未初始化的 BatchNorm / Running Stats
如果你用了 BN 層,而這層未正確初始化 running_mean
和 running_var
,forward 就會出現 NaN。
✅ 解法: 在 _initialize_unmatched_weights()
中記得處理 named_buffers()
,像這樣:
for name, buffer in model.named_buffers():
if name not in loaded_state_dict:
print(f"🔄 Re-initializing buffer: {name}")
if "running_var" in name:
buffer.fill_(1)
elif "running_mean" in name:
buffer.zero_()
else:
buffer.zero_()
4. fake_B 未正確生成
如果你的 self.fake_B
是 generator 產出的,但 generator 尚未穩定或包含未初始化層,會導致它輸出 inf / nan → 傳遞給 netD
出錯。
✅ 解法:
- 加入 debug 檢查
self.fake_B
內部數值:
if torch.isnan(self.fake_B).any():
print("⚠️ fake_B contains NaN")
if torch.isinf(self.fake_B).any():
print("⚠️ fake_B contains Inf")
5. 損失函數內部出錯
部分損失(如 BCEWithLogitsLoss)若輸入極端值(如超大負數)會直接噴 NaN。
✅ 解法:
- 檢查
backward_D()
裡使用的 loss 函數輸入是否存在爆炸性值。 - 將其中 logits 印出檢查,例如:
print("D(real):", d_real_out.min().item(), d_real_out.max().item())
print("D(fake):", d_fake_out.min().item(), d_fake_out.max().item())
✅ 建議加強的 debug 區段(示範):
if torch.isnan(d_loss):
print("❌ 判別器損失為 NaN")
print(" - fake_B NaN:", torch.isnan(self.fake_B).any().item())
print(" - fake_B max:", self.fake_B.max().item(), "min:", self.fake_B.min().item())
for name, param in self.netD.named_parameters():
if torch.isnan(param).any():
print(f" - ❌ NaN in netD param: {name}")
return
🧠 總結建議:
- ✅ 確保
_initialize_unmatched_weights()
初始化所有未載入的權重與 buffer。 - ✅ 對
fake_B
做 NaN / Inf 偵測。 - ✅ 降低 learning rate 或加強 gradient clipping。
- ✅ 增加中間層輸出的值範圍檢查(尤其是
D(real)
、D(fake)
的輸出)。 - ✅ 確認
backward_D
的 loss 沒誤用未初始化的輸入。
如果你想,我可以幫你加入自動 NaN 偵測與警示系統,讓訓練中一旦爆掉就知道是哪一層出的問題。需要嗎?
太好了,那我來幫你加上一個「NaN / Inf 偵測 + 層級追蹤系統」,讓你在模型爆掉時可以:
- 知道哪個 tensor 出現 NaN/Inf(例如
fake_B
, 判別器輸出等) - 哪個 layer 的權重出了問題(
netD
ornetG
) - 自動列出損失值範圍、警示爆梯風險
✅ 加強版:NaN / Inf 偵測整合建議(放在 optimize_parameters()
裡)
def optimize_parameters(self, use_autocast=False):
self.forward()
# --- 判斷 fake_B 是否出問題 ---
if torch.isnan(self.fake_B).any():
print("❌ fake_B contains NaN")
if torch.isinf(self.fake_B).any():
print("❌ fake_B contains Inf")
# --- Discriminator ---
self.set_requires_grad(self.netD, True)
self.optimizer_D.zero_grad()
try:
if use_autocast:
with torch.amp.autocast(device_type='cuda'):
d_loss, cat_loss_d = self.loss_module.backward_D(self.real_A, self.real_B, self.fake_B, self.labels)
self._check_tensor("d_loss", d_loss)
self.scaler_D.scale(d_loss).backward()
self.scaler_D.unscale_(self.optimizer_D)
torch.nn.utils.clip_grad_norm_(self.netD.parameters(), self.gradient_clip)
self.scaler_D.step(self.optimizer_D)
self.scaler_D.update()
else:
d_loss, cat_loss_d = self.loss_module.backward_D(self.real_A, self.real_B, self.fake_B, self.labels)
self._check_tensor("d_loss", d_loss)
d_loss.backward()
torch.nn.utils.clip_grad_norm_(self.netD.parameters(), self.gradient_clip)
self.optimizer_D.step()
except RuntimeError as e:
print(f"❌ RuntimeError during backward: {e}")
self._debug_nan_in_model(self.netD)
return
if torch.isnan(d_loss):
print("❌ 判別器損失為 NaN,停止訓練。")
self._debug_nan_in_model(self.netD)
return
🔍 輔助函式 1:檢查 Tensor 是否 NaN / Inf
def _check_tensor(self, name, tensor):
if torch.isnan(tensor).any():
print(f"❌ Tensor {name} contains NaN.")
if torch.isinf(tensor).any():
print(f"❌ Tensor {name} contains Inf.")
else:
print(f"✅ Tensor {name} OK: min={tensor.min().item():.4f}, max={tensor.max().item():.4f}")
🧠 輔助函式 2:偵測模型內參數出問題的層
def _debug_nan_in_model(self, model):
print("🔍 檢查模型內部是否有 NaN 或 Inf 的參數...")
for name, param in model.named_parameters():
if torch.isnan(param).any():
print(f"❌ NaN in parameter: {name}")
elif torch.isinf(param).any():
print(f"❌ Inf in parameter: {name}")
🧪 可選:在 self.forward()
後檢查 netG
輸出
def _check_generator_output(self):
if torch.isnan(self.fake_B).any():
print("❌ fake_B has NaN after Generator")
elif torch.isinf(self.fake_B).any():
print("❌ fake_B has Inf after Generator")
else:
print("✅ Generator output OK")
你可以在 forward()
後調用它來早期偵錯:
self.forward()
self._check_generator_output()
✅ 整合效果:
這些會讓你一旦遇到爆梯、NaN、未初始化層爆炸,都能快速看到是哪一層哪個 tensor 出問題,並且搭配你的 _initialize_unmatched_weights()
就能快速修復。
需要我幫你再把這些整合成一個完整 class 的 patch 嗎?還是你會手動加進去?
很好,你的 debug 輸出證明:
- ✅
Generator
沒有 NaN/Inf,輸出值在正常範圍內(max ≈ 0.94 / min ≈ -0.84) - ❌ 但
判別器損失
爆掉了(d_loss = NaN
)
🔍 下一步:精確找出 netD
內哪個層導致爆炸
你現在需要做的是進一步在 Discriminator
forward 過程中,逐層檢查輸出值是否出現 NaN。
我建議你 hook 所有中間層輸出值,並即時報告 NaN / Inf,這樣就能知道是哪一層出問題。
✅ 加入 Forward Hook 來偵測 NaN 層(只需加在一次訓練開頭)
def _register_nan_hooks(self, model, model_name="Model"):
def hook_fn(module, input, output):
if isinstance(output, torch.Tensor):
if torch.isnan(output).any():
print(f"❌ NaN in output of {model_name} layer: {module.__class__.__name__}")
elif torch.isinf(output).any():
print(f"❌ Inf in output of {model_name} layer: {module.__class__.__name__}")
elif isinstance(output, (tuple, list)):
for i, out in enumerate(output):
if isinstance(out, torch.Tensor):
if torch.isnan(out).any():
print(f"❌ NaN in output {i} of {model_name} layer: {module.__class__.__name__}")
elif torch.isinf(out).any():
print(f"❌ Inf in output {i} of {model_name} layer: {module.__class__.__name__}")
for name, module in model.named_modules():
# 避免 hook 整個模型或 Sequential,針對 Conv/Norm/Activation 這類層即可
if not isinstance(module, nn.Sequential) and not isinstance(module, nn.ModuleList):
module.register_forward_hook(hook_fn)
使用方式(在你的 model 初始化後加一次即可):
self._register_nan_hooks(self.netD, model_name="netD")
🧠 Hook 的好處:
- 當你執行
self.loss_module.backward_D(...)
時,只要某層的輸出出現 NaN/Inf,會直接印出是哪一層。 - 不需改任何 loss 或 forward function 的內部。
❗進一步建議:
如果你發現問題總是在某個 layer,例如:
❌ NaN in output of netD layer: BatchNorm2d
那 99% 是:
- 該層沒正確初始化(e.g., running var 為 0)
- 或其 weight/bias 為 NaN
- 或其輸入 tensor 極端不穩定(如全 0 / 很大)
你可以進一步印出該層參數:
for name, param in netD.named_parameters():
if "batchnorm" in name.lower() and torch.isnan(param).any():
print(f"❌ NaN in BN param: {name}")
要我幫你加一個 hook-ready 版本的 netD
包裝器來 debug 嗎?還是你會自己插進去?
當然可以!下面我幫你整合好一套「NaN / Inf 偵測 Hook 系統」,你只要在 netD
初始化後調用一次,就能在訓練過程中 自動追蹤是哪一層輸出爆掉。
✅ 1. Hook 工具:register_nan_hooks()
放入你的訓練類別中
def register_nan_hooks(self, model, model_name="Model"):
def hook_fn(module, input, output):
def check_tensor(t, label):
if not isinstance(t, torch.Tensor):
return
if torch.isnan(t).any():
print(f"❌ NaN in {label} of {model_name} layer: {module.__class__.__name__} ({module})")
elif torch.isinf(t).any():
print(f"❌ Inf in {label} of {model_name} layer: {module.__class__.__name__} ({module})")
if isinstance(output, torch.Tensor):
check_tensor(output, "output")
elif isinstance(output, (tuple, list)):
for i, out in enumerate(output):
check_tensor(out, f"output[{i}]")
for name, module in model.named_modules():
if not isinstance(module, (nn.Sequential, nn.ModuleList)) and len(list(module.children())) == 0:
module.register_forward_hook(hook_fn)
print(f"✅ Hook registered on: {name} ({module.__class__.__name__})")
✅ 2. 使用方式:在 netD
初始化後呼叫一次
這可以放在你的 __init__()
或 setup_model()
裡:
self.register_nan_hooks(self.netD, model_name="netD")
✅ 提示:只需要呼叫一次,之後每次
netD
forward 都會自動檢查!
✅ 3. 建議補強初始化 debug:檢查參數是否含 NaN
放在你 _initialize_unmatched_weights()
後可用:
def check_model_parameters_for_nan(self, model, model_name="Model"):
for name, param in model.named_parameters():
if torch.isnan(param).any():
print(f"❌ NaN in param: {model_name}.{name}")
elif torch.isinf(param).any():
print(f"❌ Inf in param: {model_name}.{name}")
✅ 使用總結:
你在模型載入後這樣寫一次即可:
# after loading model and reinitializing unmatched weights
self.register_nan_hooks(self.netD, model_name="netD")
self.check_model_parameters_for_nan(self.netD, model_name="netD")
這樣接下來你執行訓練,如果 netD
的某層一 forward 就出現 NaN,它會立即報:
❌ NaN in output of netD layer: BatchNorm2d (BatchNorm2d(64, eps=1e-05, ...))
你就可以直接鎖定是哪個 layer 出的問題。
需要我幫你包成一個 DebugMixin 或 HookHelper 類別整合進你的 trainer 類別嗎?可以讓你之後所有模型都能自動開啟這種 NaN 檢查。
✅ 完全抓到了!現在可以確定:
❌ 問題在於
netD
的最後一層Conv2d(512, 1)
,它的輸出出現 NaN
這通常發生在:
- 該層的權重或 bias 含 NaN
- 該層的輸入含極端值(如 ±∞、或梯度爆炸)
- 該層初始化沒成功,尤其是部分載入模型時未初始化這層
✅ 建議修正步驟如下:
✅ 步驟 1:確認這層是否有成功初始化
我們要確認這層(名稱可能像 model.classifier.3
或類似)是不是在載入權重時沒被正確初始化。
你可以插入這段來檢查這層的 weight:
for name, param in self.netD.named_parameters():
if "512" in name or "weight" in name or "bias" in name:
if torch.isnan(param).any():
print(f"❌ NaN in netD param: {name}")
elif torch.isinf(param).any():
print(f"❌ Inf in netD param: {name}")
✅ 步驟 2:明確初始化這層(修補 initialize_unmatched_weights
)
你的 _initialize_unmatched_weights()
可能沒成功處理 buffer 或 detached weight,可以加入:
def _initialize_unmatched_weights(self, model, loaded_state_dict):
for name, param in model.named_parameters():
if name not in loaded_state_dict or torch.isnan(param).any():
print(f"🔄 Re-initializing param: {name}")
if "weight" in name:
nn.init.kaiming_normal_(param.data, mode='fan_out', nonlinearity='leaky_relu')
elif "bias" in name:
nn.init.constant_(param.data, 0)
for name, buffer in model.named_buffers():
if name not in loaded_state_dict or torch.isnan(buffer).any():
print(f"🔄 Re-initializing buffer: {name}")
buffer.data.zero_()
注意 .data
是為了強制覆蓋即使有 gradient 的 tensor。
✅ 步驟 3:避免爆梯(加梯度限制)
你已經有:
torch.nn.utils.clip_grad_norm_(self.netD.parameters(), self.gradient_clip)
確認 self.gradient_clip
值是否太大(推薦 0.1 ~ 5.0),爆梯常導致 NaN!
✅ 額外防禦:最後一層加 Tanh
/ Sigmoid
限制輸出範圍?
如果你的判別器最後沒有 activation(也就是 raw logits),你可以考慮:
self.final_layer = nn.Sequential(
nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1),
nn.Sigmoid() # 或 Tanh() 視你的 GAN 損失決定
)
避免 logits 無限制爆炸導致梯度異常。
✅ 總結修復順序建議:
- 確認最後一層的 weight / bias 是否為 NaN(打印出來)
- 強化
_initialize_unmatched_weights()
,補上 NaN 檢查與.data
重設 - 加入 final layer 限制(如
Sigmoid
) - 確認你在
load_state_dict(strict=False)
之後,有呼叫初始化未命中層的邏輯
需要我幫你修改完整 _initialize_unmatched_weights()
+ load_networks()
的安全版本嗎?可以避免之後類似問題再次發生。
太好了,這裡是幫你整合後的 安全載入 + NaN 檢查 + 未命中層初始化 的完整版本:
✅ load_networks()
:允許部份權重載入並自動補初始化
def load_networks(self, step):
loaded = False
target_filepath_G = os.path.join(self.save_dir, f"{step}_net_G.pth")
target_filepath_D = os.path.join(self.save_dir, f"{step}_net_D.pth")
# --- Generator ---
if os.path.exists(target_filepath_G):
loaded = True
try:
state_dict_G = torch.load(target_filepath_G, map_location=self.device)
self.netG.load_state_dict(state_dict_G, strict=False)
except Exception as e:
print(f"❌ Error loading Generator: {e}")
else:
print(f"⚠️ Generator checkpoint not found: {target_filepath_G}")
# --- Discriminator ---
if os.path.exists(target_filepath_D):
try:
state_dict_D = torch.load(target_filepath_D, map_location=self.device)
self.netD.load_state_dict(state_dict_D, strict=False)
self._initialize_unmatched_weights(self.netD, state_dict_D, model_name="netD")
self._check_nan_in_parameters(self.netD, model_name="netD")
except Exception as e:
print(f"❌ Error loading Discriminator: {e}")
else:
print(f"⚠️ Discriminator checkpoint not found: {target_filepath_D}")
if loaded:
print(f"✅ Model {step} loaded successfully")
return loaded
✅ _initialize_unmatched_weights()
:重設未命中的參數與 buffer,並補強 NaN 保護
def _initialize_unmatched_weights(self, model, loaded_state_dict, model_name="Model"):
for name, param in model.named_parameters():
if name not in loaded_state_dict or torch.isnan(param).any() or torch.isinf(param).any():
print(f"🔄 Re-initializing param: {model_name}.{name}")
if "weight" in name:
nn.init.kaiming_normal_(param.data, mode='fan_out', nonlinearity='leaky_relu')
elif "bias" in name:
nn.init.constant_(param.data, 0)
for name, buffer in model.named_buffers():
if name not in loaded_state_dict or torch.isnan(buffer).any() or torch.isinf(buffer).any():
print(f"🔄 Re-initializing buffer: {model_name}.{name}")
buffer.data.zero_()
✅ _check_nan_in_parameters()
:載入後檢查模型內是否含 NaN/Inf
def _check_nan_in_parameters(self, model, model_name="Model"):
for name, param in model.named_parameters():
if torch.isnan(param).any():
print(f"❌ NaN in param: {model_name}.{name}")
elif torch.isinf(param).any():
print(f"❌ Inf in param: {model_name}.{name}")
for name, buffer in model.named_buffers():
if torch.isnan(buffer).any():
print(f"❌ NaN in buffer: {model_name}.{name}")
elif torch.isinf(buffer).any():
print(f"❌ Inf in buffer: {model_name}.{name}")
✅ 使用方式:
你只要在模型初始化後調用:
self.load_networks(step)
就會:
- ✅ 載入可用的部份權重
- ✅ 自動初始化未載入的層
- ✅ 自動檢查是否有 NaN/Inf
- ✅ 列出問題參數,幫助定位錯誤
若你還沒加上 forward hook,那一併使用這段也不錯來找出 forward 中 NaN:
self.register_nan_hooks(self.netD, model_name="netD")
需要我幫你把這些工具封裝成一個 ModelLoaderHelper
類別嗎?這樣你所有模型都可以共用這個載入流程~
太棒了!這裡是完整封裝好的 ModelLoaderHelper
類別,你只要在你的 Trainer 裡加上這個類別,之後任何模型都可以安全載入、初始化遺漏權重、自動 NaN 偵測 ✅
✅ model_loader_helper.py
import os
import torch
import torch.nn as nn
class ModelLoaderHelper:
def __init__(self, device="cuda"):
self.device = device
def load_model(self, model, ckpt_path, model_name="Model"):
if not os.path.exists(ckpt_path):
print(f"⚠️ {model_name} checkpoint not found: {ckpt_path}")
return False
try:
state_dict = torch.load(ckpt_path, map_location=self.device)
model.load_state_dict(state_dict, strict=False)
self._initialize_unmatched_weights(model, state_dict, model_name)
self._check_nan_in_parameters(model, model_name)
self.register_nan_hooks(model, model_name)
print(f"✅ {model_name} loaded from {ckpt_path}")
return True
except Exception as e:
print(f"❌ Error loading {model_name}: {e}")
return False
def _initialize_unmatched_weights(self, model, loaded_state_dict, model_name="Model"):
for name, param in model.named_parameters():
if name not in loaded_state_dict or torch.isnan(param).any() or torch.isinf(param).any():
print(f"🔄 Re-initializing param: {model_name}.{name}")
if "weight" in name:
nn.init.kaiming_normal_(param.data, mode='fan_out', nonlinearity='leaky_relu')
elif "bias" in name:
nn.init.constant_(param.data, 0)
for name, buffer in model.named_buffers():
if name not in loaded_state_dict or torch.isnan(buffer).any() or torch.isinf(buffer).any():
print(f"🔄 Re-initializing buffer: {model_name}.{name}")
buffer.data.zero_()
def _check_nan_in_parameters(self, model, model_name="Model"):
for name, param in model.named_parameters():
if torch.isnan(param).any():
print(f"❌ NaN in param: {model_name}.{name}")
elif torch.isinf(param).any():
print(f"❌ Inf in param: {model_name}.{name}")
for name, buffer in model.named_buffers():
if torch.isnan(buffer).any():
print(f"❌ NaN in buffer: {model_name}.{name}")
elif torch.isinf(buffer).any():
print(f"❌ Inf in buffer: {model_name}.{name}")
def register_nan_hooks(self, model, model_name="Model"):
def hook_fn(module, input, output):
def check_tensor(t, label):
if not isinstance(t, torch.Tensor):
return
if torch.isnan(t).any():
print(f"❌ NaN in {label} of {model_name} layer: {module.__class__.__name__} ({module})")
elif torch.isinf(t).any():
print(f"❌ Inf in {label} of {model_name} layer: {module.__class__.__name__} ({module})")
if isinstance(output, torch.Tensor):
check_tensor(output, "output")
elif isinstance(output, (tuple, list)):
for i, out in enumerate(output):
check_tensor(out, f"output[{i}]")
for name, module in model.named_modules():
if not isinstance(module, (nn.Sequential, nn.ModuleList)) and len(list(module.children())) == 0:
module.register_forward_hook(hook_fn)
print(f"✅ Hook registered on: {name} ({module.__class__.__name__})")
✅ 用法:
在你的 Trainer 或主程式中這樣使用:
from model_loader_helper import ModelLoaderHelper
# 初始化 loader
self.model_loader = ModelLoaderHelper(device=self.device)
# 載入模型
G_loaded = self.model_loader.load_model(self.netG, f"{step}_net_G.pth", model_name="netG")
D_loaded = self.model_loader.load_model(self.netD, f"{step}_net_D.pth", model_name="netD")
🧠 附加建議:你可以用一行封裝訓練初始化
def setup_models(self, step):
self.model_loader = ModelLoaderHelper(device=self.device)
self.model_loader.load_model(self.netG, f"{self.save_dir}/{step}_net_G.pth", model_name="netG")
self.model_loader.load_model(self.netD, f"{self.save_dir}/{step}_net_D.pth", model_name="netD")
如果你之後有更多模型(如 VGG、encoder、style encoder),都可以共用這個 class ✅
需要我幫你加進 tensorboard 記錄 or log 檔案輸出也可以說一聲!