增 Residual Block 到 Generator

這裡是包含 Residual Block 的完整 UNetGenerator Python 腳本,適用於 Noto Sans → Zen Maru Gothic 的字型風格轉換。

主要改動

  1. 加入 Residual Block
    • UnetSkipConnectionBlock 的 Encoder(下採樣部分)加入 ResNet-style 跳躍連接,提升特徵傳遞效果。
    • 殘差連接(Residual Connection) 有助於深層網路的梯度流動,避免訊息遺失。
  2. 改進 UNet 架構
    • 在中間層(如 layer=3,4,5使用 Residual Block
    • 選擇性加入 Self-Attention 來增強風格細節學習。

完整 UNet Generator(含 Residual Block)

import torch
import torch.nn as nn
import torch.nn.functional as F

# ------------------------
# Residual Block
# ------------------------
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.norm1 = nn.InstanceNorm2d(in_channels)
        self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.norm2 = nn.InstanceNorm2d(in_channels)

    def forward(self, x):
        identity = x  # 保留原始輸入
        out = self.conv1(x)
        out = self.norm1(out)
        out = F.relu(out, inplace=True)
        out = self.conv2(out)
        out = self.norm2(out)
        return out + identity  # 殘差連接

# ------------------------
# Self-Attention Module
# ------------------------
class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        self.query = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, C, width, height = x.size()
        proj_query = self.query(x).view(batch_size, -1, width * height).permute(0, 2, 1)  # (B, N, C//8)
        proj_key = self.key(x).view(batch_size, -1, width * height)  # (B, C//8, N)
        energy = torch.bmm(proj_query, proj_key)  # (B, N, N)
        attention = F.softmax(energy, dim=-1)

        proj_value = self.value(x).view(batch_size, -1, width * height)  # (B, C, N)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))  # (B, C, N)
        out = out.view(batch_size, C, width, height)

        out = self.gamma * out + x  # 殘差連接
        return out

# ------------------------
# UNet Skip Connection Block
# ------------------------
class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None, 
                 use_attention=False, use_residual=False, norm_layer=nn.InstanceNorm2d, layer=0):
        super(UnetSkipConnectionBlock, self).__init__()
        outermost = (layer == 8)
        innermost = (layer == 1)
        self.outermost = outermost
        self.innermost = innermost
        self.layer = layer
        self.use_attention = use_attention
        self.use_residual = use_residual

        use_bias = norm_layer == nn.InstanceNorm2d  # 只在 InstanceNorm2d 時使用 bias
        self.norm_layer = norm_layer

        if input_nc is None:
            input_nc = outer_nc

        # Downsampling
        self.down = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.down_norm = norm_layer(inner_nc) if not outermost else nn.Identity()

        # Upsampling
        if outermost:
            self.up = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1)
            self.up_norm = nn.Identity()
        elif innermost:
            self.up = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            self.up_norm = norm_layer(outer_nc)
        else:
            self.up = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            self.up_norm = norm_layer(outer_nc)

        # 加入 Residual Block
        if use_residual and not outermost:
            self.res_block = ResidualBlock(inner_nc)

        if use_attention:
            self.attn = SelfAttention(inner_nc)

        self.submodule = submodule

    def forward(self, x):
        down_x = self.down(x)
        if not self.outermost:
            down_x = self.down_norm(down_x)
            down_x = F.leaky_relu(down_x, 0.2, inplace=True)

        # 加入 Residual Block
        if self.use_residual and not self.outermost:
            down_x = self.res_block(down_x)

        # 加入 Self-Attention
        if self.use_attention and not self.outermost:
            down_x = self.attn(down_x)

        if self.submodule is not None:
            down_x = self.submodule(down_x)

        up_x = self.up(down_x)
        up_x = self.up_norm(up_x)

        if self.outermost:
            return up_x
        else:
            up_x = F.relu(up_x, inplace=True)
            return torch.cat([up_x, x], dim=1)

# ------------------------
# UNet Generator
# ------------------------
class UNetGenerator(nn.Module):
    def __init__(self, input_nc=1, output_nc=1, num_downs=8, ngf=64, 
                 use_attention=True, use_residual=True, norm_layer=nn.InstanceNorm2d):
        super(UNetGenerator, self).__init__()

        # 最內層(bottleneck)
        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, norm_layer=norm_layer, layer=1)

        # 中間層
        for index in range(num_downs - 5):
            loop_use_attention = use_attention and (index + 2 == 4)
            loop_use_residual = use_residual
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, submodule=unet_block, 
                                                 norm_layer=norm_layer, use_attention=loop_use_attention, 
                                                 use_residual=loop_use_residual, layer=index + 2)

        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, submodule=unet_block, norm_layer=norm_layer, layer=5)
        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, submodule=unet_block, norm_layer=norm_layer, 
                                             use_attention=use_attention, use_residual=use_residual, layer=6)
        unet_block = UnetSkipConnectionBlock(ngf * 1, ngf * 2, submodule=unet_block, norm_layer=norm_layer, layer=7)

        # 最外層
        self.model = UnetSkipConnectionBlock(output_nc, ngf * 1, input_nc=input_nc, submodule=unet_block, 
                                             norm_layer=norm_layer, layer=8)

    def forward(self, x):
        output = self.model(x)
        output = torch.tanh(output)  # 使用 Tanh 縮放輸出到 [-1, 1]
        return output

改進點

加入 Residual Block,減少梯度消失
在 Encoder 層中選擇性加入 Self-Attention
更穩定的梯度流,使風格轉換效果更佳 🚀

這樣的 UNet 結構適合字型風格轉換,能夠更有效地學習 Zen Maru Gothic 的平滑筆劃風格!

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *