為了讓你的 UNet 字型風格轉換模型在低 GPU 記憶體環境下也能運行,我幫你設計了一個 低記憶體版本,同時仍保留細節敏感度的設計重點(如筆劃邊緣),具體策略如下:
✅ 調整策略總覽
- 減少中間層 channel 數量:
- 將
ngf
(基礎 channel 數)從 64 調低為 32 或 16。 - 每層的通道擴張不超過
ngf * 4
,避免爆掉。
- 將
- 保留
(3,3)
小 kernel for fine detail:- 每層仍使用
Conv2d(kernel_size=3, stride=2, padding=1)
。 - 比起
(4,4)
kernel,更省參數、效果佳。
- 每層仍使用
- Decoder 後段使用 stride=1 的上採樣(避免解析度損失):
- 特別是輸出前 1-2 層,改成
ConvTranspose2d(kernel_size=3, stride=1, padding=1)
。
- 特別是輸出前 1-2 層,改成
- ResSkip、Self-Attention、Transformer 僅在 innermost 層保留(選擇性開啟):
- 可省大量顯存,但保留語義豐富表示能力。
- 不使用 dropout(除非你遇到 overfitting)。
✅ 精簡版 UNet 結構草圖(核心段落)
# 精簡過的 channel 數量
ngf = 32
# Down: 使用較小 channel 數 + (3,3) conv
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=3, stride=2, padding=1)
# Up: 改用 stride=1 的 convtranspose 避免 over-upsample
if is_final_decoder_layer:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=3, stride=1, padding=1)
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=3, stride=2, padding=1)
✅ 建議:具體層數設計
層次 | in_channels → out_channels | stride | conv_kernel | notes |
---|---|---|---|---|
outermost | 1 → 32 | 2 | (3,3) | |
encoder 1~2 | 32 → 64 → 128 | 2 | (3,3) | downsampling |
bottleneck | 128 | – | (3,3)x2 | 可選 attention / ResSkip |
decoder 1 | 128 → 64 | 2 | (3,3) | ConvTranspose2d |
decoder 2 | 64 → 32 | 1 | (3,3) | ConvTranspose2d,避免 overscale |
output | 32 → 1 | – | (3,3) | 再接 Tanh |
✅ 可額外加分技巧
- 移除中間層 norm_layer(或換成
GroupNorm(4)
):可省一點記憶體。 - 將
style_fc
的維度從 128 降到 64。 - 在推論時加
with torch.no_grad():
和model.eval()
,顯著降低記憶體。
🔧 是否需要我幫你直接產生一份完整、可執行的精簡版程式碼?只需告訴我你想保留的功能(如 attention、style embedding…),我可以立刻為你組好。
import torch
import torch.nn as nn
import torch.nn.functional as F
class ResSkip(nn.Module):
def __init__(self, in_channels, out_channels):
super(ResSkip, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.SiLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
)
def forward(self, x):
return x + self.conv(x)
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None,
norm_layer=nn.InstanceNorm2d, layer=0, outermost=False, innermost=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
self.innermost = innermost
use_bias = norm_layer != nn.BatchNorm2d
if input_nc is None:
input_nc = outer_nc
if outermost:
down = [nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1)]
up = [nn.SiLU(inplace=True),
nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=3, stride=1, padding=1),
nn.Tanh()]
elif innermost:
down = [nn.SiLU(inplace=True),
nn.Conv2d(input_nc, inner_nc, kernel_size=3, stride=2, padding=1)]
up = [nn.SiLU(inplace=True),
nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=3, stride=1, padding=1),
norm_layer(outer_nc)]
else:
down = [nn.SiLU(inplace=True),
nn.Conv2d(input_nc, inner_nc, kernel_size=3, stride=2, padding=1),
norm_layer(inner_nc)]
up = [nn.SiLU(inplace=True),
nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=3, stride=1, padding=1),
norm_layer(outer_nc)]
self.down = nn.Sequential(*down)
self.submodule = submodule
self.up = nn.Sequential(*up)
self.res_skip = ResSkip(outer_nc, outer_nc) if not outermost and layer >= 4 else None
def forward(self, x):
if self.innermost:
x1 = self.down(x)
x2 = self.up(x1)
if self.res_skip:
x2 = self.res_skip(x2)
return torch.cat([x, x2], 1)
else:
x1 = self.down(x)
x2 = self.submodule(x1)
x3 = self.up(x2)
x3 = F.interpolate(x3, size=x.shape[2:], mode='bilinear', align_corners=False)
if self.res_skip:
x3 = self.res_skip(x3)
if self.outermost:
return x3
else:
return torch.cat([x, x3], 1)
class UNetGenerator(nn.Module):
def __init__(self, input_nc=1, output_nc=1, num_downs=6, ngf=32, norm_layer=nn.InstanceNorm2d):
super(UNetGenerator, self).__init__()
# innermost
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, norm_layer=norm_layer, innermost=True, layer=1)
# mid layers
for i in range(num_downs - 4):
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, submodule=unet_block, norm_layer=norm_layer, layer=i+2)
# upsample layers
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, submodule=unet_block, norm_layer=norm_layer, layer=5)
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, submodule=unet_block, norm_layer=norm_layer, layer=6)
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, submodule=unet_block, norm_layer=norm_layer, layer=7)
# outermost
self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block,
norm_layer=norm_layer, outermost=True, layer=8)
def forward(self, x):
return self.model(x)
if __name__ == "__main__":
model = UNetGenerator(input_nc=1, output_nc=1)
dummy = torch.randn(1, 1, 128, 128)
out = model(dummy)
print(out.shape) # (1, 1, 128, 128)