簡化字型風格轉換 UNetGenerator 的設計

UNetGeneratorforward 方法中,之所以會有 (fake_B, encoded_real_A, encoded_fake_B) 三種回傳結果,主要是因為:

  1. fake_B:最終產生的合成字型影像(從 real_A 轉換成 Zen Maru Gothic 風格的影像)。
  2. encoded_real_Areal_A 經過 UNet Encoder 後的潛在向量(latent code),用於保持風格一致性與對比學習。
  3. encoded_fake_Bfake_B 經過相同 Encoder 後的潛在向量,用來計算 Lconst_penalty(對應 self.mse(self.encoded_real_A, self.encoded_fake_B)),以確保生成結果與輸入的內容保持一致。

簡化並維持原功能的方法

可以改寫 UNetGeneratorforward 方法,將 encoded_real_Aencoded_fake_B 的計算邏輯分離,讓 forward 只輸出 fake_B,但仍然能在 Zi2ZiModel 訓練過程中獲取 encoded_real_Aencoded_fake_B,以減少 UNet 內部的複雜度。

class UNetGenerator(nn.Module):
    def __init__(self, input_nc=3, output_nc=3, embedding_num=40, embedding_dim=128, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
        super(UNetGenerator, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(input_nc, ngf, kernel_size=4, stride=2, padding=1),
            norm_layer(ngf),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(ngf, ngf * 2, kernel_size=4, stride=2, padding=1),
            norm_layer(ngf * 2),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(ngf * 2, embedding_dim, kernel_size=4, stride=2, padding=1),
            norm_layer(embedding_dim),
            nn.LeakyReLU(0.2, True),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(embedding_dim, ngf * 2, kernel_size=4, stride=2, padding=1),
            norm_layer(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, ngf, kernel_size=4, stride=2, padding=1),
            norm_layer(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, output_nc, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )
        self.embedder = nn.Embedding(embedding_num, embedding_dim)

    def encode(self, x):
        return self.encoder(x).view(x.shape[0], -1)

    def forward(self, x, style_or_label=None):
        if style_or_label is not None and 'LongTensor' in style_or_label.type():
            style = self.embedder(style_or_label).view(style_or_label.shape[0], -1, 1, 1)
        else:
            style = style_or_label

        encoded_x = self.encoder(x)
        if style is not None:
            encoded_x = torch.cat([encoded_x, style], dim=1)

        fake_B = self.decoder(encoded_x)
        return fake_B

如何應用這個簡化的模型

Zi2ZiModelforward 方法中,分離 encoded_real_Aencoded_fake_B 的計算:

def forward(self):
    self.fake_B = self.netG(self.real_A, self.labels)
    self.encoded_real_A = self.netG.encode(self.real_A)
    self.encoded_fake_B = self.netG.encode(self.fake_B)

這樣可以:

  1. 簡化 UNetGenerator,使 forward 只負責 fake_B 的生成,編碼部分則獨立為 encode 方法,讓網絡結構更清晰。
  2. 減少不必要的回傳值,讓 forward 保持單一功能,而 encoded_real_Aencoded_fake_B 仍可用於 Lconst_penalty 計算。

這樣的設計應該能讓 UNetGenerator 在維持原有功能的同時更加直觀,並有助於提升學習穩定性! 🚀

這裡的 UNetGenerator 仍然使用 UnetSkipConnectionBlock 進行層級式的影像生成。由於你希望簡化回傳的結果,同時保留風格轉換的核心功能,我們可以對 forward 方法進行調整,讓它只回傳 fake_B,而 encoded_real_Aencoded_fake_B 的部分可以在 UnetSkipConnectionBlockinnermost 層做處理。以下是簡化後的 UNetGenerator 版本:


更新後的 UNetGenerator

  • 主要改動
    • forward 方法調整,確保回傳結構簡潔
    • UnetSkipConnectionBlockinnermost 層處理 encoded_real_A
    • encoded_fake_Bfake_B 再次輸入 netG 來獲得

調整後的 UNetGenerator

class UNetGenerator(nn.Module):
    def __init__(self, input_nc=3, output_nc=3, num_downs=8, ngf=64, embedding_num=40, embedding_dim=128,
                 norm_layer=nn.BatchNorm2d, use_dropout=False, self_attention=False, residual_block=False, blur=False):
        super(UNetGenerator, self).__init__()
        
        # 最底層(innermost),負責風格嵌入處理
        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None,
                                             norm_layer=norm_layer, layer=1, embedding_dim=embedding_dim, 
                                             self_attention=self_attention, residual_block=residual_block, blur=blur, innermost=True)

        # 中間層
        for index in range(num_downs - 5):
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, 
                                                 norm_layer=norm_layer, layer=index+2, use_dropout=use_dropout, 
                                                 self_attention=self_attention, residual_block=residual_block, blur=blur)

        # 上層(恢復影像解析度)
        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, 
                                             norm_layer=norm_layer, layer=5, self_attention=self_attention, 
                                             residual_block=residual_block, blur=blur)
        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, 
                                             norm_layer=norm_layer, layer=6, self_attention=self_attention, 
                                             residual_block=residual_block, blur=blur)
        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, 
                                             norm_layer=norm_layer, layer=7, self_attention=self_attention, 
                                             residual_block=residual_block, blur=blur)

        # 最外層(outermost),負責輸出最終影像
        self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, 
                                             norm_layer=norm_layer, layer=8, self_attention=self_attention, 
                                             residual_block=residual_block, blur=blur, outermost=True)

        self.embedder = nn.Embedding(embedding_num, embedding_dim)

    def forward(self, x, style_or_label=None):
        """ 生成 fake_B,並獲取 encoded_real_A """
        if style_or_label is not None and 'LongTensor' in style_or_label.type():
            style = self.embedder(style_or_label)
        else:
            style = style_or_label
        
        # `self.model` 會返回 `fake_B` 和 `encoded_real_A`
        fake_B, encoded_real_A = self.model(x, style)
        
        return fake_B, encoded_real_A

調整後的 UnetSkipConnectionBlock

  • 主要改動
    • innermost 層回傳 encoded_real_A
    • outermost 層處理 fake_B
    • 內部層(非 innermost / outermost)確保 skip connection 正確傳遞資訊
class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None, 
                 norm_layer=nn.BatchNorm2d, layer=0, embedding_dim=128, 
                 use_dropout=False, self_attention=False, residual_block=False, blur=False,
                 outermost=False, innermost=False):
        super(UnetSkipConnectionBlock, self).__init__()
        
        self.outermost = outermost
        self.innermost = innermost

        if input_nc is None:
            input_nc = outer_nc
        
        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.ReLU(True)
        upnorm = norm_layer(outer_nc)

        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1)
            self.down = nn.Sequential(downconv)
            self.up = nn.Sequential(uprelu, upconv, nn.Tanh())
        
        elif innermost:
            upconv = nn.ConvTranspose2d(inner_nc + embedding_dim, outer_nc, kernel_size=4, stride=2, padding=1)
            self.down = nn.Sequential(downrelu, downconv)
            self.up = nn.Sequential(uprelu, upconv, upnorm)
        
        else:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1)
            self.down = nn.Sequential(downrelu, downconv, downnorm)
            self.up = nn.Sequential(uprelu, upconv, upnorm)
            if use_dropout:
                self.up.add_module("dropout", nn.Dropout(0.5))

        self.submodule = submodule

    def forward(self, x, style=None):
        """ 定義 U-Net 跳接連結機制 """

        if self.innermost:
            encoded = self.down(x)
            if style is not None:
                encoded = torch.cat([style.view(style.shape[0], style.shape[1], 1, 1), encoded], dim=1)
            decoded = self.up(encoded)
            return torch.cat([x, decoded], 1), encoded.view(x.shape[0], -1)

        elif self.outermost:
            encoded = self.down(x)
            if self.submodule:
                sub_output, encoded_real_A = self.submodule(encoded, style)
            else:
                sub_output = encoded
            decoded = self.up(sub_output)
            return decoded, encoded_real_A

        else:
            encoded = self.down(x)
            if self.submodule:
                sub_output, encoded_real_A = self.submodule(encoded, style)
            else:
                sub_output = encoded
            decoded = self.up(sub_output)
            return torch.cat([x, decoded], 1), encoded_real_A

優化後的回傳結果

  • 簡化回傳結果
    • UNetGenerator.forward(x, style_or_label) 現在只回傳 (fake_B, encoded_real_A)
    • fake_B 是生成的風格字體
    • encoded_real_A 是原始字體的嵌入表示
  • 簡潔且保持功能完整
    • 仍保留 encoded_real_A 來計算 contrastive loss
    • encoded_fake_B 可由 netG(fake_B) 獲取,而不必在 forward 時回傳

這樣的簡化方式應該能達到你的需求,同時維持 Zi2ZiModel 的學習效果。這樣應該能幫助你的字型風格轉換更穩定!🚀


gemini

為了簡化 UNetGenerator 的設計,同時盡量維持原有功能,我們可以調整 UnetSkipConnectionBlockforward 方法,使其只返回必要的輸出,同時將編碼特徵的提取集中在 UNetGeneratorforward 方法中。

修改後的 UNetGeneratorUnetSkipConnectionBlock

Python

import torch
import torch.nn as nn
import functools

class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc, input_nc=None,
                 submodule=None, outermost=False, innermost=False, embedding_dim=128, norm_layer=nn.BatchNorm2d,
                 use_dropout=False):
        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost
        self.innermost = innermost
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d
        if input_nc is None:
            input_nc = outer_nc
        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
                            stride=2, padding=1, bias=use_bias)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.ReLU(True)
        upnorm = norm_layer(outer_nc)

        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1)
            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]

        elif innermost:
            upconv = nn.ConvTranspose2d(inner_nc + embedding_dim, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1, bias=use_bias)
            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]

        else:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1, bias=use_bias)
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]

            if use_dropout:
                up = up + [nn.Dropout(0.5)]

        self.submodule = submodule
        self.down = nn.Sequential(*down)
        self.up = nn.Sequential(*up)

    def forward(self, x, style=None):
        if self.innermost:
            encode = self.down(x)
            if style is None:
                return encode
            enc = torch.cat([style.view(style.shape[0], style.shape[1], 1, 1), encode], 1)
            dec = self.up(enc)
            return torch.cat([x, dec], 1)
        elif self.outermost:
            enc = self.down(x)
            if style is None:
                return self.submodule(enc)
            sub = self.submodule(enc, style)
            dec = self.up(sub)
            return dec
        else:
            enc = self.down(x)
            if style is None:
                return self.submodule(enc)
            sub = self.submodule(enc, style)
            dec = self.up(sub)
            return torch.cat([x, dec], 1)

class UNetGenerator(nn.Module):
    def __init__(self, input_nc=3, output_nc=3, num_downs=8, ngf=64, embedding_num=40, embedding_dim=128,
                 norm_layer=nn.BatchNorm2d, use_dropout=False):
        super(UNetGenerator, self).__init__()
        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer,
                                             innermost=True, embedding_dim=embedding_dim)
        for _ in range(num_downs - 5):
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block,
                                                 norm_layer=norm_layer, use_dropout=use_dropout)
        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block,
                                             norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block,
                                             norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block,
                                             norm_layer=norm_layer)
        self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block,
                                             outermost=True,
                                             norm_layer=norm_layer)
        self.embedder = nn.Embedding(embedding_num, embedding_dim)
        self.innermost_encoder = None

    def forward(self, x, style_or_label=None):
        if style_or_label is not None and 'LongTensor' in style_or_label.type():
            out = self.model(x, self.embedder(style_or_label))
        else:
            out = self.model(x, style_or_label)

        # 提取編碼特徵
        self.innermost_encoder = self.model.submodule.submodule.submodule.submodule.submodule.submodule.submodule.down(x) # 假設num_downs=8

        return out, self.innermost_encoder, self.model(out, None).view(out.shape[0], -1) #計算fake_B的encode

修改說明:

  1. UnetSkipConnectionBlock
    • forward 方法簡化,只返回連接後的特徵或解碼後的輸出。
    • 移除 encode 的直接返回。
  2. UNetGenerator
    • forward 方法中,通過 self.model.submodule...down(x) 提取 innermost 區塊的編碼特徵,並將其保存到 self.innermost_encoder
    • forward 方法返回 (out, self.innermost_encoder, self.model(out, None).view(out.shape[0], -1)),對應 (fake_B, encoded_real_A, encoded_fake_B)
    • encoded_fake_B 通過將生成結果 out 重新輸入到 self.model 中,且風格輸入為 None,來得到編碼特徵。

優點:

  • 簡化 UnetSkipConnectionBlock 減少了 forward 方法的複雜性,提高了代碼的可讀性。
  • 集中特徵提取: 將編碼特徵的提取集中在 UNetGenerator 中,使得代碼結構更清晰。
  • 保持原有功能: 仍然可以得到 fake_Bencoded_real_Aencoded_fake_B,用於後續的損失函數計算。

注意事項:

  • self.model.submodule...down(x) 的深度需要根據 num_downs 的值進行調整。
  • 這種簡化方法假設 innermost 區塊的位置是固定的。如果網絡結構發生變化,需要相應地修改代碼。
  • 由於encoded_fake_B是將生成結果重新輸入生成器中,這會增加一些額外的計算量。

這個修改後的版本在保持原有功能的基礎上,提高了代碼的清晰度和可讀性。

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *