UNet 字型風格轉換模型在低記憶體版本

為了讓你的 UNet 字型風格轉換模型在低 GPU 記憶體環境下也能運行,我幫你設計了一個 低記憶體版本,同時仍保留細節敏感度的設計重點(如筆劃邊緣),具體策略如下:


✅ 調整策略總覽

  1. 減少中間層 channel 數量
    • ngf(基礎 channel 數)從 64 調低為 32 或 16。
    • 每層的通道擴張不超過 ngf * 4,避免爆掉。
  2. 保留 (3,3) 小 kernel for fine detail
    • 每層仍使用 Conv2d(kernel_size=3, stride=2, padding=1)
    • 比起 (4,4) kernel,更省參數、效果佳。
  3. Decoder 後段使用 stride=1 的上採樣(避免解析度損失):
    • 特別是輸出前 1-2 層,改成 ConvTranspose2d(kernel_size=3, stride=1, padding=1)
  4. ResSkip、Self-Attention、Transformer 僅在 innermost 層保留(選擇性開啟)
    • 可省大量顯存,但保留語義豐富表示能力。
  5. 不使用 dropout(除非你遇到 overfitting)

✅ 精簡版 UNet 結構草圖(核心段落)

# 精簡過的 channel 數量
ngf = 32

# Down: 使用較小 channel 數 + (3,3) conv
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=3, stride=2, padding=1)

# Up: 改用 stride=1 的 convtranspose 避免 over-upsample
if is_final_decoder_layer:
    upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=3, stride=1, padding=1)
else:
    upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=3, stride=2, padding=1)

✅ 建議:具體層數設計

層次in_channels → out_channelsstrideconv_kernelnotes
outermost1 → 322(3,3)
encoder 1~232 → 64 → 1282(3,3)downsampling
bottleneck128(3,3)x2可選 attention / ResSkip
decoder 1128 → 642(3,3)ConvTranspose2d
decoder 264 → 321(3,3)ConvTranspose2d,避免 overscale
output32 → 1(3,3)再接 Tanh

✅ 可額外加分技巧

  • 移除中間層 norm_layer(或換成 GroupNorm(4)):可省一點記憶體。
  • style_fc 的維度從 128 降到 64
  • 在推論時加 with torch.no_grad():model.eval(),顯著降低記憶體。

🔧 是否需要我幫你直接產生一份完整、可執行的精簡版程式碼?只需告訴我你想保留的功能(如 attention、style embedding…),我可以立刻為你組好。

import torch
import torch.nn as nn
import torch.nn.functional as F

class ResSkip(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResSkip, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.SiLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
        )

    def forward(self, x):
        return x + self.conv(x)

class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None,
                 norm_layer=nn.InstanceNorm2d, layer=0, outermost=False, innermost=False):
        super(UnetSkipConnectionBlock, self).__init__()

        self.outermost = outermost
        self.innermost = innermost
        use_bias = norm_layer != nn.BatchNorm2d

        if input_nc is None:
            input_nc = outer_nc

        if outermost:
            down = [nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1)]
            up = [nn.SiLU(inplace=True),
                  nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=3, stride=1, padding=1),
                  nn.Tanh()]
        elif innermost:
            down = [nn.SiLU(inplace=True),
                    nn.Conv2d(input_nc, inner_nc, kernel_size=3, stride=2, padding=1)]
            up = [nn.SiLU(inplace=True),
                  nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=3, stride=1, padding=1),
                  norm_layer(outer_nc)]
        else:
            down = [nn.SiLU(inplace=True),
                    nn.Conv2d(input_nc, inner_nc, kernel_size=3, stride=2, padding=1),
                    norm_layer(inner_nc)]
            up = [nn.SiLU(inplace=True),
                  nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=3, stride=1, padding=1),
                  norm_layer(outer_nc)]

        self.down = nn.Sequential(*down)
        self.submodule = submodule
        self.up = nn.Sequential(*up)
        self.res_skip = ResSkip(outer_nc, outer_nc) if not outermost and layer >= 4 else None

    def forward(self, x):
        if self.innermost:
            x1 = self.down(x)
            x2 = self.up(x1)
            if self.res_skip:
                x2 = self.res_skip(x2)
            return torch.cat([x, x2], 1)
        else:
            x1 = self.down(x)
            x2 = self.submodule(x1)
            x3 = self.up(x2)
            x3 = F.interpolate(x3, size=x.shape[2:], mode='bilinear', align_corners=False)
            if self.res_skip:
                x3 = self.res_skip(x3)
            if self.outermost:
                return x3
            else:
                return torch.cat([x, x3], 1)

class UNetGenerator(nn.Module):
    def __init__(self, input_nc=1, output_nc=1, num_downs=6, ngf=32, norm_layer=nn.InstanceNorm2d):
        super(UNetGenerator, self).__init__()

        # innermost
        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, norm_layer=norm_layer, innermost=True, layer=1)
        # mid layers
        for i in range(num_downs - 4):
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, submodule=unet_block, norm_layer=norm_layer, layer=i+2)
        # upsample layers
        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, submodule=unet_block, norm_layer=norm_layer, layer=5)
        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, submodule=unet_block, norm_layer=norm_layer, layer=6)
        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, submodule=unet_block, norm_layer=norm_layer, layer=7)

        # outermost
        self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block,
                                             norm_layer=norm_layer, outermost=True, layer=8)

    def forward(self, x):
        return self.model(x)

if __name__ == "__main__":
    model = UNetGenerator(input_nc=1, output_nc=1)
    dummy = torch.randn(1, 1, 128, 128)
    out = model(dummy)
    print(out.shape)  # (1, 1, 128, 128)

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *