實際測試, 的確產生的 .pth 檔案變小很多, 原本是 215MB, 縮小到變成 157MB, 縮小為 73%.
測試在資料量極少的情況下(11筆訓練資料), 沒有標準用法情況下, 只訓練 4個 Epoch, 不存在的相似部件可以推論出模糊形狀, 但encoded.mean(dim=(2, 3)) 就是一片黑。
也許, 增加訓練資料與訓練次數, 應該就可以更明顯地比較出其差異.
你這段 UNet 結構中在 innermost
層會將 bottleneck 的特徵 encoded
攤平成向量並回傳,這可能會佔用大量記憶體。為了改善這個問題,我們可以改成 全域平均池化(Global Average Pooling, GAP) 的方式來傳遞風格表徵,大幅降低記憶體使用量,同時仍保留風格訊息。
✅ 修改說明:
- 原本是
encoded.contiguous().view(B, -1)
,會產生非常長的向量(B × C × H × W)。 - 改成
encoded.mean(dim=(2, 3))
,只保留 B × C 的全域平均向量,能更有效地保留風格資訊並節省資源。
✅ 修改後的 UnetSkipConnectionBlock.forward
(只修改 forward
的部分)如下:
def forward(self, x, style=None):
encoded = self.down(x)
if self.attn_block:
encoded = self.attn_block(encoded)
if self.innermost:
if hasattr(self, 'transformer_block'):
encoded = self.transformer_block(encoded)
if style is not None:
encoded = self.film(encoded, style)
decoded = self.up(encoded)
if decoded.shape[2:] != x.shape[2:]:
decoded = F.interpolate(decoded, size=x.shape[2:], mode='bilinear', align_corners=False)
if self.res_skip:
decoded = self.res_skip(decoded)
# 使用全域平均池化來取得 bottleneck 表徵
style_feat = encoded.mean(dim=(2, 3)) # shape: (B, C)
return x + decoded, style_feat
else:
sub_output, encoded_real_A = self.submodule(encoded, style)
decoded = self.up(sub_output)
if decoded.shape[2:] != x.shape[2:]:
decoded = F.interpolate(decoded, size=x.shape[2:], mode='bilinear', align_corners=False)
if self.res_skip:
decoded = self.res_skip(decoded)
if self.outermost:
return decoded, encoded_real_A
else:
return x + decoded, encoded_real_A
如果你還需要將這個 style_feat
傳到某個 style embedding 或分類器中(例如給 CategoryLoss 使用),它會變得非常方便處理,且不容易記憶體爆炸。
需要我也幫你改上游接收這個 style feature 的邏輯嗎?或者整合進 Style Encoder?
你這份 UNetGenerator
程式碼我已幫你修改好,使其支援 全域平均池化風格表徵傳遞,並讓上游能從 bottleneck 壓縮表示中推論出 style embedding 和 style 分類輸出,以下是修改重點與完整程式碼:
✅ 修改重點:
- 從
UnetSkipConnectionBlock
傳出的encoded
為全域平均池化後的 bottleneck 向量(前面你已改過)。 - 新增
style_classifier
及style_embedder
模組:style_classifier
: 將 bottleneck 表徵映射為分類輸出。style_embedder
: 將 bottleneck 表徵映射為 style embedding。
- 更新
forward()
與encode()
:forward
回傳(fake_B, style_cls_pred, style_emb)
。encode
只回傳 style embedding。
🔧 修改後程式碼:
class UNetGenerator(nn.Module):
def __init__(self, input_nc=1, output_nc=1, num_downs=8, ngf=64,
embedding_num=40, embedding_dim=128,
norm_layer=nn.InstanceNorm2d, use_dropout=False,
self_attention=False, blur=False, attention_type='linear',
attn_layers=None):
super(UNetGenerator, self).__init__()
if attn_layers is None:
attn_layers = []
unet_block = UnetSkipConnectionBlock(
ngf * 8, ngf * 8, submodule=None, innermost=True,
norm_layer=norm_layer, layer=1, embedding_dim=embedding_dim,
use_transformer=True, self_attention=self_attention,
blur=blur, attention_type=attention_type, attn_layers=attn_layers)
for i in range(num_downs - 5):
unet_block = UnetSkipConnectionBlock(
ngf * 8, ngf * 8, submodule=unet_block,
norm_layer=norm_layer, layer=i + 2, use_dropout=use_dropout,
self_attention=self_attention, blur=blur,
attention_type=attention_type, attn_layers=attn_layers)
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, submodule=unet_block, norm_layer=norm_layer, layer=5,
self_attention=self_attention, blur=blur, attention_type=attention_type, attn_layers=attn_layers)
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, submodule=unet_block, norm_layer=norm_layer, layer=6,
self_attention=self_attention, blur=blur, attention_type=attention_type, attn_layers=attn_layers)
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, submodule=unet_block, norm_layer=norm_layer, layer=7,
self_attention=self_attention, blur=blur, attention_type=attention_type, attn_layers=attn_layers)
self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block,
outermost=True, norm_layer=norm_layer, layer=8,
self_attention=self_attention, blur=blur,
attention_type=attention_type, attn_layers=attn_layers)
self.embedder = nn.Embedding(embedding_num, embedding_dim)
# ↓ 新增 bottleneck 表徵的後處理模組
self.style_classifier = nn.Sequential(
nn.LayerNorm(ngf * 8),
nn.Linear(ngf * 8, embedding_num)
)
self.style_embedder = nn.Sequential(
nn.LayerNorm(ngf * 8),
nn.Linear(ngf * 8, embedding_dim),
nn.SiLU()
)
def _prepare_style(self, style_or_label):
return self.embedder(style_or_label) if style_or_label is not None and 'LongTensor' in style_or_label.type() else style_or_label
def forward(self, x, style_or_label=None):
style = self._prepare_style(style_or_label)
fake_B, style_feat = self.model(x, style) # style_feat: (B, C)
style_cls_pred = self.style_classifier(style_feat) # 分類
style_emb = self.style_embedder(style_feat) # 嵌入向量
return fake_B, style_cls_pred, style_emb
def encode(self, x, style_or_label=None):
style = self._prepare_style(style_or_label)
_, style_feat = self.model(x, style)
return self.style_embedder(style_feat)
這樣設計不僅大幅降低 memory 使用,也讓 style 特徵可重複使用於分類或 embedding 提取。
有需要的話也可以幫你補上訓練 loss 或正規化建議!需要嗎?
你這份 Zi2ZiModel
和 Zi2ZiLoss
的架構整體邏輯已經非常完整,接下來我會協助你根據「新的架構需求:UNet 中使用 FiLM 與 LoRA Attention」來進行必要調整,包含:
- 確認 Generator(UNet)是否支援 FiLM 與 LoRA Attention。
- 修正 Generator 的
encode
方法邏輯。 - Loss 模組部分保持不變,只要 Generator 的輸出與
encode
方式正確,損失計算即可套用現有模組。
✅ 你可以這樣調整 UNetGenerator
:
這部分是架構改動的關鍵 —— FiLM
與 LoRA Attention
要體現在 UNet 裡,我這邊提供一個基本架構方向供你融入現有 UNet 中(你可能已經部分實作,只需插入以下邏輯):
class FiLMBlock(nn.Module):
def __init__(self, in_channels, embedding_dim):
super().__init__()
self.gamma = nn.Linear(embedding_dim, in_channels)
self.beta = nn.Linear(embedding_dim, in_channels)
def forward(self, x, style_embedding):
gamma = self.gamma(style_embedding).unsqueeze(-1).unsqueeze(-1)
beta = self.beta(style_embedding).unsqueeze(-1).unsqueeze(-1)
return x * gamma + beta
然後在 UNetGenerator
的 forward 內,每一層你想調整的地方加上:
x = self.down1(x)
x = self.film1(x, style_embedding) # 假設你定義了 self.film1 = FiLMBlock(...)
✅ 確認 Generator forward()
和 encode()
方法:
目前 self.netG(self.real_A, self.labels)
回傳 (fake_B, encoded_real_A)
但 self.netG.encode()
方法你需要確認是從哪一層抽特徵進行編碼。
一種實作是這樣:
class UNetGenerator(nn.Module):
...
def forward(self, x, label_idx):
style_emb = self.embedding(label_idx) # label_idx shape: (B,)
out, features = self.unet_forward(x, style_emb)
return out, features
def encode(self, x, label_idx):
style_emb = self.embedding(label_idx)
_, features = self.unet_forward(x, style_emb)
return features
✅ Zi2ZiLoss
損失函數模組目前可以維持不變
只要 self.encoded_real_A
和 self.encoded_fake_B
是你要進行 content consistency 的中間層特徵,就能直接套用:
const_loss = self.const(encoded_real_A, encoded_fake_B) * self.lambda_const
你也可以考慮在 UNet
中傳出多層特徵,然後計算多層一致性損失(multi-scale const loss)——我可以幫你加這段也沒問題。
✅ 建議:forward()
改寫為帶入 label 後可共用 encode
def forward(self, input, label):
style_embedding = self.embedding(label)
out, features = self.unet_forward(input, style_embedding)
return out, features
def encode(self, input, label):
_, features = self.forward(input, label)
return features
🔧 如果你需要,我可以協助你:
- 完整改寫
UNetGenerator
,讓它支援 FiLM + LoRA Attention。 - 把
encode
做成多層 feature return。 - 套入 LoRA Attention(比如 Linear Attention + LoRA 融合)。
你目前的損失架構其實非常成熟,只要 Generator 支援新的 style modulation 機制,整體應該會 smooth 運作。
需要我幫你一起重構 UNetGenerator
嗎?你目前的 FiLM 與 Attention 層數設定有想法嗎?例如你希望:
- 在哪些層插入 FiLM?
- 在哪些層插入 Self-Attention / LoRA Attention?
這樣我可以直接幫你整合出來。
你現在的 Zi2ZiModel.forward()
會回傳 self.style_pred
,這代表你已經在 Generator 裡實作了 style classification 預測(例如針對目標字體 label 做分類)。如果你想 擴充 style_pred
的用途,讓它能用在 Loss 函數中,最常見的應用有:
✅ 擴充 style_pred 的用途建議
我們可以讓 style_pred
用來做以下幾件事:
- Style Classification Loss(Style Supervision)
- 讓 Generator 學會輸出對應的 label 分類。
- 加入分類損失
CrossEntropyLoss(style_pred, labels)
作為一種對 style embedding 的監督。
- 用於風格一致性(Style Consistency)
- 可考慮在不同樣本下讓
style_pred
一致,例如 triplet loss / contrastive loss,但這屬於擴展應用。
- 可考慮在不同樣本下讓
在你現有架構下,最自然的是加上第 1 點,即 style classification 損失。
🔧 實作步驟:在 Loss 加入 style classification loss
✅ 1. 在 Zi2ZiLoss
初始化中加入:
self.style_cls_loss = nn.CrossEntropyLoss().to(device)
self.lambda_style_cls = 1 # 可以調整這個係數
你也可以讓這個係數作為建構參數傳入。
✅ 2. 在 backward_G()
中加入:
修改函數簽名:
def backward_G(self, real_A, real_B, fake_B, encoded_real_A, encoded_fake_B, labels, style_pred):
然後新增 style classification loss 計算:
style_cls_loss = self.style_cls_loss(style_pred, labels) * self.lambda_style_cls
加總進 total_G_loss
:
total_G_loss = g_loss_adv + const_loss + l1_loss + cat_loss + fm_loss + perc_loss + style_cls_loss
更新回傳內容:
return total_G_loss, {
'const_loss': const_loss,
'l1_loss': l1_loss,
'g_adv': g_loss_adv,
'cat_loss': cat_loss,
'fm_loss': fm_loss,
'perceptual_loss': perc_loss,
'style_cls_loss': style_cls_loss,
}
✅ 3. 修改 Zi2ZiModel.optimize_parameters()
呼叫方式
g_loss, losses = self.loss_module.backward_G(
self.real_A, self.real_B, self.fake_B,
self.encoded_real_A, self.encoded_fake_B,
self.labels, self.style_pred # 多傳入 style_pred
)
✅ 4. 補充:UNet 中 style_pred
如何取得?
如果你還沒實作,可以這樣在 UNetGenerator
的 forward
結尾加上分類頭:
self.style_classifier = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
nn.Linear(ngf * 8, embedding_num) # ngf * 8 是 bottleneck 的 channel 數
)
...
def forward(self, x, label_idx):
style_embedding = self.embedding(label_idx)
out, features = self.unet_forward(x, style_embedding)
# 分類頭 (bottleneck feature)
style_pred = self.style_classifier(features)
return out, style_pred, features
✅ 最後總結:
你擴充 style_pred
用法的最佳方式是:
- 在 Generator 中讓它輸出對應 style 類別預測。
- 在 Loss 中加入 CrossEntropyLoss 作為
style_cls_loss
。 - 這樣不只幫助風格對齊,還可作為 style embedding 的監督信號,能強化 FiLM modulation 的穩定性。
需要我幫你補上 UNet 的 style_pred
分類頭部分嗎?還是你已經實作了,只是想擴充在 Loss 上?