Zi2ZiModel 字型風格轉換加入了 AdaIN(Adaptive Instance Normalization) 來調整風格向量,並使用 Skip Connection(ResSkip) 來保留高解析度細節

這段程式碼是基於 U-Net 架構的字型風格轉換模型,並且加入了 AdaIN(Adaptive Instance Normalization) 來調整風格向量,並使用 Skip Connection(ResSkip) 來保留高解析度細節。以下是對你的需求調整的具體實現計劃:

改進計劃

  1. 加入 AdaIN 模塊
    • AdaIN 用於風格轉換,將風格向量應用到不同層的特徵圖,使得風格調整更加細緻。
    • 在 UNet 的 UnetSkipConnectionBlock 內,於 innermost中間層 位置使用 AdaIN 來調整風格特徵。
  2. 加入 ResSkip 來提升解析度特徵的保留
    • 增強 Skip Connection(ResSkip),將低解析度特徵保留並直接傳遞至高解析度層,提升字型細節。
    • UnetSkipConnectionBlockforward 中添加 ResSkip
  3. 優化 Self-Attention 機制
    • 目前 Self-Attention 只應用在 layer=4,6,但可以根據特徵分辨率選擇是否應用。
    • 讓 Self-Attention 影響較大的 receptive field,進一步增強風格一致性。

改動後的核心代碼

1. 新增 AdaIN 層

class AdaptiveInstanceNorm2d(nn.Module):
    def __init__(self, num_features, style_dim):
        super(AdaptiveInstanceNorm2d, self).__init__()
        self.norm = nn.InstanceNorm2d(num_features, affine=False)
        self.style_scale = nn.Linear(style_dim, num_features)
        self.style_bias = nn.Linear(style_dim, num_features)

    def forward(self, x, style):
        gamma = self.style_scale(style).view(style.shape[0], -1, 1, 1)
        beta = self.style_bias(style).view(style.shape[0], -1, 1, 1)
        return gamma * self.norm(x) + beta

2. 在 U-Net 中加入 AdaIN

class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None, 
                 norm_layer=nn.BatchNorm2d, layer=0, embedding_dim=128, 
                 use_dropout=False, self_attention=False, blur=False, outermost=False, innermost=False):
        super(UnetSkipConnectionBlock, self).__init__()
        
        self.outermost = outermost
        self.innermost = innermost
        use_bias = norm_layer != nn.BatchNorm2d
        self.use_adain = True if not outermost else False  # 只在內部層使用 AdaIN

        if input_nc is None:
            input_nc = outer_nc
        
        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.ReLU(inplace=False)  
        upnorm = norm_layer(outer_nc)

        if innermost:
            self.adain = AdaptiveInstanceNorm2d(inner_nc, embedding_dim)  # 加入 AdaIN
            upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            self.down = nn.Sequential(downrelu, downconv)
            self.up = nn.Sequential(uprelu, upconv, upnorm)

        else:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            self.down = nn.Sequential(downrelu, downconv, downnorm)
            self.up = nn.Sequential(uprelu, upconv, upnorm)

            if self.use_adain:
                self.adain = AdaptiveInstanceNorm2d(inner_nc, embedding_dim)

            if use_dropout:
                self.up.add_module("dropout", nn.Dropout(0.5))

        self.submodule = submodule

        if self_attention and layer in [4, 6]:
            self.self_attn = SelfAttention(inner_nc)
        else:
            self.self_attn = None

    def forward(self, x, style=None):
        encoded = self.down(x)

        if self.self_attn is not None:
            encoded = self.self_attn(encoded)

        if self.use_adain and style is not None:
            encoded = self.adain(encoded, style)

        if self.submodule:
            sub_output, encoded_real_A = self.submodule(encoded, style)
        else:
            sub_output = encoded

        decoded = self.up(sub_output)

        return torch.cat([x, decoded], 1), encoded_real_A

3. 加入 ResSkip

class ResSkip(nn.Module):
    def __init__(self, channels):
        super(ResSkip, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        residual = x
        out = self.relu(self.conv1(x))
        out = self.conv2(out)
        return out + residual  # 保留高解析度特徵

並在 U-Netforward 方法中加入:

if not self.outermost and not self.innermost:
    decoded = self.res_skip(decoded)

優化點

  1. 更靈活的風格轉換
    • 使用 AdaIN,允許模型在不同層應用風格變換,使得字型轉換更平滑。
    • 避免傳統 BatchNorm 造成的風格信息丟失。
  2. 高解析度特徵保留
    • ResSkip 允許細節特徵在高解析度層被保留,減少字型失真問題。
  3. 提升 Self-Attention 效果
    • 只在影響較大的 receptive field 應用 Self-Attention。

這些改進應該能有效提升 Noto Sans → Zen Maru Gothic 的字型風格轉換質量!🚀

PyTorch 訓練腳本,已包含 AdaIN(Adaptive Instance Normalization)、ResSkip(跳躍連接增強高解析度特徵保留)以及 Self-Attention,專為字型風格轉換(Noto Sans → Zen Maru Gothic)設計。這個模型基於 U-Net,並加入了 AdaIN 來進行風格調整,以及 ResSkip 來提升高解析度特徵的保留度。


主要修改:

  1. AdaIN(Adaptive Instance Normalization)
    • 讓風格向量影響不同層的特徵圖,使風格轉換更加靈活。
  2. ResSkip(Residual Skip Connection)
    • 透過殘差連接增強高解析度特徵的保留度,提高字型細節清晰度。
  3. Self-Attention
    • 讓模型學習長距離的關聯性,提高字型細節的準確度。

python

import torch
import torch.nn as nn
import torch.nn.functional as F

# Adaptive Instance Normalization (AdaIN)
class AdaptiveInstanceNorm2d(nn.Module):
    def __init__(self, num_features, style_dim):
        super(AdaptiveInstanceNorm2d, self).__init__()
        self.norm = nn.InstanceNorm2d(num_features, affine=False)
        self.style_scale = nn.Linear(style_dim, num_features)
        self.style_bias = nn.Linear(style_dim, num_features)

    def forward(self, x, style):
        gamma = self.style_scale(style).view(style.shape[0], -1, 1, 1)
        beta = self.style_bias(style).view(style.shape[0], -1, 1, 1)
        return gamma * self.norm(x) + beta

# Residual Skip Connection
class ResSkip(nn.Module):
    def __init__(self, channels):
        super(ResSkip, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        residual = x
        out = self.relu(self.conv1(x))
        out = self.conv2(out)
        return out + residual

# Self-Attention Module
class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        self.query = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch, C, H, W = x.size()
        query = self.query(x).view(batch, -1, H * W).permute(0, 2, 1)
        key = self.key(x).view(batch, -1, H * W)
        attn_map = torch.bmm(query, key)
        attn_map = F.softmax(attn_map, dim=-1)
        value = self.value(x).view(batch, -1, H * W)
        out = torch.bmm(value, attn_map.permute(0, 2, 1)).view(batch, C, H, W)
        return self.gamma * out + x

# UNet Block with AdaIN & Self-Attention
class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None, 
                 norm_layer=nn.BatchNorm2d, layer=0, embedding_dim=128, 
                 use_dropout=False, self_attention=False, outermost=False, innermost=False):
        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost
        self.innermost = innermost
        self.use_adain = not outermost  # Only use AdaIN inside the network
        use_bias = norm_layer != nn.BatchNorm2d

        if input_nc is None:
            input_nc = outer_nc
        
        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.ReLU(inplace=True)
        upnorm = norm_layer(outer_nc)

        if innermost:
            self.adain = AdaptiveInstanceNorm2d(inner_nc, embedding_dim)
            upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            self.down = nn.Sequential(downrelu, downconv)
            self.up = nn.Sequential(uprelu, upconv, upnorm)
        else:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            self.down = nn.Sequential(downrelu, downconv, downnorm)
            self.up = nn.Sequential(uprelu, upconv, upnorm)
            
            if self.use_adain:
                self.adain = AdaptiveInstanceNorm2d(inner_nc, embedding_dim)
            if use_dropout:
                self.up.add_module("dropout", nn.Dropout(0.5))

        self.submodule = submodule
        self.self_attn = SelfAttention(inner_nc) if self_attention and layer in [4, 6] else None
        self.res_skip = ResSkip(outer_nc) if not outermost and not innermost else None

    def forward(self, x, style=None):
        encoded = self.down(x)
        if self.self_attn:
            encoded = self.self_attn(encoded)
        if self.use_adain and style is not None:
            encoded = self.adain(encoded, style)
        
        if self.submodule:
            sub_output, _ = self.submodule(encoded, style)
        else:
            sub_output = encoded
        
        decoded = self.up(sub_output)
        if self.res_skip:
            decoded = self.res_skip(decoded)
        
        return (torch.cat([x, decoded], 1) if not self.outermost else decoded), None

# Full U-Net Model
class UNetGenerator(nn.Module):
    def __init__(self, input_nc, output_nc, num_downs, embedding_dim=128):
        super(UNetGenerator, self).__init__()
        unet_block = UnetSkipConnectionBlock(64, 64, layer=num_downs-1, innermost=True)
        for i in range(num_downs - 2, 0, -1):
            unet_block = UnetSkipConnectionBlock(64, 64, submodule=unet_block, layer=i)
        self.model = UnetSkipConnectionBlock(output_nc, 64, input_nc=input_nc, submodule=unet_block, outermost=True)
    
    def forward(self, x, style):
        return self.model(x, style)[0]

# Instantiate Model
input_nc, output_nc, num_downs, embedding_dim = 1, 1, 7, 128  # For grayscale fonts
net = UNetGenerator(input_nc, output_nc, num_downs, embedding_dim)

# Print Model Summary
print(net)

import torch
import torch.nn as nn

class AdaptiveInstanceNorm2d(nn.Module):
    def __init__(self, num_features, style_dim):
        super(AdaptiveInstanceNorm2d, self).__init__()
        self.norm = nn.InstanceNorm2d(num_features, affine=False)
        self.style_scale = nn.Linear(style_dim, num_features)
        self.style_bias = nn.Linear(style_dim, num_features)

    def forward(self, x, style):
        gamma = self.style_scale(style).view(style.shape[0], -1, 1, 1)
        beta = self.style_bias(style).view(style.shape[0], -1, 1, 1)
        return gamma * self.norm(x) + beta

class SelfAttention(nn.Module):
    def __init__(self, in_dim):
        super(SelfAttention, self).__init__()
        self.query = nn.Conv2d(in_dim, in_dim // 8, 1)
        self.key = nn.Conv2d(in_dim, in_dim // 8, 1)
        self.value = nn.Conv2d(in_dim, in_dim, 1)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        batch, C, width, height = x.size()
        proj_query = self.query(x).view(batch, -1, width * height).permute(0, 2, 1)
        proj_key = self.key(x).view(batch, -1, width * height)
        energy = torch.bmm(proj_query, proj_key)
        attention = self.softmax(energy)
        proj_value = self.value(x).view(batch, -1, width * height)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(batch, C, width, height)
        return out + x

class ResSkip(nn.Module):
    def __init__(self, channels):
        super(ResSkip, self).__init__()
        self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return x + self.relu(self.conv(x))

class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None, 
                 norm_layer=nn.BatchNorm2d, layer=0, embedding_dim=128, 
                 use_dropout=False, self_attention=False, outermost=False, innermost=False):
        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost
        self.innermost = innermost
        self.use_adain = not outermost  # Only use AdaIN inside the network
        use_bias = norm_layer != nn.BatchNorm2d

        if input_nc is None:
            input_nc = outer_nc
        
        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.ReLU(inplace=True)
        upnorm = norm_layer(outer_nc)

        if innermost:
            self.adain = AdaptiveInstanceNorm2d(inner_nc, embedding_dim)
            upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            self.down = nn.Sequential(downrelu, downconv)
            self.up = nn.Sequential(uprelu, upconv, upnorm)
        else:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            self.down = nn.Sequential(downrelu, downconv, downnorm)
            self.up = nn.Sequential(uprelu, upconv, upnorm)
            
            if self.use_adain:
                self.adain = AdaptiveInstanceNorm2d(inner_nc, embedding_dim)
            if use_dropout:
                self.up.add_module("dropout", nn.Dropout(0.5))

        self.submodule = submodule
        self.self_attn = SelfAttention(inner_nc) if self_attention and layer in [4, 6] else None
        self.res_skip = ResSkip(outer_nc) if not outermost and not innermost else None

    def forward(self, x, style=None):
        encoded = self.down(x)
        if self.self_attn:
            encoded = self.self_attn(encoded)
        if self.use_adain and style is not None:
            encoded = self.adain(encoded, style)
        
        if self.submodule:
            sub_output, _ = self.submodule(encoded, style)
        else:
            sub_output = encoded
        
        decoded = self.up(sub_output)
        if self.res_skip:
            decoded = self.res_skip(decoded)
        
        return (torch.cat([x, decoded], 1) if not self.outermost else decoded), None

# 測試 UNet 區塊
if __name__ == "__main__":
    x = torch.randn(1, 3, 256, 256)  # 假設輸入圖像
    style = torch.randn(1, 128)  # 假設風格向量
    model = UnetSkipConnectionBlock(outer_nc=64, inner_nc=128, input_nc=3, 
                                    embedding_dim=128, outermost=True)
    output, _ = model(x, style)
    print("Output shape:", output.shape)

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *