Zi2ZiModel 字型風格轉換, 優化 UNetSkipConnectionBlock 的 forward 方法

before:

class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None, 
                 norm_layer=nn.BatchNorm2d, layer=0, embedding_dim=128, 
                 use_dropout=False, self_attention=False, blur=False, outermost=False, innermost=False):
        super(UnetSkipConnectionBlock, self).__init__()
        
        self.outermost = outermost
        self.innermost = innermost
        use_bias = norm_layer != nn.BatchNorm2d  # 若使用 BatchNorm,則 bias=False

        if input_nc is None:
            input_nc = outer_nc
        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.ReLU(inplace=False)  # 這裡必須是 False
        upnorm = norm_layer(outer_nc)

        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            self.down = nn.Sequential(downconv)
            self.up = nn.Sequential(uprelu, upconv, nn.Tanh())
        elif innermost:
            upconv = nn.ConvTranspose2d(inner_nc + embedding_dim, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            self.down = nn.Sequential(downrelu, downconv)
            self.up = nn.Sequential(uprelu, upconv, upnorm)
        else:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            self.down = nn.Sequential(downrelu, downconv, downnorm)
            self.up = nn.Sequential(uprelu, upconv, upnorm)
            if use_dropout:
                self.up.add_module("dropout", nn.Dropout(0.5))

        self.submodule = submodule
        self.self_attn = SelfAttention(inner_nc) if self_attention and layer in [4, 6] else None
        self.res_skip = ResSkip(outer_nc) if not outermost and not innermost else None

    def forward(self, x, style=None):
        encoded = self.down(x)

        if self.self_attn:
            encoded = self.self_attn(encoded)

        if self.innermost:
            if style is not None:
                encoded = torch.cat([style.view(style.shape[0], style.shape[1], 1, 1), encoded], dim=1)
            decoded = self.up(encoded)
            if decoded.shape[2:] != x.shape[2:]:
                decoded = F.interpolate(decoded, size=x.shape[2:], mode='bilinear', align_corners=False)
            if self.res_skip:
                decoded = self.res_skip(decoded)
            return torch.cat([x, decoded], 1), encoded.view(x.shape[0], -1)

        elif self.outermost:
            if self.submodule:
                sub_output, encoded_real_A = self.submodule(encoded, style)
            else:
                sub_output = encoded
            decoded = self.up(sub_output)
            if self.res_skip:
                decoded = self.res_skip(decoded)

            return decoded, encoded_real_A

        else:
            if self.submodule:
                sub_output, encoded_real_A = self.submodule(encoded, style)
            else:
                sub_output = encoded
            decoded = self.up(sub_output)
            if decoded.shape[2:] != x.shape[2:]:
                decoded = F.interpolate(decoded, size=x.shape[2:], mode='bilinear', align_corners=False)
            if self.res_skip:
                decoded = self.res_skip(decoded)
            
            return torch.cat([x, decoded], 1), encoded_real_A

after, chatgpt:

class UNetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc, submodule=None, norm_layer=nn.BatchNorm2d):
        super(UNetSkipConnectionBlock, self).__init__()
        self.submodule = submodule
        self.down = nn.Sequential(
            nn.Conv2d(outer_nc, inner_nc, kernel_size=4, stride=2, padding=1),
            norm_layer(inner_nc),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.up = nn.Sequential(
            nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1),
            norm_layer(outer_nc),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        if self.submodule:
            return torch.cat([self.up(self.submodule(self.down(x))), x], 1)
        else:
            return torch.cat([self.up(self.down(x)), x], 1)

after, gemini:

好的,我們來優化 UnetSkipConnectionBlockforward 方法,主要目標是提高效率和簡潔性。

問題分析:

  • 程式碼中有重複的邏輯,特別是在處理 submodule 的部分。
  • decoded.shape[2:] != x.shape[2:] 的檢查和 F.interpolate 操作在多個地方出現,可以考慮提取出來。
  • encoded_real_A 的回傳邏輯可以簡化。

優化策略:

  1. 提取共用邏輯:submodule 的處理和 F.interpolate 的檢查和執行提取到單獨的函數中。
  2. 簡化回傳邏輯: 統一回傳 encoded_real_A 的方式。
  3. 減少重複計算: 減少不必要的形狀檢查和計算。

優化後的程式碼:

Python

import torch
import torch.nn as nn
import torch.nn.functional as F

# ... (ResSkip 和 SelfAttention 類別保持不變)

class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None,
                 norm_layer=nn.BatchNorm2d, layer=0, embedding_dim=128,
                 use_dropout=False, self_attention=False, blur=False, outermost=False, innermost=False):
        super(UnetSkipConnectionBlock, self).__init__()

        self.outermost = outermost
        self.innermost = innermost
        use_bias = norm_layer != nn.BatchNorm2d

        if input_nc is None:
            input_nc = outer_nc
        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.ReLU(inplace=False)
        upnorm = norm_layer(outer_nc)

        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            self.down = nn.Sequential(downconv)
            self.up = nn.Sequential(uprelu, upconv, nn.Tanh())
        elif innermost:
            upconv = nn.ConvTranspose2d(inner_nc + embedding_dim, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            self.down = nn.Sequential(downrelu, downconv)
            self.up = nn.Sequential(uprelu, upconv, upnorm)
        else:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            self.down = nn.Sequential(downrelu, downconv, downnorm)
            self.up = nn.Sequential(uprelu, upconv, upnorm)
            if use_dropout:
                self.up.add_module("dropout", nn.Dropout(0.5))

        self.submodule = submodule
        self.self_attn = SelfAttention(inner_nc) if self_attention and layer in [4, 6] else None
        self.res_skip = ResSkip(outer_nc) if not outermost and not innermost else None

    def _process_submodule(self, encoded, style):
        if self.submodule:
            sub_output, encoded_real_A = self.submodule(encoded, style)
        else:
            sub_output = encoded
            encoded_real_A = None
        return sub_output, encoded_real_A

    def _interpolate_if_needed(self, decoded, x):
        if decoded.shape[2:] != x.shape[2:]:
            decoded = F.interpolate(decoded, size=x.shape[2:], mode='bilinear', align_corners=False)
        return decoded

    def forward(self, x, style=None):
        encoded = self.down(x)

        if self.self_attn:
            encoded = self.self_attn(encoded)

        if self.innermost:
            if style is not None:
                encoded = torch.cat([style.view(style.shape[0], style.shape[1], 1, 1), encoded], dim=1)
            decoded = self.up(encoded)
            decoded = self._interpolate_if_needed(decoded, x)
            if self.res_skip:
                decoded = self.res_skip(decoded)
            return torch.cat([x, decoded], 1), encoded.view(x.shape[0], -1)

        sub_output, encoded_real_A = self._process_submodule(encoded, style)
        decoded = self.up(sub_output)
        decoded = self._interpolate_if_needed(decoded, x)

        if self.res_skip:
            decoded = self.res_skip(decoded)

        if self.outermost:
            return decoded, encoded_real_A
        else:
            return torch.cat([x, decoded], 1), encoded_real_A

主要改進:

  • _process_submodule 函數: 提取了 submodule 的處理邏輯,減少了重複程式碼。
  • _interpolate_if_needed 函數: 提取了 F.interpolate 的檢查和執行邏輯。
  • 簡化的回傳邏輯: 統一了 encoded_real_A 的回傳方式。
  • 更清晰的程式碼結構: 通過提取函數,程式碼的可讀性和可維護性得到了提高。

這些優化可以提高 UnetSkipConnectionBlock 的效率和程式碼的清晰度。

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *