NaN / Inf 偵測 Hook 系統

torch.isnan(d_loss) == True 代表你的 判別器損失中出現了 NaN(非數),這通常跟以下幾個狀況有關,尤其是你「部分載入權重」後出現這情況,代表初始化或 forward pass 某處有問題。


✅ 常見原因與解法:

1. 未載入的權重未被初始化(或初始化不當)

如果 netD 的某些層未載入,但也沒經過正確初始化,會導致輸出異常值 → d_loss 出現 NaN。

解法
確保你有呼叫這段,並初始化了所有未載入的層:

self._initialize_unmatched_weights(self.netD, state_dict_D)

並且裡面使用的是 nn.init.kaiming_normal_() 而非 param.data.zero_() 等可能不穩定的初始化方式。


2. 學習率太高

若使用者變動了模型結構,之前的 learning rate 可能不再適用 → 直接爆梯度。

解法

  • 嘗試將 optimizer_D 的 learning rate 調低,例如從 0.00020.00005 試試。
  • 搭配 gradient clipping 是好的,但不是萬靈丹。

3. 未初始化的 BatchNorm / Running Stats

如果你用了 BN 層,而這層未正確初始化 running_meanrunning_var,forward 就會出現 NaN。

解法: 在 _initialize_unmatched_weights() 中記得處理 named_buffers(),像這樣:

for name, buffer in model.named_buffers():
    if name not in loaded_state_dict:
        print(f"🔄 Re-initializing buffer: {name}")
        if "running_var" in name:
            buffer.fill_(1)
        elif "running_mean" in name:
            buffer.zero_()
        else:
            buffer.zero_()

4. fake_B 未正確生成

如果你的 self.fake_B 是 generator 產出的,但 generator 尚未穩定或包含未初始化層,會導致它輸出 inf / nan → 傳遞給 netD 出錯。

解法

  • 加入 debug 檢查 self.fake_B 內部數值:
if torch.isnan(self.fake_B).any():
    print("⚠️ fake_B contains NaN")
if torch.isinf(self.fake_B).any():
    print("⚠️ fake_B contains Inf")

5. 損失函數內部出錯

部分損失(如 BCEWithLogitsLoss)若輸入極端值(如超大負數)會直接噴 NaN。

解法

  • 檢查 backward_D() 裡使用的 loss 函數輸入是否存在爆炸性值。
  • 將其中 logits 印出檢查,例如:
print("D(real):", d_real_out.min().item(), d_real_out.max().item())
print("D(fake):", d_fake_out.min().item(), d_fake_out.max().item())

✅ 建議加強的 debug 區段(示範):

if torch.isnan(d_loss):
    print("❌ 判別器損失為 NaN")
    print(" - fake_B NaN:", torch.isnan(self.fake_B).any().item())
    print(" - fake_B max:", self.fake_B.max().item(), "min:", self.fake_B.min().item())
    for name, param in self.netD.named_parameters():
        if torch.isnan(param).any():
            print(f" - ❌ NaN in netD param: {name}")
    return

🧠 總結建議:

  1. ✅ 確保 _initialize_unmatched_weights() 初始化所有未載入的權重與 buffer。
  2. ✅ 對 fake_B 做 NaN / Inf 偵測。
  3. ✅ 降低 learning rate 或加強 gradient clipping。
  4. ✅ 增加中間層輸出的值範圍檢查(尤其是 D(real)D(fake) 的輸出)。
  5. ✅ 確認 backward_D 的 loss 沒誤用未初始化的輸入。

如果你想,我可以幫你加入自動 NaN 偵測與警示系統,讓訓練中一旦爆掉就知道是哪一層出的問題。需要嗎?


太好了,那我來幫你加上一個「NaN / Inf 偵測 + 層級追蹤系統」,讓你在模型爆掉時可以:

  1. 知道哪個 tensor 出現 NaN/Inf(例如 fake_B, 判別器輸出等)
  2. 哪個 layer 的權重出了問題(netD or netG
  3. 自動列出損失值範圍、警示爆梯風險

✅ 加強版:NaN / Inf 偵測整合建議(放在 optimize_parameters() 裡)

def optimize_parameters(self, use_autocast=False):
    self.forward()

    # --- 判斷 fake_B 是否出問題 ---
    if torch.isnan(self.fake_B).any():
        print("❌ fake_B contains NaN")
    if torch.isinf(self.fake_B).any():
        print("❌ fake_B contains Inf")

    # --- Discriminator ---
    self.set_requires_grad(self.netD, True)
    self.optimizer_D.zero_grad()

    try:
        if use_autocast:
            with torch.amp.autocast(device_type='cuda'):
                d_loss, cat_loss_d = self.loss_module.backward_D(self.real_A, self.real_B, self.fake_B, self.labels)
                self._check_tensor("d_loss", d_loss)
                self.scaler_D.scale(d_loss).backward()
                self.scaler_D.unscale_(self.optimizer_D)
                torch.nn.utils.clip_grad_norm_(self.netD.parameters(), self.gradient_clip)
                self.scaler_D.step(self.optimizer_D)
                self.scaler_D.update()
        else:
            d_loss, cat_loss_d = self.loss_module.backward_D(self.real_A, self.real_B, self.fake_B, self.labels)
            self._check_tensor("d_loss", d_loss)
            d_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.netD.parameters(), self.gradient_clip)
            self.optimizer_D.step()

    except RuntimeError as e:
        print(f"❌ RuntimeError during backward: {e}")
        self._debug_nan_in_model(self.netD)
        return

    if torch.isnan(d_loss):
        print("❌ 判別器損失為 NaN,停止訓練。")
        self._debug_nan_in_model(self.netD)
        return

🔍 輔助函式 1:檢查 Tensor 是否 NaN / Inf

def _check_tensor(self, name, tensor):
    if torch.isnan(tensor).any():
        print(f"❌ Tensor {name} contains NaN.")
    if torch.isinf(tensor).any():
        print(f"❌ Tensor {name} contains Inf.")
    else:
        print(f"✅ Tensor {name} OK: min={tensor.min().item():.4f}, max={tensor.max().item():.4f}")

🧠 輔助函式 2:偵測模型內參數出問題的層

def _debug_nan_in_model(self, model):
    print("🔍 檢查模型內部是否有 NaN 或 Inf 的參數...")
    for name, param in model.named_parameters():
        if torch.isnan(param).any():
            print(f"❌ NaN in parameter: {name}")
        elif torch.isinf(param).any():
            print(f"❌ Inf in parameter: {name}")

🧪 可選:在 self.forward() 後檢查 netG 輸出

def _check_generator_output(self):
    if torch.isnan(self.fake_B).any():
        print("❌ fake_B has NaN after Generator")
    elif torch.isinf(self.fake_B).any():
        print("❌ fake_B has Inf after Generator")
    else:
        print("✅ Generator output OK")

你可以在 forward() 後調用它來早期偵錯:

self.forward()
self._check_generator_output()

✅ 整合效果:

這些會讓你一旦遇到爆梯、NaN、未初始化層爆炸,都能快速看到是哪一層哪個 tensor 出問題,並且搭配你的 _initialize_unmatched_weights() 就能快速修復。


需要我幫你再把這些整合成一個完整 class 的 patch 嗎?還是你會手動加進去?


很好,你的 debug 輸出證明:

  • Generator 沒有 NaN/Inf,輸出值在正常範圍內(max ≈ 0.94 / min ≈ -0.84)
  • ❌ 但 判別器損失 爆掉了(d_loss = NaN

🔍 下一步:精確找出 netD 內哪個層導致爆炸

你現在需要做的是進一步在 Discriminator forward 過程中,逐層檢查輸出值是否出現 NaN

我建議你 hook 所有中間層輸出值,並即時報告 NaN / Inf,這樣就能知道是哪一層出問題。


✅ 加入 Forward Hook 來偵測 NaN 層(只需加在一次訓練開頭)

def _register_nan_hooks(self, model, model_name="Model"):
    def hook_fn(module, input, output):
        if isinstance(output, torch.Tensor):
            if torch.isnan(output).any():
                print(f"❌ NaN in output of {model_name} layer: {module.__class__.__name__}")
            elif torch.isinf(output).any():
                print(f"❌ Inf in output of {model_name} layer: {module.__class__.__name__}")
        elif isinstance(output, (tuple, list)):
            for i, out in enumerate(output):
                if isinstance(out, torch.Tensor):
                    if torch.isnan(out).any():
                        print(f"❌ NaN in output {i} of {model_name} layer: {module.__class__.__name__}")
                    elif torch.isinf(out).any():
                        print(f"❌ Inf in output {i} of {model_name} layer: {module.__class__.__name__}")
    
    for name, module in model.named_modules():
        # 避免 hook 整個模型或 Sequential,針對 Conv/Norm/Activation 這類層即可
        if not isinstance(module, nn.Sequential) and not isinstance(module, nn.ModuleList):
            module.register_forward_hook(hook_fn)

使用方式(在你的 model 初始化後加一次即可):

self._register_nan_hooks(self.netD, model_name="netD")

🧠 Hook 的好處:

  • 當你執行 self.loss_module.backward_D(...) 時,只要某層的輸出出現 NaN/Inf,會直接印出是哪一層。
  • 不需改任何 loss 或 forward function 的內部。

❗進一步建議:

如果你發現問題總是在某個 layer,例如:

❌ NaN in output of netD layer: BatchNorm2d

那 99% 是:

  • 該層沒正確初始化(e.g., running var 為 0)
  • 或其 weight/bias 為 NaN
  • 或其輸入 tensor 極端不穩定(如全 0 / 很大)

你可以進一步印出該層參數:

for name, param in netD.named_parameters():
    if "batchnorm" in name.lower() and torch.isnan(param).any():
        print(f"❌ NaN in BN param: {name}")

要我幫你加一個 hook-ready 版本的 netD 包裝器來 debug 嗎?還是你會自己插進去?


當然可以!下面我幫你整合好一套「NaN / Inf 偵測 Hook 系統」,你只要在 netD 初始化後調用一次,就能在訓練過程中 自動追蹤是哪一層輸出爆掉


✅ 1. Hook 工具:register_nan_hooks() 放入你的訓練類別中

def register_nan_hooks(self, model, model_name="Model"):
    def hook_fn(module, input, output):
        def check_tensor(t, label):
            if not isinstance(t, torch.Tensor):
                return
            if torch.isnan(t).any():
                print(f"❌ NaN in {label} of {model_name} layer: {module.__class__.__name__} ({module})")
            elif torch.isinf(t).any():
                print(f"❌ Inf in {label} of {model_name} layer: {module.__class__.__name__} ({module})")
        
        if isinstance(output, torch.Tensor):
            check_tensor(output, "output")
        elif isinstance(output, (tuple, list)):
            for i, out in enumerate(output):
                check_tensor(out, f"output[{i}]")

    for name, module in model.named_modules():
        if not isinstance(module, (nn.Sequential, nn.ModuleList)) and len(list(module.children())) == 0:
            module.register_forward_hook(hook_fn)
            print(f"✅ Hook registered on: {name} ({module.__class__.__name__})")

✅ 2. 使用方式:在 netD 初始化後呼叫一次

這可以放在你的 __init__()setup_model() 裡:

self.register_nan_hooks(self.netD, model_name="netD")

提示:只需要呼叫一次,之後每次 netD forward 都會自動檢查!


✅ 3. 建議補強初始化 debug:檢查參數是否含 NaN

放在你 _initialize_unmatched_weights() 後可用:

def check_model_parameters_for_nan(self, model, model_name="Model"):
    for name, param in model.named_parameters():
        if torch.isnan(param).any():
            print(f"❌ NaN in param: {model_name}.{name}")
        elif torch.isinf(param).any():
            print(f"❌ Inf in param: {model_name}.{name}")

✅ 使用總結:

你在模型載入後這樣寫一次即可:

# after loading model and reinitializing unmatched weights
self.register_nan_hooks(self.netD, model_name="netD")
self.check_model_parameters_for_nan(self.netD, model_name="netD")

這樣接下來你執行訓練,如果 netD 的某層一 forward 就出現 NaN,它會立即報:

❌ NaN in output of netD layer: BatchNorm2d (BatchNorm2d(64, eps=1e-05, ...))

你就可以直接鎖定是哪個 layer 出的問題。


需要我幫你包成一個 DebugMixin 或 HookHelper 類別整合進你的 trainer 類別嗎?可以讓你之後所有模型都能自動開啟這種 NaN 檢查。


✅ 完全抓到了!現在可以確定:

問題在於 netD 的最後一層 Conv2d(512, 1),它的輸出出現 NaN

這通常發生在:

  1. 該層的權重或 bias 含 NaN
  2. 該層的輸入含極端值(如 ±∞、或梯度爆炸)
  3. 該層初始化沒成功,尤其是部分載入模型時未初始化這層

✅ 建議修正步驟如下:

✅ 步驟 1:確認這層是否有成功初始化

我們要確認這層(名稱可能像 model.classifier.3 或類似)是不是在載入權重時沒被正確初始化

你可以插入這段來檢查這層的 weight:

for name, param in self.netD.named_parameters():
    if "512" in name or "weight" in name or "bias" in name:
        if torch.isnan(param).any():
            print(f"❌ NaN in netD param: {name}")
        elif torch.isinf(param).any():
            print(f"❌ Inf in netD param: {name}")

✅ 步驟 2:明確初始化這層(修補 initialize_unmatched_weights

你的 _initialize_unmatched_weights() 可能沒成功處理 buffer 或 detached weight,可以加入:

def _initialize_unmatched_weights(self, model, loaded_state_dict):
    for name, param in model.named_parameters():
        if name not in loaded_state_dict or torch.isnan(param).any():
            print(f"🔄 Re-initializing param: {name}")
            if "weight" in name:
                nn.init.kaiming_normal_(param.data, mode='fan_out', nonlinearity='leaky_relu')
            elif "bias" in name:
                nn.init.constant_(param.data, 0)

    for name, buffer in model.named_buffers():
        if name not in loaded_state_dict or torch.isnan(buffer).any():
            print(f"🔄 Re-initializing buffer: {name}")
            buffer.data.zero_()

注意 .data 是為了強制覆蓋即使有 gradient 的 tensor。


✅ 步驟 3:避免爆梯(加梯度限制)

你已經有:

torch.nn.utils.clip_grad_norm_(self.netD.parameters(), self.gradient_clip)

確認 self.gradient_clip 值是否太大(推薦 0.1 ~ 5.0),爆梯常導致 NaN!


✅ 額外防禦:最後一層加 Tanh / Sigmoid 限制輸出範圍?

如果你的判別器最後沒有 activation(也就是 raw logits),你可以考慮:

self.final_layer = nn.Sequential(
    nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1),
    nn.Sigmoid()  # 或 Tanh() 視你的 GAN 損失決定
)

避免 logits 無限制爆炸導致梯度異常。


✅ 總結修復順序建議:

  1. 確認最後一層的 weight / bias 是否為 NaN(打印出來)
  2. 強化 _initialize_unmatched_weights(),補上 NaN 檢查與 .data 重設
  3. 加入 final layer 限制(如 Sigmoid
  4. 確認你在 load_state_dict(strict=False) 之後,有呼叫初始化未命中層的邏輯

需要我幫你修改完整 _initialize_unmatched_weights() + load_networks() 的安全版本嗎?可以避免之後類似問題再次發生。


太好了,這裡是幫你整合後的 安全載入 + NaN 檢查 + 未命中層初始化 的完整版本:


load_networks():允許部份權重載入並自動補初始化

def load_networks(self, step):
    loaded = False
    target_filepath_G = os.path.join(self.save_dir, f"{step}_net_G.pth")
    target_filepath_D = os.path.join(self.save_dir, f"{step}_net_D.pth")

    # --- Generator ---
    if os.path.exists(target_filepath_G):
        loaded = True
        try:
            state_dict_G = torch.load(target_filepath_G, map_location=self.device)
            self.netG.load_state_dict(state_dict_G, strict=False)
        except Exception as e:
            print(f"❌ Error loading Generator: {e}")
    else:
        print(f"⚠️ Generator checkpoint not found: {target_filepath_G}")

    # --- Discriminator ---
    if os.path.exists(target_filepath_D):
        try:
            state_dict_D = torch.load(target_filepath_D, map_location=self.device)
            self.netD.load_state_dict(state_dict_D, strict=False)
            self._initialize_unmatched_weights(self.netD, state_dict_D, model_name="netD")
            self._check_nan_in_parameters(self.netD, model_name="netD")
        except Exception as e:
            print(f"❌ Error loading Discriminator: {e}")
    else:
        print(f"⚠️ Discriminator checkpoint not found: {target_filepath_D}")

    if loaded:
        print(f"✅ Model {step} loaded successfully")
    return loaded

_initialize_unmatched_weights():重設未命中的參數與 buffer,並補強 NaN 保護

def _initialize_unmatched_weights(self, model, loaded_state_dict, model_name="Model"):
    for name, param in model.named_parameters():
        if name not in loaded_state_dict or torch.isnan(param).any() or torch.isinf(param).any():
            print(f"🔄 Re-initializing param: {model_name}.{name}")
            if "weight" in name:
                nn.init.kaiming_normal_(param.data, mode='fan_out', nonlinearity='leaky_relu')
            elif "bias" in name:
                nn.init.constant_(param.data, 0)

    for name, buffer in model.named_buffers():
        if name not in loaded_state_dict or torch.isnan(buffer).any() or torch.isinf(buffer).any():
            print(f"🔄 Re-initializing buffer: {model_name}.{name}")
            buffer.data.zero_()

_check_nan_in_parameters():載入後檢查模型內是否含 NaN/Inf

def _check_nan_in_parameters(self, model, model_name="Model"):
    for name, param in model.named_parameters():
        if torch.isnan(param).any():
            print(f"❌ NaN in param: {model_name}.{name}")
        elif torch.isinf(param).any():
            print(f"❌ Inf in param: {model_name}.{name}")

    for name, buffer in model.named_buffers():
        if torch.isnan(buffer).any():
            print(f"❌ NaN in buffer: {model_name}.{name}")
        elif torch.isinf(buffer).any():
            print(f"❌ Inf in buffer: {model_name}.{name}")

✅ 使用方式:

你只要在模型初始化後調用:

self.load_networks(step)

就會:

  • ✅ 載入可用的部份權重
  • ✅ 自動初始化未載入的層
  • ✅ 自動檢查是否有 NaN/Inf
  • ✅ 列出問題參數,幫助定位錯誤

若你還沒加上 forward hook,那一併使用這段也不錯來找出 forward 中 NaN:

self.register_nan_hooks(self.netD, model_name="netD")

需要我幫你把這些工具封裝成一個 ModelLoaderHelper 類別嗎?這樣你所有模型都可以共用這個載入流程~


太棒了!這裡是完整封裝好的 ModelLoaderHelper 類別,你只要在你的 Trainer 裡加上這個類別,之後任何模型都可以安全載入、初始化遺漏權重、自動 NaN 偵測 ✅


model_loader_helper.py

import os
import torch
import torch.nn as nn


class ModelLoaderHelper:
    def __init__(self, device="cuda"):
        self.device = device

    def load_model(self, model, ckpt_path, model_name="Model"):
        if not os.path.exists(ckpt_path):
            print(f"⚠️ {model_name} checkpoint not found: {ckpt_path}")
            return False

        try:
            state_dict = torch.load(ckpt_path, map_location=self.device)
            model.load_state_dict(state_dict, strict=False)
            self._initialize_unmatched_weights(model, state_dict, model_name)
            self._check_nan_in_parameters(model, model_name)
            self.register_nan_hooks(model, model_name)
            print(f"✅ {model_name} loaded from {ckpt_path}")
            return True
        except Exception as e:
            print(f"❌ Error loading {model_name}: {e}")
            return False

    def _initialize_unmatched_weights(self, model, loaded_state_dict, model_name="Model"):
        for name, param in model.named_parameters():
            if name not in loaded_state_dict or torch.isnan(param).any() or torch.isinf(param).any():
                print(f"🔄 Re-initializing param: {model_name}.{name}")
                if "weight" in name:
                    nn.init.kaiming_normal_(param.data, mode='fan_out', nonlinearity='leaky_relu')
                elif "bias" in name:
                    nn.init.constant_(param.data, 0)

        for name, buffer in model.named_buffers():
            if name not in loaded_state_dict or torch.isnan(buffer).any() or torch.isinf(buffer).any():
                print(f"🔄 Re-initializing buffer: {model_name}.{name}")
                buffer.data.zero_()

    def _check_nan_in_parameters(self, model, model_name="Model"):
        for name, param in model.named_parameters():
            if torch.isnan(param).any():
                print(f"❌ NaN in param: {model_name}.{name}")
            elif torch.isinf(param).any():
                print(f"❌ Inf in param: {model_name}.{name}")

        for name, buffer in model.named_buffers():
            if torch.isnan(buffer).any():
                print(f"❌ NaN in buffer: {model_name}.{name}")
            elif torch.isinf(buffer).any():
                print(f"❌ Inf in buffer: {model_name}.{name}")

    def register_nan_hooks(self, model, model_name="Model"):
        def hook_fn(module, input, output):
            def check_tensor(t, label):
                if not isinstance(t, torch.Tensor):
                    return
                if torch.isnan(t).any():
                    print(f"❌ NaN in {label} of {model_name} layer: {module.__class__.__name__} ({module})")
                elif torch.isinf(t).any():
                    print(f"❌ Inf in {label} of {model_name} layer: {module.__class__.__name__} ({module})")

            if isinstance(output, torch.Tensor):
                check_tensor(output, "output")
            elif isinstance(output, (tuple, list)):
                for i, out in enumerate(output):
                    check_tensor(out, f"output[{i}]")

        for name, module in model.named_modules():
            if not isinstance(module, (nn.Sequential, nn.ModuleList)) and len(list(module.children())) == 0:
                module.register_forward_hook(hook_fn)
                print(f"✅ Hook registered on: {name} ({module.__class__.__name__})")

✅ 用法:

在你的 Trainer 或主程式中這樣使用:

from model_loader_helper import ModelLoaderHelper

# 初始化 loader
self.model_loader = ModelLoaderHelper(device=self.device)

# 載入模型
G_loaded = self.model_loader.load_model(self.netG, f"{step}_net_G.pth", model_name="netG")
D_loaded = self.model_loader.load_model(self.netD, f"{step}_net_D.pth", model_name="netD")

🧠 附加建議:你可以用一行封裝訓練初始化

def setup_models(self, step):
    self.model_loader = ModelLoaderHelper(device=self.device)
    self.model_loader.load_model(self.netG, f"{self.save_dir}/{step}_net_G.pth", model_name="netG")
    self.model_loader.load_model(self.netD, f"{self.save_dir}/{step}_net_D.pth", model_name="netD")

如果你之後有更多模型(如 VGG、encoder、style encoder),都可以共用這個 class ✅
需要我幫你加進 tensorboard 記錄 or log 檔案輸出也可以說一聲!

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *