全域平均池化(Global Average Pooling, GAP)解決UNet 結構中在 innermost 層會將 bottleneck 的特徵 encoded 攤平成向量並回傳會佔用大量記憶體

實際測試, 的確產生的 .pth 檔案變小很多, 原本是 215MB, 縮小到變成 157MB, 縮小為 73%.

測試在資料量極少的情況下(11筆訓練資料), 沒有標準用法情況下, 只訓練 4個 Epoch, 不存在的相似部件可以推論出模糊形狀, 但encoded.mean(dim=(2, 3)) 就是一片黑。

也許, 增加訓練資料與訓練次數, 應該就可以更明顯地比較出其差異.


你這段 UNet 結構中在 innermost 層會將 bottleneck 的特徵 encoded 攤平成向量並回傳,這可能會佔用大量記憶體。為了改善這個問題,我們可以改成 全域平均池化(Global Average Pooling, GAP) 的方式來傳遞風格表徵,大幅降低記憶體使用量,同時仍保留風格訊息。

✅ 修改說明:

  • 原本是 encoded.contiguous().view(B, -1),會產生非常長的向量(B × C × H × W)。
  • 改成 encoded.mean(dim=(2, 3)),只保留 B × C 的全域平均向量,能更有效地保留風格資訊並節省資源。

✅ 修改後的 UnetSkipConnectionBlock.forward(只修改 forward 的部分)如下:

def forward(self, x, style=None):
    encoded = self.down(x)
    if self.attn_block:
        encoded = self.attn_block(encoded)

    if self.innermost:
        if hasattr(self, 'transformer_block'):
            encoded = self.transformer_block(encoded)
        if style is not None:
            encoded = self.film(encoded, style)

        decoded = self.up(encoded)
        if decoded.shape[2:] != x.shape[2:]:
            decoded = F.interpolate(decoded, size=x.shape[2:], mode='bilinear', align_corners=False)
        if self.res_skip:
            decoded = self.res_skip(decoded)

        # 使用全域平均池化來取得 bottleneck 表徵
        style_feat = encoded.mean(dim=(2, 3))  # shape: (B, C)
        return x + decoded, style_feat
    else:
        sub_output, encoded_real_A = self.submodule(encoded, style)
        decoded = self.up(sub_output)
        if decoded.shape[2:] != x.shape[2:]:
            decoded = F.interpolate(decoded, size=x.shape[2:], mode='bilinear', align_corners=False)
        if self.res_skip:
            decoded = self.res_skip(decoded)
        if self.outermost:
            return decoded, encoded_real_A
        else:
            return x + decoded, encoded_real_A

如果你還需要將這個 style_feat 傳到某個 style embedding 或分類器中(例如給 CategoryLoss 使用),它會變得非常方便處理,且不容易記憶體爆炸。

需要我也幫你改上游接收這個 style feature 的邏輯嗎?或者整合進 Style Encoder?


你這份 UNetGenerator 程式碼我已幫你修改好,使其支援 全域平均池化風格表徵傳遞,並讓上游能從 bottleneck 壓縮表示中推論出 style embedding 和 style 分類輸出,以下是修改重點與完整程式碼:


✅ 修改重點:

  1. UnetSkipConnectionBlock 傳出的 encoded 為全域平均池化後的 bottleneck 向量(前面你已改過)。
  2. 新增 style_classifierstyle_embedder 模組
    • style_classifier: 將 bottleneck 表徵映射為分類輸出。
    • style_embedder: 將 bottleneck 表徵映射為 style embedding。
  3. 更新 forward()encode()
    • forward 回傳 (fake_B, style_cls_pred, style_emb)
    • encode 只回傳 style embedding。

🔧 修改後程式碼:

class UNetGenerator(nn.Module):
    def __init__(self, input_nc=1, output_nc=1, num_downs=8, ngf=64,
                 embedding_num=40, embedding_dim=128,
                 norm_layer=nn.InstanceNorm2d, use_dropout=False,
                 self_attention=False, blur=False, attention_type='linear',
                 attn_layers=None):
        super(UNetGenerator, self).__init__()

        if attn_layers is None:
            attn_layers = []

        unet_block = UnetSkipConnectionBlock(
            ngf * 8, ngf * 8, submodule=None, innermost=True,
            norm_layer=norm_layer, layer=1, embedding_dim=embedding_dim,
            use_transformer=True, self_attention=self_attention,
            blur=blur, attention_type=attention_type, attn_layers=attn_layers)

        for i in range(num_downs - 5):
            unet_block = UnetSkipConnectionBlock(
                ngf * 8, ngf * 8, submodule=unet_block,
                norm_layer=norm_layer, layer=i + 2, use_dropout=use_dropout,
                self_attention=self_attention, blur=blur,
                attention_type=attention_type, attn_layers=attn_layers)

        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, submodule=unet_block, norm_layer=norm_layer, layer=5,
                                             self_attention=self_attention, blur=blur, attention_type=attention_type, attn_layers=attn_layers)
        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, submodule=unet_block, norm_layer=norm_layer, layer=6,
                                             self_attention=self_attention, blur=blur, attention_type=attention_type, attn_layers=attn_layers)
        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, submodule=unet_block, norm_layer=norm_layer, layer=7,
                                             self_attention=self_attention, blur=blur, attention_type=attention_type, attn_layers=attn_layers)

        self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block,
                                             outermost=True, norm_layer=norm_layer, layer=8,
                                             self_attention=self_attention, blur=blur,
                                             attention_type=attention_type, attn_layers=attn_layers)

        self.embedder = nn.Embedding(embedding_num, embedding_dim)

        # ↓ 新增 bottleneck 表徵的後處理模組
        self.style_classifier = nn.Sequential(
            nn.LayerNorm(ngf * 8),
            nn.Linear(ngf * 8, embedding_num)
        )
        self.style_embedder = nn.Sequential(
            nn.LayerNorm(ngf * 8),
            nn.Linear(ngf * 8, embedding_dim),
            nn.SiLU()
        )

    def _prepare_style(self, style_or_label):
        return self.embedder(style_or_label) if style_or_label is not None and 'LongTensor' in style_or_label.type() else style_or_label

    def forward(self, x, style_or_label=None):
        style = self._prepare_style(style_or_label)
        fake_B, style_feat = self.model(x, style)  # style_feat: (B, C)

        style_cls_pred = self.style_classifier(style_feat)   # 分類
        style_emb = self.style_embedder(style_feat)          # 嵌入向量

        return fake_B, style_cls_pred, style_emb

    def encode(self, x, style_or_label=None):
        style = self._prepare_style(style_or_label)
        _, style_feat = self.model(x, style)
        return self.style_embedder(style_feat)

這樣設計不僅大幅降低 memory 使用,也讓 style 特徵可重複使用於分類或 embedding 提取。
有需要的話也可以幫你補上訓練 loss 或正規化建議!需要嗎?


你這份 Zi2ZiModelZi2ZiLoss 的架構整體邏輯已經非常完整,接下來我會協助你根據「新的架構需求:UNet 中使用 FiLM 與 LoRA Attention」來進行必要調整,包含:

  1. 確認 Generator(UNet)是否支援 FiLM 與 LoRA Attention。
  2. 修正 Generator 的 encode 方法邏輯。
  3. Loss 模組部分保持不變,只要 Generator 的輸出與 encode 方式正確,損失計算即可套用現有模組。

你可以這樣調整 UNetGenerator

這部分是架構改動的關鍵 —— FiLMLoRA Attention 要體現在 UNet 裡,我這邊提供一個基本架構方向供你融入現有 UNet 中(你可能已經部分實作,只需插入以下邏輯):

class FiLMBlock(nn.Module):
    def __init__(self, in_channels, embedding_dim):
        super().__init__()
        self.gamma = nn.Linear(embedding_dim, in_channels)
        self.beta = nn.Linear(embedding_dim, in_channels)

    def forward(self, x, style_embedding):
        gamma = self.gamma(style_embedding).unsqueeze(-1).unsqueeze(-1)
        beta = self.beta(style_embedding).unsqueeze(-1).unsqueeze(-1)
        return x * gamma + beta

然後在 UNetGenerator 的 forward 內,每一層你想調整的地方加上:

x = self.down1(x)
x = self.film1(x, style_embedding)  # 假設你定義了 self.film1 = FiLMBlock(...)

確認 Generator forward()encode() 方法:

目前 self.netG(self.real_A, self.labels) 回傳 (fake_B, encoded_real_A)
self.netG.encode() 方法你需要確認是從哪一層抽特徵進行編碼。

一種實作是這樣:

class UNetGenerator(nn.Module):
    ...
    def forward(self, x, label_idx):
        style_emb = self.embedding(label_idx)  # label_idx shape: (B,)
        out, features = self.unet_forward(x, style_emb)
        return out, features

    def encode(self, x, label_idx):
        style_emb = self.embedding(label_idx)
        _, features = self.unet_forward(x, style_emb)
        return features

Zi2ZiLoss 損失函數模組目前可以維持不變

只要 self.encoded_real_Aself.encoded_fake_B 是你要進行 content consistency 的中間層特徵,就能直接套用:

const_loss = self.const(encoded_real_A, encoded_fake_B) * self.lambda_const

你也可以考慮在 UNet 中傳出多層特徵,然後計算多層一致性損失(multi-scale const loss)——我可以幫你加這段也沒問題。


✅ 建議:forward() 改寫為帶入 label 後可共用 encode

def forward(self, input, label):
    style_embedding = self.embedding(label)
    out, features = self.unet_forward(input, style_embedding)
    return out, features

def encode(self, input, label):
    _, features = self.forward(input, label)
    return features

🔧 如果你需要,我可以協助你:

  • 完整改寫 UNetGenerator,讓它支援 FiLM + LoRA Attention。
  • encode 做成多層 feature return。
  • 套入 LoRA Attention(比如 Linear Attention + LoRA 融合)。

你目前的損失架構其實非常成熟,只要 Generator 支援新的 style modulation 機制,整體應該會 smooth 運作。

需要我幫你一起重構 UNetGenerator 嗎?你目前的 FiLM 與 Attention 層數設定有想法嗎?例如你希望:

  • 在哪些層插入 FiLM?
  • 在哪些層插入 Self-Attention / LoRA Attention?

這樣我可以直接幫你整合出來。


你現在的 Zi2ZiModel.forward() 會回傳 self.style_pred,這代表你已經在 Generator 裡實作了 style classification 預測(例如針對目標字體 label 做分類)。如果你想 擴充 style_pred 的用途,讓它能用在 Loss 函數中,最常見的應用有:


擴充 style_pred 的用途建議

我們可以讓 style_pred 用來做以下幾件事:

  1. Style Classification Loss(Style Supervision)
    • 讓 Generator 學會輸出對應的 label 分類。
    • 加入分類損失 CrossEntropyLoss(style_pred, labels) 作為一種對 style embedding 的監督。
  2. 用於風格一致性(Style Consistency)
    • 可考慮在不同樣本下讓 style_pred 一致,例如 triplet loss / contrastive loss,但這屬於擴展應用。

在你現有架構下,最自然的是加上第 1 點,即 style classification 損失


🔧 實作步驟:在 Loss 加入 style classification loss

✅ 1. 在 Zi2ZiLoss 初始化中加入:

self.style_cls_loss = nn.CrossEntropyLoss().to(device)
self.lambda_style_cls = 1  # 可以調整這個係數

你也可以讓這個係數作為建構參數傳入。

✅ 2. 在 backward_G() 中加入:

修改函數簽名:

def backward_G(self, real_A, real_B, fake_B, encoded_real_A, encoded_fake_B, labels, style_pred):

然後新增 style classification loss 計算:

style_cls_loss = self.style_cls_loss(style_pred, labels) * self.lambda_style_cls

加總進 total_G_loss

total_G_loss = g_loss_adv + const_loss + l1_loss + cat_loss + fm_loss + perc_loss + style_cls_loss

更新回傳內容:

return total_G_loss, {
    'const_loss': const_loss,
    'l1_loss': l1_loss,
    'g_adv': g_loss_adv,
    'cat_loss': cat_loss,
    'fm_loss': fm_loss,
    'perceptual_loss': perc_loss,
    'style_cls_loss': style_cls_loss,
}

✅ 3. 修改 Zi2ZiModel.optimize_parameters() 呼叫方式

g_loss, losses = self.loss_module.backward_G(
    self.real_A, self.real_B, self.fake_B,
    self.encoded_real_A, self.encoded_fake_B,
    self.labels, self.style_pred  # 多傳入 style_pred
)

✅ 4. 補充:UNet 中 style_pred 如何取得?

如果你還沒實作,可以這樣在 UNetGeneratorforward 結尾加上分類頭:

self.style_classifier = nn.Sequential(
    nn.AdaptiveAvgPool2d((1, 1)),
    nn.Flatten(),
    nn.Linear(ngf * 8, embedding_num)  # ngf * 8 是 bottleneck 的 channel 數
)

...

def forward(self, x, label_idx):
    style_embedding = self.embedding(label_idx)
    out, features = self.unet_forward(x, style_embedding)

    # 分類頭 (bottleneck feature)
    style_pred = self.style_classifier(features)

    return out, style_pred, features

✅ 最後總結:

你擴充 style_pred 用法的最佳方式是:

  • 在 Generator 中讓它輸出對應 style 類別預測。
  • 在 Loss 中加入 CrossEntropyLoss 作為 style_cls_loss
  • 這樣不只幫助風格對齊,還可作為 style embedding 的監督信號,能強化 FiLM modulation 的穩定性。

需要我幫你補上 UNet 的 style_pred 分類頭部分嗎?還是你已經實作了,只是想擴充在 Loss 上?


發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *