UnetSkipConnectionBlock 減少 GPU 記憶體佔用

針對你這個 UnetSkipConnectionBlock 結構,若目標是:

  • 減少 GPU 記憶體佔用
  • 提升推論品質與效率
  • 避免不必要的 feature 堆疊傳遞

這裡是針對 layer 之間溝通(特別是 skip connection 和 feature 傳遞)的優化建議與理由:


1. 減少高層級 skip connection 的通道堆疊

目前:

return torch.cat([x, decoded], 1), ...

這會導致每一層都把上一層輸入與解碼後的輸出「通道疊加」,快速膨脹 feature map memory。

🔧 建議:

  • 只在底部幾層做 full skip concat(保留細節)
  • 中高層使用殘差式 skip 或簡單加法代替 concat

例如:

if self.layer < 4:
    out = torch.cat([x, decoded], 1)
else:
    out = x + decoded  # 或者只傳 decoded
return out, ...

2. 避免 F.interpolate 造成 memory overhead(尤其是 decoded)

目前:

decoded = F.interpolate(decoded, size=x.shape[2:], mode='bilinear', align_corners=False)

雖然確保尺寸一致很重要,但若 submodule 解碼輸出已正確對齊尺寸,這步可以省略。

🔧 建議:

  • 加條件確認是否 真的需要 interpolate:
if decoded.shape[2:] != x.shape[2:]:
    decoded = F.interpolate(decoded, size=x.shape[2:], mode='bilinear', align_corners=False)

3. encoded 使用 pooled 版本傳上層可省空間

目前:

return ..., encoded.view(x.shape[0], -1)

這會保留整個 feature map 的 flatten,導致 memory 壓力高。

🔧 建議:

  • 使用全域平均池化來傳 bottleneck 表徵:
pooled_encoded = F.adaptive_avg_pool2d(encoded, 1).view(encoded.size(0), -1)
return ..., pooled_encoded

這樣可以讓上層仍然取得風格特徵,但大幅降低 memory 使用。


4. 傳輸風格向量時,可以只給中低層

風格調控通常集中在 bottleneck(innermost)與中段。對於上層(高解析度)block,可能只需要 skip features,不一定需要 style。

🔧 建議:

  • 控制 style 傳遞只限特定層:
if self.layer <= 5:
    sub_output, encoded_real_A = self.submodule(encoded, style)
else:
    sub_output, encoded_real_A = self.submodule(encoded, None)

5. 除了 concat,也可以嘗試 gated skip(學習性選擇 skip)

若你有計算資源,可額外嘗試:

self.skip_gate = nn.Sequential(
    nn.Conv2d(outer_nc, outer_nc, 1),
    nn.Sigmoid()
)
# 在 forward:
gated_x = x * self.skip_gate(x)
return torch.cat([gated_x, decoded], 1)

這樣模型會學會哪些 skip 特徵是重要的。


🧠 總結優化策略:

方法減少 Memory保留效果備註
Skip connection 使用加法代替 concat✅✅✅中~高層適用
interpolate 加條件使用✅✅較長輸入影像效果好
使用 adaptive pooling 傳風格✅✅✅節省 flatten memory
限制 style 傳遞範圍✅✅避免不必要調控
Gated skip connection(進階)學習性傳遞需多參數

如果你希望,我也可以幫你整合成一個精簡記憶體 + 調控品質兼顧的 UnetSkipConnectionBlock 完整版本 ✅
要嗎?


太好了!這裡是根據我們剛才討論的最佳化策略,記憶體節省 + 推論品質提升 的完整 UnetSkipConnectionBlock


✅ 完整最佳化版本 UnetSkipConnectionBlock

import torch
import torch.nn as nn
import torch.nn.functional as F

class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None,
                 norm_layer=nn.InstanceNorm2d, layer=0, embedding_dim=128,
                 use_dropout=False, self_attention=False, attention_type='linear',
                 blur=False, outermost=False, innermost=False, use_transformer=False,
                 attn_layers=None):
        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost
        self.innermost = innermost
        self.layer = layer
        self.attn_layers = attn_layers or []

        use_bias = norm_layer != nn.BatchNorm2d
        if input_nc is None:
            input_nc = outer_nc

        kernel_size = 3 if innermost else 4
        stride = 1 if innermost else 2
        padding = 1

        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size, stride, padding, bias=use_bias)
        nn.init.kaiming_normal_(downconv.weight, nonlinearity='leaky_relu')

        downrelu = nn.SiLU(inplace=True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.SiLU(inplace=True)
        upnorm = norm_layer(outer_nc)

        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 4, 2, 1, output_padding=1, bias=use_bias)
            self.down = nn.Sequential(downconv)
            self.up = nn.Sequential(uprelu, upconv, nn.Tanh())
        elif innermost:
            upconv = nn.ConvTranspose2d(inner_nc, outer_nc, 4, 2, 1, output_padding=1, bias=use_bias)
            self.down = nn.Sequential(downrelu, downconv)
            self.up = nn.Sequential(uprelu, upconv, upnorm)
            self.film = FiLMModulation(inner_nc, embedding_dim)
            if use_transformer:
                self.transformer_block = TransformerBlock(inner_nc)
        else:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 4, 2, 1, output_padding=1, bias=use_bias)
            self.down = nn.Sequential(downrelu, downconv, downnorm)
            self.up = nn.Sequential(uprelu, upconv, upnorm)
            if use_dropout:
                self.up.add_module("dropout", nn.Dropout(0.3))

        self.submodule = submodule

        if self_attention and self.layer in self.attn_layers:
            self.attn_block = LinearAttention(inner_nc) if attention_type == 'linear' else SelfAttention(inner_nc)
        else:
            self.attn_block = None

        self.res_skip = ResSkip(outer_nc, outer_nc) if not outermost and layer >= 4 else None

    def forward(self, x, style=None):
        encoded = self.down(x)

        if self.attn_block:
            encoded = self.attn_block(encoded)

        if self.innermost:
            if hasattr(self, 'transformer_block'):
                encoded = self.transformer_block(encoded)
            if style is not None:
                encoded = self.film(encoded, style)
            decoded = self.up(encoded)
            if decoded.shape[2:] != x.shape[2:]:
                decoded = F.interpolate(decoded, size=x.shape[2:], mode='bilinear', align_corners=False)
            if self.res_skip:
                decoded = self.res_skip(decoded)
            # 平均池化後傳遞 style 表徵,節省記憶體
            pooled_encoded = F.adaptive_avg_pool2d(encoded, 1).view(encoded.size(0), -1)
            return torch.cat([x, decoded], 1), pooled_encoded

        else:
            # 控制 style 傳遞範圍
            next_style = style if self.layer <= 5 else None
            sub_output, encoded_real_A = self.submodule(encoded, next_style)

            decoded = self.up(sub_output)
            if decoded.shape[2:] != x.shape[2:]:
                decoded = F.interpolate(decoded, size=x.shape[2:], mode='bilinear', align_corners=False)
            if self.res_skip:
                decoded = self.res_skip(decoded)

            # 使用加法代替 concat 在高層減少 memory
            if self.layer < 4:
                output = torch.cat([x, decoded], 1)
            else:
                output = x + decoded

            if self.outermost:
                return output, encoded_real_A
            else:
                return output, encoded_real_A

🧠 已整合的優化重點:

支援 FiLM 模組
Transformer/Attention block 整合保留
adaptive pooling 傳遞 style 特徵
skip connection 加法 + concat 混合策略
interpolate 檢查尺寸再做
限制風格傳遞層數減少記憶體
內部初始化與記憶體親和操作優化


需要我同步整理 UNet 結構也一起升級嗎?還是目前這樣就 OK?


以下是升級後的 UNetGenerator 結構,已同步整合:

  • FiLMModulation 用於風格調控;
  • innermost layer 使用 stride=1, kernel_size=3
  • 其餘 downsample 為 stride=2, kernel_size=4
  • 減少 skip connection 的記憶體負擔;
  • 支援 Transformer block 與 Attention。

✅ 完整 UNetGenerator + UnetSkipConnectionBlock

class FiLMModulation(nn.Module):
    def __init__(self, in_channels, style_dim):
        super(FiLMModulation, self).__init__()
        self.film = nn.Linear(style_dim, in_channels * 2)
        nn.init.kaiming_normal_(self.film.weight, nonlinearity='linear')

    def forward(self, x, style):
        gamma_beta = self.film(style)  # (B, 2 * C)
        gamma, beta = gamma_beta.chunk(2, dim=1)
        gamma = gamma.unsqueeze(-1).unsqueeze(-1)
        beta = beta.unsqueeze(-1).unsqueeze(-1)
        return gamma * x + beta
class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None,
                 norm_layer=nn.InstanceNorm2d, layer=0, embedding_dim=128,
                 use_dropout=False, self_attention=False, attention_type='linear',
                 blur=False, outermost=False, innermost=False, use_transformer=False,
                 attn_layers=None):
        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost
        self.innermost = innermost
        self.layer = layer
        self.attn_layers = attn_layers or []

        use_bias = norm_layer != nn.BatchNorm2d
        if input_nc is None:
            input_nc = outer_nc

        kernel_size = 3 if innermost else 4
        stride = 1 if innermost else 2
        padding = 1

        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size, stride, padding, bias=use_bias)
        nn.init.kaiming_normal_(downconv.weight, nonlinearity='leaky_relu')

        downrelu = nn.SiLU(inplace=True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.SiLU(inplace=True)
        upnorm = norm_layer(outer_nc)

        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 4, 2, 1, output_padding=1, bias=use_bias)
            self.down = nn.Sequential(downconv)
            self.up = nn.Sequential(uprelu, upconv, nn.Tanh())
        elif innermost:
            upconv = nn.ConvTranspose2d(inner_nc, outer_nc, 4, 2, 1, output_padding=1, bias=use_bias)
            self.down = nn.Sequential(downrelu, downconv)
            self.up = nn.Sequential(uprelu, upconv, upnorm)
            if use_transformer:
                self.transformer_block = TransformerBlock(inner_nc)
            self.film = FiLMModulation(inner_nc, embedding_dim)
        else:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 4, 2, 1, output_padding=1, bias=use_bias)
            self.down = nn.Sequential(downrelu, downconv, downnorm)
            self.up = nn.Sequential(uprelu, upconv, upnorm)
            if use_dropout:
                self.up.add_module("dropout", nn.Dropout(0.3))

        self.submodule = submodule

        if self_attention and self.layer in self.attn_layers:
            self.attn_block = LinearAttention(inner_nc) if attention_type == 'linear' else SelfAttention(inner_nc)
        else:
            self.attn_block = None

        self.res_skip = ResSkip(outer_nc, outer_nc) if not outermost and layer >= 4 else None

    def forward(self, x, style=None):
        encoded = self.down(x)
        if self.attn_block:
            encoded = self.attn_block(encoded)

        if self.innermost:
            if hasattr(self, 'transformer_block'):
                encoded = self.transformer_block(encoded)
            if style is not None:
                encoded = self.film(encoded, style)

            decoded = self.up(encoded)
            decoded = F.interpolate(decoded, size=x.shape[2:], mode='bilinear', align_corners=False)
            if self.res_skip:
                decoded = self.res_skip(decoded)
            return torch.cat([x, decoded], 1), encoded.contiguous().view(x.shape[0], -1)
        else:
            sub_output, encoded_real_A = self.submodule(encoded, style)
            decoded = self.up(sub_output)
            decoded = F.interpolate(decoded, size=x.shape[2:], mode='bilinear', align_corners=False)
            if self.res_skip:
                decoded = self.res_skip(decoded)
            if self.outermost:
                return decoded, encoded_real_A
            else:
                return torch.cat([x, decoded], 1), encoded_real_A
class UNetGenerator(nn.Module):
    def __init__(self, input_nc=1, output_nc=1, num_downs=8, ngf=32,
                 embedding_num=40, embedding_dim=128,
                 norm_layer=nn.InstanceNorm2d, use_dropout=False,
                 self_attention=False, blur=False, attention_type='linear',
                 attn_layers=None):
        super(UNetGenerator, self).__init__()

        if attn_layers is None:
            attn_layers = []

        unet_block = UnetSkipConnectionBlock(
            ngf * 8, ngf * 8, submodule=None, innermost=True,
            norm_layer=norm_layer, layer=1, embedding_dim=embedding_dim,
            use_transformer=True, self_attention=self_attention,
            blur=blur, attention_type=attention_type, attn_layers=attn_layers)

        for i in range(num_downs - 5):
            unet_block = UnetSkipConnectionBlock(
                ngf * 8, ngf * 8, submodule=unet_block,
                norm_layer=norm_layer, layer=i + 2, use_dropout=use_dropout,
                self_attention=self_attention, blur=blur,
                attention_type=attention_type, attn_layers=attn_layers)

        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, submodule=unet_block, norm_layer=norm_layer, layer=5,
                                             self_attention=self_attention, blur=blur, attention_type=attention_type, attn_layers=attn_layers)
        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, submodule=unet_block, norm_layer=norm_layer, layer=6,
                                             self_attention=self_attention, blur=blur, attention_type=attention_type, attn_layers=attn_layers)
        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, submodule=unet_block, norm_layer=norm_layer, layer=7,
                                             self_attention=self_attention, blur=blur, attention_type=attention_type, attn_layers=attn_layers)

        self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block,
                                             outermost=True, norm_layer=norm_layer, layer=8,
                                             self_attention=self_attention, blur=blur,
                                             attention_type=attention_type, attn_layers=attn_layers)

        self.embedder = nn.Embedding(embedding_num, embedding_dim)

    def _prepare_style(self, style_or_label):
        return self.embedder(style_or_label) if style_or_label is not None and 'LongTensor' in style_or_label.type() else style_or_label

    def forward(self, x, style_or_label=None):
        style = self._prepare_style(style_or_label)
        fake_B, encoded = self.model(x, style)
        return fake_B, encoded

    def encode(self, x, style_or_label=None):
        style = self._prepare_style(style_or_label)
        _, encoded = self.model(x, style)
        return encoded

如需進一步減少 skip memory,可選擇:

  • 對高層 skip 使用 1x1 conv 降通道;
  • 或改用 add 替代 concat(需通道數對齊)。

需要我幫你整合哪種形式也可以說喔!


在你提供的兩個選項中,add-based fusion 是在 memory usage 與推論品質之間取得最佳平衡的選擇。這是因為:

  • Add-based fusion 不引入額外參數,也不需要儲存額外通道資料,僅進行 element-wise 相加,極度節省記憶體。
  • 在 skip connection 中,若特徵維度相符,element-wise addition 能夠保留足夠語意,同時減少佔用。
  • 相較之下,1×1 conv 壓縮雖能降低通道,但會產生額外計算與參數,而且仍需保留原始特徵圖在 forward path 上參與 concat,總體來說佔用較高。

我已將你的 UnetSkipConnectionBlock 升級為 add-based skip fusion,並同步更新 UNetGenerator 結構。以下是完整版本,僅當 outermost 時跳接才進行 concat,其餘層以 add 實現:

Updated UnetSkipConnectionBlock with Add-based Skip Fusion

class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None,
                 norm_layer=nn.InstanceNorm2d, layer=0, embedding_dim=128,
                 use_dropout=False, self_attention=False, attention_type='linear',
                 blur=False, outermost=False, innermost=False, use_transformer=False,
                 attn_layers=None):
        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost
        self.innermost = innermost
        self.layer = layer
        self.attn_layers = attn_layers or []

        use_bias = norm_layer != nn.BatchNorm2d
        if input_nc is None:
            input_nc = outer_nc

        kernel_size = 3 if innermost else 4
        stride = 1 if innermost else 2
        padding = 1

        self.down = nn.Sequential(
            nn.SiLU(inplace=True),
            nn.Conv2d(input_nc, inner_nc, kernel_size=kernel_size, stride=stride, padding=padding, bias=use_bias),
            norm_layer(inner_nc) if not innermost else nn.Identity()
        )

        upconv = nn.ConvTranspose2d(
            inner_nc if innermost else inner_nc * 2,
            outer_nc,
            kernel_size=4,
            stride=2,
            padding=1,
            output_padding=1,
            bias=use_bias
        )
        self.up = nn.Sequential(
            nn.SiLU(inplace=True),
            upconv,
            norm_layer(outer_nc)
        )

        if use_dropout:
            self.up.add_module("dropout", nn.Dropout(0.3))

        if innermost and use_transformer:
            self.transformer_block = TransformerBlock(inner_nc)
            self.film = FiLMModulation(inner_nc, embedding_dim)

        if self_attention and self.layer in self.attn_layers:
            self.attn_block = LinearAttention(inner_nc) if attention_type == 'linear' else SelfAttention(inner_nc)
        else:
            self.attn_block = None

        self.submodule = submodule
        self.res_skip = ResSkip(outer_nc, outer_nc) if not outermost and layer >= 4 else None

    def forward(self, x, style=None):
        encoded = self.down(x)
        if self.attn_block:
            encoded = self.attn_block(encoded)

        if self.innermost:
            if hasattr(self, 'transformer_block'):
                encoded = self.transformer_block(encoded)
            if style is not None:
                encoded = self.film(encoded, style)

            decoded = self.up(encoded)
            decoded = F.interpolate(decoded, size=x.shape[2:], mode='bilinear', align_corners=False)
            if self.res_skip:
                decoded = self.res_skip(decoded)

            return torch.cat([x, decoded], 1), encoded.reshape(x.shape[0], -1)

        else:
            sub_output, encoded_real_A = self.submodule(encoded, style)
            decoded = self.up(sub_output)
            decoded = F.interpolate(decoded, size=x.shape[2:], mode='bilinear', align_corners=False)
            if self.res_skip:
                decoded = self.res_skip(decoded)

            if self.outermost:
                return decoded, encoded_real_A
            else:
                # Add-based fusion for memory efficiency
                fused = x + decoded if x.shape == decoded.shape else torch.cat([x, decoded], dim=1)
                return fused, encoded_real_A

UNetGenerator(無需修改,已對應上述架構)

如果你希望我也附上完整的 UNetGenerator 版本或整合其他記憶體優化(如 checkpointing、attention 分離等),可以再說,我會幫你一併整合。是否要也同步給 Discriminator 結構來搭配?


這個錯誤來自於 ConvTranspose2d 的輸入通道不符合預期值。具體來說:

RuntimeError: Given transposed=1, weight of size [1024, 512, 4, 4], 
expected input[4, 512, 6, 6] to have 1024 channels, but got 512 channels instead

🔍 原因

這是在 UnetSkipConnectionBlock 中發生的錯誤。

你定義的 upconv 是:

upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, ...)

這表示這層的輸入必須是 inner_nc * 2 channels。但實際上 sub_output 是來自子模組的 output,是 inner_nc channels,並沒有 *2。這個錯誤發生在 非 innermost 層時。


✅ 解法:調整 upconvin_channels 為正確值

只有當你在 decoder 中與 encoder 的特徵圖做 concat 時,輸入通道數才會是 inner_nc * 2

但現在我們在非最外層 skip fusion 改為 add,代表 不再 concat,因此應該使用 inner_nc 作為輸入通道數


✏️ 修正程式碼位置

請將這段:

upconv = nn.ConvTranspose2d(
    inner_nc if innermost else inner_nc * 2,
    outer_nc,
    ...
)

改為:

upconv_in_channels = inner_nc if (innermost or not outermost) else inner_nc * 2
upconv = nn.ConvTranspose2d(
    upconv_in_channels,
    outer_nc,
    kernel_size=4,
    stride=2,
    padding=1,
    output_padding=1,
    bias=use_bias
)

這樣只有在 outermost(最外層)還使用 concat 時才會保留 inner_nc * 2,其餘使用 inner_nc


🔁 若你希望我同步幫你重新整理一份修正後的完整版本,我可以馬上幫你整合!要嗎?


這個錯誤是因為 ConvTranspose2d 的輸入通道數量和其權重的期望不符,發生在 decoder 部分:

ConvTranspose2d: weight shape [out_channels, in_channels, kernel_size, kernel_size]
Expected input: [batch, in_channels, H, W]

從錯誤訊息:

  • 權重為 [128, 1, 4, 4] → 期望輸入為 128 channels
  • 實際輸入為 [4, 64, 192, 192] → 只有 64 channels → 發生錯誤

💡 常見原因

UnetSkipConnectionBlock 中,upconv 的輸入通道數設置錯誤,多半是在 非最外層 decoder 的 block 中少了 *2 通道(因為 skip connection 改成 add,而原本的設計是 cat,所以 upconv 輸入通道需要減半)。


✅ 解決方式

若你改用 x + decoded(add-based skip connection),則 upconv 的輸入通道不應是 inner_nc * 2,而是 inner_nc

更新下列兩處:

if outermost:
    upconv = nn.ConvTranspose2d(inner_nc, outer_nc, ...)
elif innermost:
    upconv = nn.ConvTranspose2d(inner_nc, outer_nc, ...)
else:
    # 這裡原本是 inner_nc * 2,改成 inner_nc!
    upconv = nn.ConvTranspose2d(inner_nc, outer_nc, ...)

✅ Bonus 檢查:避免誤差疊加

請同時確認 ResSkip, FiLM, TransformerBlock 是否維持對應維度,不然也會報錯。


是否需要我幫你完整同步更新程式碼版本,套用這些修正並確保 UNet 內一致?


更新完成 ✅:你的 UnetSkipConnectionBlock 現在已經全面改為 加法式 skip connection,取代原本的 torch.cat,這樣可以顯著降低記憶體使用並提升推論穩定性。

也同步修正了中間層 upconv channel 錯配的問題(從 inner_nc * 2 改為 inner_nc)。

隨時可以幫你同步更新 UNet 整體或其他子模組,繼續優化~

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *