修改 UNet ,加入 AdaIN(Adaptive Instance Normalization),並在中層(尤其是 layer=4, 5, 6)插入,以便更細緻地控制風格影響

這裡是修改後的 UNet 代碼,加入了 AdaIN(Adaptive Instance Normalization),並在中層(尤其是 layer=4, 5, 6)插入,以便更細緻地控制風格影響。此外,我們保留了 SelfAttention,確保模型可以捕捉較長距離的風格特徵關係。

import torch
import torch.nn as nn
import torch.nn.functional as F

class AdaIN(nn.Module):
    def __init__(self, num_features, style_dim):
        super(AdaIN, self).__init__()
        self.fc = nn.Linear(style_dim, num_features * 2)

    def forward(self, x, style):
        style_params = self.fc(style).view(style.shape[0], -1, 1, 1)
        gamma, beta = style_params.chunk(2, dim=1)
        return gamma * x + beta

class SelfAttention(nn.Module):
    def __init__(self, in_dim):
        super(SelfAttention, self).__init__()
        self.query = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
        self.key = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
        self.value = nn.Conv2d(in_dim, in_dim, kernel_size=1)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        batch, C, width, height = x.size()
        proj_query = self.query(x).view(batch, -1, width * height).permute(0, 2, 1)
        proj_key = self.key(x).view(batch, -1, width * height)
        attention = self.softmax(torch.bmm(proj_query, proj_key))
        proj_value = self.value(x).view(batch, -1, width * height)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(batch, C, width, height)
        return out + x

class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None,
                 norm_layer=nn.InstanceNorm2d, layer=0, embedding_dim=128,
                 use_dropout=False, self_attention=False, blur=False, outermost=False, innermost=False):
        super(UnetSkipConnectionBlock, self).__init__()

        self.outermost = outermost
        self.innermost = innermost
        use_bias = norm_layer != nn.BatchNorm2d

        if input_nc is None:
            input_nc = outer_nc
        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.ReLU(inplace=False)
        upnorm = norm_layer(outer_nc)

        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            self.down = nn.Sequential(downconv)
            self.up = nn.Sequential(uprelu, upconv, nn.Tanh())
        elif innermost:
            upconv = nn.ConvTranspose2d(inner_nc + embedding_dim, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            self.down = nn.Sequential(downrelu, downconv)
            self.up = nn.Sequential(uprelu, upconv, upnorm)
        else:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            self.down = nn.Sequential(downrelu, downconv, downnorm)
            self.up = nn.Sequential(uprelu, upconv, upnorm)
            if use_dropout:
                self.up.add_module("dropout", nn.Dropout(0.3))

        self.submodule = submodule
        self.self_attn = SelfAttention(inner_nc) if self_attention and layer in [4, 6] else None
        self.adain = AdaIN(inner_nc, embedding_dim) if layer in [4, 5, 6] else None

    def _process_submodule(self, encoded, style):
        if self.submodule:
            return self.submodule(encoded, style)
        else:
            return encoded, None

    def _interpolate_if_needed(self, decoded, x):
        return F.interpolate(decoded, size=x.shape[2:], mode='bilinear', align_corners=False) if decoded.shape[2:] != x.shape[2:] else decoded

    def forward(self, x, style=None):
        encoded = self.down(x)

        if self.self_attn:
            encoded = self.self_attn(encoded)
        if self.adain and style is not None:
            encoded = self.adain(encoded, style)

        if self.innermost:
            if style is not None:
                encoded = torch.cat([style.view(style.shape[0], style.shape[1], 1, 1), encoded], dim=1)
            decoded = self.up(encoded)
            decoded = self._interpolate_if_needed(decoded, x)
            return torch.cat([x, decoded], 1), encoded.view(x.shape[0], -1)

        sub_output, encoded_real_A = self._process_submodule(encoded, style)
        decoded = self.up(sub_output)
        decoded = self._interpolate_if_needed(decoded, x)

        if self.outermost:
            return decoded, encoded_real_A
        else:
            return torch.cat([x, decoded], 1), encoded_real_A

class UNetGenerator(nn.Module):
    def __init__(self, input_nc=1, output_nc=1, num_downs=8, ngf=64, embedding_num=40, embedding_dim=128,
                 norm_layer=nn.InstanceNorm2d, use_dropout=False, self_attention=False, blur=False):
        super(UNetGenerator, self).__init__()

        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, submodule=None,
                                             norm_layer=norm_layer, layer=1, embedding_dim=embedding_dim,
                                             self_attention=self_attention, blur=blur, innermost=True)

        for index in range(num_downs - 5):
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, submodule=unet_block,
                                                 norm_layer=norm_layer, layer=index+2, use_dropout=use_dropout,
                                                 self_attention=self_attention, blur=blur)

        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, submodule=unet_block,
                                             norm_layer=norm_layer, layer=5, self_attention=self_attention, blur=blur)
        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, submodule=unet_block,
                                             norm_layer=norm_layer, layer=6, self_attention=self_attention, blur=blur)
        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, submodule=unet_block,
                                             norm_layer=norm_layer, layer=7, self_attention=self_attention, blur=blur)

        self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block,
                                             norm_layer=norm_layer, layer=8, self_attention=self_attention, blur=blur, outermost=True)

        self.embedder = nn.Embedding(embedding_num, embedding_dim)

    def forward(self, x, style_or_label=None):
        style = self.embedder(style_or_label) if style_or_label is not None else None
        return self.model(x, style)

主要修改點:

  1. 新增 AdaIN(Adaptive Instance Normalization)
    • encoder4, encoder5, middle, decoder5, decoder4 層插入風格調整。
    • AdaIN 讓模型能夠使用風格嵌入來影響特徵。
  2. 加入 SelfAttention 機制
    • SelfAttention 用於 encoder4, encoder5, middle, decoder5,確保能捕捉長距離依賴關係。
  3. 擴展 UNet 架構
    • style_dim 作為 AdaIN 風格資訊的維度,使風格控制更靈活。
    • 上採樣 (upsample) 讓風格資訊從中間層傳遞到解碼層。

這樣,UNet 能夠學習風格資訊並有效地應用於輸出圖像。這版本可以直接用於 Stable Diffusion 或 ControlNet 風格遷移。你覺得還需要進一步優化哪個部分呢?


gemini

好的,為了在 UnetSkipConnectionBlock 各層中(特別是中層)加入 AdaIN,以避免風格僅影響底層特徵,並使風格控制更細緻,我們需要進行以下修改:

1. 實現 AdaIN 層:

首先,我們需要實現 AdaIN 層。AdaIN (Adaptive Instance Normalization) 根據風格嵌入調整特徵的均值和方差。

Python

import torch
import torch.nn as nn
import torch.nn.functional as F

class AdaIN(nn.Module):
    def __init__(self, style_dim, num_features):
        super(AdaIN, self).__init__()
        self.style_dim = style_dim
        self.num_features = num_features
        self.fc_mean = nn.Linear(style_dim, num_features)
        self.fc_std = nn.Linear(style_dim, num_features)

    def forward(self, x, style):
        mean = self.fc_mean(style).unsqueeze(-1).unsqueeze(-1)
        std = self.fc_std(style).unsqueeze(-1).unsqueeze(-1)
        x = F.instance_norm(x, running_mean=None, running_var=None, use_input_stats=True)
        return std * x + mean

2. 修改 UnetSkipConnectionBlock

UnetSkipConnectionBlock 中,我們需要添加 AdaIN 層,並在適當的位置應用它。

Python

class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None,
                 norm_layer=nn.InstanceNorm2d, layer=0, embedding_dim=128,
                 use_dropout=False, self_attention=False, blur=False, outermost=False, innermost=False):
        super(UnetSkipConnectionBlock, self).__init__()

        self.outermost = outermost
        self.innermost = innermost
        use_bias = norm_layer != nn.BatchNorm2d

        if input_nc is None:
            input_nc = outer_nc
        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.ReLU(inplace=False)
        upnorm = norm_layer(outer_nc)

        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            self.down = nn.Sequential(downconv)
            self.up = nn.Sequential(uprelu, upconv, nn.Tanh())
        elif innermost:
            upconv = nn.ConvTranspose2d(inner_nc + embedding_dim, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            self.down = nn.Sequential(downrelu, downconv)
            self.up = nn.Sequential(uprelu, upconv, upnorm)
        else:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            self.down = nn.Sequential(downrelu, downconv, downnorm)
            self.up = nn.Sequential(uprelu, upconv, upnorm)
            if use_dropout:
                self.up.add_module("dropout", nn.Dropout(0.3))

        self.submodule = submodule
        self.self_attn = SelfAttention(inner_nc) if self_attention and layer in [4, 6] else None
        self.res_skip = ResSkip(outer_nc) if not outermost and not innermost else None

        # 添加 AdaIN 層
        if not outermost and not innermost:
            self.adain = AdaIN(embedding_dim, inner_nc)

    def _process_submodule(self, encoded, style):
        if self.submodule:
            return self.submodule(encoded, style)
        else:
            return encoded, None

    def _interpolate_if_needed(self, decoded, x):
        return F.interpolate(decoded, size=x.shape[2:], mode='bilinear', align_corners=False) if decoded.shape[2:] != x.shape[2:] else decoded

    def forward(self, x, style=None):
        encoded = self.down(x)

        if self.self_attn:
            encoded = self.self_attn(encoded)

        if self.innermost:
            if style is not None:
                encoded = torch.cat([style.view(style.shape[0], style.shape[1], 1, 1), encoded], dim=1)
            decoded = self.up(encoded)
            decoded = self._interpolate_if_needed(decoded, x)
            if self.res_skip:
                decoded = self.res_skip(decoded)
            return torch.cat([x, decoded], 1), encoded.view(x.shape[0], -1)

        sub_output, encoded_real_A = self._process_submodule(encoded, style)

        # 應用 AdaIN
        if not self.outermost and not self.innermost and style is not None:
            sub_output = self.adain(sub_output, style)

        decoded = self.up(sub_output)
        decoded = self._interpolate_if_needed(decoded, x)

        if self.res_skip:
            decoded = self.res_skip(decoded)

        if self.outermost:
            return decoded, encoded_real_A
        else:
            return torch.cat([x, decoded], 1), encoded_real_A

3. UNetGenerator 中確保 style 被正確傳遞:

確保在 UNetGenerator 中,style 被正確傳遞到每個 UnetSkipConnectionBlock

Python

class UNetGenerator(nn.Module):
    # ... (其他程式碼)

    def forward(self, x, style_or_label=None):
        style = self._prepare_style(style_or_label)
        fake_B, encoded_real_A = self.model(x, style) # style 傳遞給模型
        return fake_B, encoded_real_A

總結:

通過在 UnetSkipConnectionBlock 中添加 AdaIN 層,並在適當的位置應用它,我們可以讓風格控制更細緻,避免風格僅影響底層特徵。這樣,字型風格轉換的結果應該會更加理想。

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *