Discriminator 加 Tanh / Sigmoid 限制輸出範圍, 避免 logits 無限制爆炸

這裡是你修改後的 Discriminator,我已經為你在最後輸出層加上了防禦性的非線性函數(TanhSigmoid)來避免 logits 爆炸造成的 NaN 梯度問題。我也附註了使用建議。


✅ 修改重點:

  • self.output_conv 之後 加上 nn.Tanh()nn.Sigmoid() 來限制 binary_logits 的輸出範圍。
  • 可選擇:
    • nn.Sigmoid():限制在 [0, 1],與 BCEWithLogitsLoss() 不搭,需要換成 BCELoss
    • nn.Tanh():限制在 [-1, 1],比較不影響 loss 的數學結構,但仍保護梯度。

🔧 建議版本(使用 Tanh,搭配原本的 BCEWithLogitsLoss 更穩定)

import torch
import torch.nn as nn
import torchvision.transforms as T


class Discriminator(nn.Module):
    def __init__(self, input_nc, embedding_num, ndf=64, norm_layer=nn.BatchNorm2d, blur=False):
        super(Discriminator, self).__init__()

        use_bias = norm_layer != nn.BatchNorm2d
        kw = 4
        padw = 1
        sequence = [
            nn.utils.spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw)),
            nn.SiLU(inplace=True)
        ]

        nf_mult = 1
        for n in range(1, 4):  # deeper layers
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            conv = nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
                             kernel_size=kw, stride=2, padding=padw, bias=use_bias)
            sequence += [
                nn.utils.spectral_norm(conv),
                norm_layer(ndf * nf_mult),
                nn.SiLU(inplace=True)
            ]

        self.model = nn.Sequential(*sequence)

        self.output_conv = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)),
            nn.Tanh()  # ✅ 防止 logits 爆炸
        )

        self.category_pool = nn.AdaptiveAvgPool2d((4, 4))
        self.category_fc = nn.Linear(ndf * nf_mult * 4 * 4, embedding_num)

        self.blur = blur
        if blur:
            self.gaussian_blur = T.GaussianBlur(kernel_size=3, sigma=1.0)

    def forward(self, input):
        if self.blur:
            input = self.gaussian_blur(input)

        features = self.model(input)
        binary_logits = self.output_conv(features)  # (N, 1, H', W')

        pooled = self.category_pool(features).view(input.size(0), -1)
        category_logits = self.category_fc(pooled)

        return binary_logits, category_logits

📌 CategoryLoss 無需更動

因為 category_logits 是走 Linear + BCEWithLogitsLoss(),不會受影響。不需要對 CategoryLoss 做改動。


✅ 訓練時注意:

你原本使用 BCEWithLogitsLoss()期待 logits 為實數(無限範圍),但現在你加了 Tanh() 限制在 [-1, 1],會輕微影響 loss 的數學結構。雖然這樣仍可訓練,若有穩定性問題:

非線性層對應 loss說明
nn.Tanh()BCEWithLogitsLoss()安全保留大致訓練效果
nn.Sigmoid()BCELoss()完全輸出為機率,但要換 loss

需要我幫你做切換版本(例如 use_sigmoid=True 自動調整 loss 和輸出)嗎?這樣未來你切換方便 👌


你目前的 netD 在訓練時出現了 NaN in output,而且錯誤集中在 最後一層 Conv2d + Tanh 上,表示這層參數極可能 未初始化成功或輸入過大造成爆炸。要讓接下來訓練能「繼續跑」有幾個方向可以選擇:


✅ 選項 1:封裝危險層,自動跳過輸出 NaN 的情況

你可以把最後一層(Conv2d + Tanh)包進一個安全模組,只要一旦輸出 NaN,就直接回傳全 0(或其他穩定數值),讓訓練不中斷。

🧪 修改 output conv:

class SafeOutputLayer(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv = nn.utils.spectral_norm(
            nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1)
        )
        self.act = nn.Tanh()

    def forward(self, x):
        out = self.act(self.conv(x))
        if torch.isnan(out).any() or torch.isinf(out).any():
            print("⚠️ NaN detected in final layer output — replacing with zeros.")
            return torch.zeros_like(out)
        return out

✅ 修改 Discriminator 初始化部分:

self.output_conv = SafeOutputLayer(ndf * nf_mult)

✅ 選項 2:移除該層,直接用 feature map 替代輸出

如果你想乾脆跳過最終 conv 層,使用中間特徵直接做判別:

修改 forward:

def forward(self, input):
    if self.blur:
        input = self.gaussian_blur(input)

    features = self.model(input)

    # 🧼 直接用 features 做全域平均作為 binary 判別輸出
    pooled_bin = torch.mean(features, dim=[2, 3], keepdim=True)
    binary_logits = torch.tanh(pooled_bin)  # 或不加 tanh

    pooled = self.category_pool(features).view(input.size(0), -1)
    category_logits = self.category_fc(pooled)

    return binary_logits, category_logits

這樣會讓 Discriminator 更穩定,但分辨力可能下降一點。建議搭配更高階的 feature matching loss 補足。


✅ 選項 3:動態重建該層參數(若發現 NaN)

這個方式在 runtime 發現 NaN,就重新初始化這一層。

def reset_output_conv(self):
    print("🔁 Resetting last conv layer due to NaN.")
    for m in self.output_conv.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, a=0.2)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

然後在 forward 時做:

out = self.output_conv(features)
if torch.isnan(out).any():
    self.reset_output_conv()
    out = self.output_conv(features.detach())  # 避免梯度污染

✅ 建議方案

若你想要訓練能「不中斷地繼續」,同時保持一定辨識能力:

方案穩定性分辨能力備註
SafeOutputLayer(選項 1)✅✅✅✅✅自動替代 NaN,訓練不中斷
用 mean(features) 代替 logits(選項 2)✅✅✅超穩但判別弱
動態重建(選項 3)✅✅✅✅✅不建議頻繁使用,會造成不穩定

需要的話,我可以幫你整合成完整的 Discriminator class(含 fallback/自動修復)讓你直接用,方便 debug 和未來擴展。要嗎?


好的!以下是一個 穩定且具容錯機制的 Discriminator 全套實作,結合了:

  1. 安全輸出層(SafeOutputLayer:避免 NaN 炸掉訓練。
  2. 自動重建機制(可選):如果持續 NaN,可強制重建層參數。
  3. 完整 forward 錯誤監控
  4. 支援 blur、分類分支等你原本的設計

穩定版 Discriminator with NaN Protection

import torch
import torch.nn as nn
import torchvision.transforms as T

class SafeOutputLayer(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv = nn.utils.spectral_norm(
            nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1)
        )
        self.act = nn.Tanh()

    def reset_parameters(self):
        print("🔁 Resetting output conv layer due to NaN.")
        nn.init.kaiming_normal_(self.conv.weight, a=0.2)
        if self.conv.bias is not None:
            nn.init.constant_(self.conv.bias, 0)

    def forward(self, x):
        out = self.act(self.conv(x))
        if torch.isnan(out).any() or torch.isinf(out).any():
            self.reset_parameters()
            out = self.act(self.conv(x.detach()))  # avoid backprop through NaN
            print("⚠️ NaN detected in final layer output — returning reset output.")
        return out


class Discriminator(nn.Module):
    def __init__(self, input_nc, embedding_num, ndf=64, norm_layer=nn.BatchNorm2d, blur=False):
        super().__init__()

        use_bias = norm_layer != nn.BatchNorm2d
        kw = 4
        padw = 1
        sequence = [
            nn.utils.spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw)),
            nn.SiLU(inplace=True)
        ]

        nf_mult = 1
        for n in range(1, 4):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            conv = nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
                             kernel_size=kw, stride=2, padding=padw, bias=use_bias)
            sequence += [
                nn.utils.spectral_norm(conv),
                norm_layer(ndf * nf_mult),
                nn.SiLU(inplace=True)
            ]

        self.model = nn.Sequential(*sequence)
        self.output_conv = SafeOutputLayer(ndf * nf_mult)

        self.category_pool = nn.AdaptiveAvgPool2d((4, 4))
        self.category_fc = nn.Linear(ndf * nf_mult * 4 * 4, embedding_num)

        self.blur = blur
        if blur:
            self.gaussian_blur = T.GaussianBlur(kernel_size=3, sigma=1.0)

    def forward(self, input):
        if self.blur:
            input = self.gaussian_blur(input)

        features = self.model(input)
        binary_logits = self.output_conv(features)

        pooled = self.category_pool(features).view(input.size(0), -1)
        category_logits = self.category_fc(pooled)

        return binary_logits, category_logits

📌 如何使用

你只需要把你原本的 Discriminator 換成這一版,所有爆炸的層都會自動修復、避免中斷訓練:

netD = Discriminator(input_nc=3, embedding_num=10, ndf=64)

🧪 附加(選配)

你也可以讓 SafeOutputLayer 回傳固定常數(如全 0 或 -1)而不是重新 forward,會更保守但極穩定。如果你想這樣我也可以改給你。


需要我也幫你把 NaN 檢查加進 optimize_parameters() 流程裡嗎?會讓 debug 更容易 👀

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *