這裡是你修改後的 Discriminator
,我已經為你在最後輸出層加上了防禦性的非線性函數(Tanh
或 Sigmoid
)來避免 logits 爆炸造成的 NaN 梯度問題。我也附註了使用建議。
✅ 修改重點:
- 在
self.output_conv
之後 加上nn.Tanh()
或nn.Sigmoid()
來限制binary_logits
的輸出範圍。 - 可選擇:
nn.Sigmoid()
:限制在[0, 1]
,與BCEWithLogitsLoss()
不搭,需要換成BCELoss
。nn.Tanh()
:限制在[-1, 1]
,比較不影響 loss 的數學結構,但仍保護梯度。
🔧 建議版本(使用 Tanh
,搭配原本的 BCEWithLogitsLoss
更穩定)
import torch
import torch.nn as nn
import torchvision.transforms as T
class Discriminator(nn.Module):
def __init__(self, input_nc, embedding_num, ndf=64, norm_layer=nn.BatchNorm2d, blur=False):
super(Discriminator, self).__init__()
use_bias = norm_layer != nn.BatchNorm2d
kw = 4
padw = 1
sequence = [
nn.utils.spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw)),
nn.SiLU(inplace=True)
]
nf_mult = 1
for n in range(1, 4): # deeper layers
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
conv = nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=2, padding=padw, bias=use_bias)
sequence += [
nn.utils.spectral_norm(conv),
norm_layer(ndf * nf_mult),
nn.SiLU(inplace=True)
]
self.model = nn.Sequential(*sequence)
self.output_conv = nn.Sequential(
nn.utils.spectral_norm(nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)),
nn.Tanh() # ✅ 防止 logits 爆炸
)
self.category_pool = nn.AdaptiveAvgPool2d((4, 4))
self.category_fc = nn.Linear(ndf * nf_mult * 4 * 4, embedding_num)
self.blur = blur
if blur:
self.gaussian_blur = T.GaussianBlur(kernel_size=3, sigma=1.0)
def forward(self, input):
if self.blur:
input = self.gaussian_blur(input)
features = self.model(input)
binary_logits = self.output_conv(features) # (N, 1, H', W')
pooled = self.category_pool(features).view(input.size(0), -1)
category_logits = self.category_fc(pooled)
return binary_logits, category_logits
📌 CategoryLoss 無需更動
因為 category_logits
是走 Linear + BCEWithLogitsLoss()
,不會受影響。不需要對 CategoryLoss
做改動。
✅ 訓練時注意:
你原本使用 BCEWithLogitsLoss()
是 期待 logits 為實數(無限範圍),但現在你加了 Tanh()
限制在 [-1, 1]
,會輕微影響 loss 的數學結構。雖然這樣仍可訓練,若有穩定性問題:
非線性層 | 對應 loss | 說明 |
---|---|---|
nn.Tanh() | BCEWithLogitsLoss() ✅ | 安全保留大致訓練效果 |
nn.Sigmoid() | BCELoss() ✅ | 完全輸出為機率,但要換 loss |
需要我幫你做切換版本(例如 use_sigmoid=True
自動調整 loss 和輸出)嗎?這樣未來你切換方便 👌
你目前的 netD
在訓練時出現了 NaN in output,而且錯誤集中在 最後一層 Conv2d + Tanh 上,表示這層參數極可能 未初始化成功或輸入過大造成爆炸。要讓接下來訓練能「繼續跑」有幾個方向可以選擇:
✅ 選項 1:封裝危險層,自動跳過輸出 NaN 的情況
你可以把最後一層(Conv2d + Tanh
)包進一個安全模組,只要一旦輸出 NaN,就直接回傳全 0(或其他穩定數值),讓訓練不中斷。
🧪 修改 output conv:
class SafeOutputLayer(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = nn.utils.spectral_norm(
nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1)
)
self.act = nn.Tanh()
def forward(self, x):
out = self.act(self.conv(x))
if torch.isnan(out).any() or torch.isinf(out).any():
print("⚠️ NaN detected in final layer output — replacing with zeros.")
return torch.zeros_like(out)
return out
✅ 修改 Discriminator 初始化部分:
self.output_conv = SafeOutputLayer(ndf * nf_mult)
✅ 選項 2:移除該層,直接用 feature map 替代輸出
如果你想乾脆跳過最終 conv 層,使用中間特徵直接做判別:
修改 forward:
def forward(self, input):
if self.blur:
input = self.gaussian_blur(input)
features = self.model(input)
# 🧼 直接用 features 做全域平均作為 binary 判別輸出
pooled_bin = torch.mean(features, dim=[2, 3], keepdim=True)
binary_logits = torch.tanh(pooled_bin) # 或不加 tanh
pooled = self.category_pool(features).view(input.size(0), -1)
category_logits = self.category_fc(pooled)
return binary_logits, category_logits
這樣會讓
Discriminator
更穩定,但分辨力可能下降一點。建議搭配更高階的feature matching loss
補足。
✅ 選項 3:動態重建該層參數(若發現 NaN)
這個方式在 runtime 發現 NaN,就重新初始化這一層。
def reset_output_conv(self):
print("🔁 Resetting last conv layer due to NaN.")
for m in self.output_conv.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, a=0.2)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
然後在 forward 時做:
out = self.output_conv(features)
if torch.isnan(out).any():
self.reset_output_conv()
out = self.output_conv(features.detach()) # 避免梯度污染
✅ 建議方案
若你想要訓練能「不中斷地繼續」,同時保持一定辨識能力:
方案 | 穩定性 | 分辨能力 | 備註 |
---|---|---|---|
SafeOutputLayer(選項 1) | ✅✅✅ | ✅✅ | 自動替代 NaN,訓練不中斷 |
用 mean(features) 代替 logits(選項 2) | ✅✅✅ | ✅ | 超穩但判別弱 |
動態重建(選項 3) | ✅✅ | ✅✅✅ | 不建議頻繁使用,會造成不穩定 |
需要的話,我可以幫你整合成完整的 Discriminator class(含 fallback/自動修復)讓你直接用,方便 debug 和未來擴展。要嗎?
好的!以下是一個 穩定且具容錯機制的 Discriminator
全套實作,結合了:
- ✅ 安全輸出層(
SafeOutputLayer
):避免 NaN 炸掉訓練。 - ✅ 自動重建機制(可選):如果持續 NaN,可強制重建層參數。
- ✅ 完整 forward 錯誤監控。
- ✅ 支援 blur、分類分支等你原本的設計。
✅ 穩定版 Discriminator with NaN Protection
import torch
import torch.nn as nn
import torchvision.transforms as T
class SafeOutputLayer(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = nn.utils.spectral_norm(
nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1)
)
self.act = nn.Tanh()
def reset_parameters(self):
print("🔁 Resetting output conv layer due to NaN.")
nn.init.kaiming_normal_(self.conv.weight, a=0.2)
if self.conv.bias is not None:
nn.init.constant_(self.conv.bias, 0)
def forward(self, x):
out = self.act(self.conv(x))
if torch.isnan(out).any() or torch.isinf(out).any():
self.reset_parameters()
out = self.act(self.conv(x.detach())) # avoid backprop through NaN
print("⚠️ NaN detected in final layer output — returning reset output.")
return out
class Discriminator(nn.Module):
def __init__(self, input_nc, embedding_num, ndf=64, norm_layer=nn.BatchNorm2d, blur=False):
super().__init__()
use_bias = norm_layer != nn.BatchNorm2d
kw = 4
padw = 1
sequence = [
nn.utils.spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw)),
nn.SiLU(inplace=True)
]
nf_mult = 1
for n in range(1, 4):
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
conv = nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=2, padding=padw, bias=use_bias)
sequence += [
nn.utils.spectral_norm(conv),
norm_layer(ndf * nf_mult),
nn.SiLU(inplace=True)
]
self.model = nn.Sequential(*sequence)
self.output_conv = SafeOutputLayer(ndf * nf_mult)
self.category_pool = nn.AdaptiveAvgPool2d((4, 4))
self.category_fc = nn.Linear(ndf * nf_mult * 4 * 4, embedding_num)
self.blur = blur
if blur:
self.gaussian_blur = T.GaussianBlur(kernel_size=3, sigma=1.0)
def forward(self, input):
if self.blur:
input = self.gaussian_blur(input)
features = self.model(input)
binary_logits = self.output_conv(features)
pooled = self.category_pool(features).view(input.size(0), -1)
category_logits = self.category_fc(pooled)
return binary_logits, category_logits
📌 如何使用
你只需要把你原本的 Discriminator
換成這一版,所有爆炸的層都會自動修復、避免中斷訓練:
netD = Discriminator(input_nc=3, embedding_num=10, ndf=64)
🧪 附加(選配)
你也可以讓 SafeOutputLayer
回傳固定常數(如全 0 或 -1)而不是重新 forward,會更保守但極穩定。如果你想這樣我也可以改給你。
需要我也幫你把 NaN 檢查加進 optimize_parameters()
流程裡嗎?會讓 debug 更容易 👀