這段程式碼是基於 U-Net 架構的字型風格轉換模型,並且加入了 AdaIN(Adaptive Instance Normalization) 來調整風格向量,並使用 Skip Connection(ResSkip) 來保留高解析度細節。以下是對你的需求調整的具體實現計劃:
改進計劃
- 加入 AdaIN 模塊
- AdaIN 用於風格轉換,將風格向量應用到不同層的特徵圖,使得風格調整更加細緻。
- 在 UNet 的
UnetSkipConnectionBlock
內,於innermost
及中間層
位置使用 AdaIN 來調整風格特徵。
- 加入 ResSkip 來提升解析度特徵的保留
- 增強 Skip Connection(ResSkip),將低解析度特徵保留並直接傳遞至高解析度層,提升字型細節。
- 在
UnetSkipConnectionBlock
的forward
中添加ResSkip
。
- 優化 Self-Attention 機制
- 目前 Self-Attention 只應用在
layer=4,6
,但可以根據特徵分辨率選擇是否應用。 - 讓 Self-Attention 影響較大的 receptive field,進一步增強風格一致性。
- 目前 Self-Attention 只應用在
改動後的核心代碼
1. 新增 AdaIN 層
class AdaptiveInstanceNorm2d(nn.Module):
def __init__(self, num_features, style_dim):
super(AdaptiveInstanceNorm2d, self).__init__()
self.norm = nn.InstanceNorm2d(num_features, affine=False)
self.style_scale = nn.Linear(style_dim, num_features)
self.style_bias = nn.Linear(style_dim, num_features)
def forward(self, x, style):
gamma = self.style_scale(style).view(style.shape[0], -1, 1, 1)
beta = self.style_bias(style).view(style.shape[0], -1, 1, 1)
return gamma * self.norm(x) + beta
2. 在 U-Net 中加入 AdaIN
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None,
norm_layer=nn.BatchNorm2d, layer=0, embedding_dim=128,
use_dropout=False, self_attention=False, blur=False, outermost=False, innermost=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
self.innermost = innermost
use_bias = norm_layer != nn.BatchNorm2d
self.use_adain = True if not outermost else False # 只在內部層使用 AdaIN
if input_nc is None:
input_nc = outer_nc
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(inplace=False)
upnorm = norm_layer(outer_nc)
if innermost:
self.adain = AdaptiveInstanceNorm2d(inner_nc, embedding_dim) # 加入 AdaIN
upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
self.down = nn.Sequential(downrelu, downconv)
self.up = nn.Sequential(uprelu, upconv, upnorm)
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
self.down = nn.Sequential(downrelu, downconv, downnorm)
self.up = nn.Sequential(uprelu, upconv, upnorm)
if self.use_adain:
self.adain = AdaptiveInstanceNorm2d(inner_nc, embedding_dim)
if use_dropout:
self.up.add_module("dropout", nn.Dropout(0.5))
self.submodule = submodule
if self_attention and layer in [4, 6]:
self.self_attn = SelfAttention(inner_nc)
else:
self.self_attn = None
def forward(self, x, style=None):
encoded = self.down(x)
if self.self_attn is not None:
encoded = self.self_attn(encoded)
if self.use_adain and style is not None:
encoded = self.adain(encoded, style)
if self.submodule:
sub_output, encoded_real_A = self.submodule(encoded, style)
else:
sub_output = encoded
decoded = self.up(sub_output)
return torch.cat([x, decoded], 1), encoded_real_A
3. 加入 ResSkip
class ResSkip(nn.Module):
def __init__(self, channels):
super(ResSkip, self).__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
residual = x
out = self.relu(self.conv1(x))
out = self.conv2(out)
return out + residual # 保留高解析度特徵
並在 U-Net 的 forward
方法中加入:
if not self.outermost and not self.innermost:
decoded = self.res_skip(decoded)
優化點
- 更靈活的風格轉換
- 使用 AdaIN,允許模型在不同層應用風格變換,使得字型轉換更平滑。
- 避免傳統 BatchNorm 造成的風格信息丟失。
- 高解析度特徵保留
- ResSkip 允許細節特徵在高解析度層被保留,減少字型失真問題。
- 提升 Self-Attention 效果
- 只在影響較大的 receptive field 應用 Self-Attention。
這些改進應該能有效提升 Noto Sans → Zen Maru Gothic 的字型風格轉換質量!🚀
PyTorch 訓練腳本,已包含 AdaIN(Adaptive Instance Normalization)、ResSkip(跳躍連接增強高解析度特徵保留)以及 Self-Attention,專為字型風格轉換(Noto Sans → Zen Maru Gothic)設計。這個模型基於 U-Net,並加入了 AdaIN 來進行風格調整,以及 ResSkip 來提升高解析度特徵的保留度。
主要修改:
- AdaIN(Adaptive Instance Normalization)
- 讓風格向量影響不同層的特徵圖,使風格轉換更加靈活。
- ResSkip(Residual Skip Connection)
- 透過殘差連接增強高解析度特徵的保留度,提高字型細節清晰度。
- Self-Attention
- 讓模型學習長距離的關聯性,提高字型細節的準確度。
python
import torch
import torch.nn as nn
import torch.nn.functional as F
# Adaptive Instance Normalization (AdaIN)
class AdaptiveInstanceNorm2d(nn.Module):
def __init__(self, num_features, style_dim):
super(AdaptiveInstanceNorm2d, self).__init__()
self.norm = nn.InstanceNorm2d(num_features, affine=False)
self.style_scale = nn.Linear(style_dim, num_features)
self.style_bias = nn.Linear(style_dim, num_features)
def forward(self, x, style):
gamma = self.style_scale(style).view(style.shape[0], -1, 1, 1)
beta = self.style_bias(style).view(style.shape[0], -1, 1, 1)
return gamma * self.norm(x) + beta
# Residual Skip Connection
class ResSkip(nn.Module):
def __init__(self, channels):
super(ResSkip, self).__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
residual = x
out = self.relu(self.conv1(x))
out = self.conv2(out)
return out + residual
# Self-Attention Module
class SelfAttention(nn.Module):
def __init__(self, in_channels):
super(SelfAttention, self).__init__()
self.query = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
self.key = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
batch, C, H, W = x.size()
query = self.query(x).view(batch, -1, H * W).permute(0, 2, 1)
key = self.key(x).view(batch, -1, H * W)
attn_map = torch.bmm(query, key)
attn_map = F.softmax(attn_map, dim=-1)
value = self.value(x).view(batch, -1, H * W)
out = torch.bmm(value, attn_map.permute(0, 2, 1)).view(batch, C, H, W)
return self.gamma * out + x
# UNet Block with AdaIN & Self-Attention
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None,
norm_layer=nn.BatchNorm2d, layer=0, embedding_dim=128,
use_dropout=False, self_attention=False, outermost=False, innermost=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
self.innermost = innermost
self.use_adain = not outermost # Only use AdaIN inside the network
use_bias = norm_layer != nn.BatchNorm2d
if input_nc is None:
input_nc = outer_nc
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(inplace=True)
upnorm = norm_layer(outer_nc)
if innermost:
self.adain = AdaptiveInstanceNorm2d(inner_nc, embedding_dim)
upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
self.down = nn.Sequential(downrelu, downconv)
self.up = nn.Sequential(uprelu, upconv, upnorm)
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
self.down = nn.Sequential(downrelu, downconv, downnorm)
self.up = nn.Sequential(uprelu, upconv, upnorm)
if self.use_adain:
self.adain = AdaptiveInstanceNorm2d(inner_nc, embedding_dim)
if use_dropout:
self.up.add_module("dropout", nn.Dropout(0.5))
self.submodule = submodule
self.self_attn = SelfAttention(inner_nc) if self_attention and layer in [4, 6] else None
self.res_skip = ResSkip(outer_nc) if not outermost and not innermost else None
def forward(self, x, style=None):
encoded = self.down(x)
if self.self_attn:
encoded = self.self_attn(encoded)
if self.use_adain and style is not None:
encoded = self.adain(encoded, style)
if self.submodule:
sub_output, _ = self.submodule(encoded, style)
else:
sub_output = encoded
decoded = self.up(sub_output)
if self.res_skip:
decoded = self.res_skip(decoded)
return (torch.cat([x, decoded], 1) if not self.outermost else decoded), None
# Full U-Net Model
class UNetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, num_downs, embedding_dim=128):
super(UNetGenerator, self).__init__()
unet_block = UnetSkipConnectionBlock(64, 64, layer=num_downs-1, innermost=True)
for i in range(num_downs - 2, 0, -1):
unet_block = UnetSkipConnectionBlock(64, 64, submodule=unet_block, layer=i)
self.model = UnetSkipConnectionBlock(output_nc, 64, input_nc=input_nc, submodule=unet_block, outermost=True)
def forward(self, x, style):
return self.model(x, style)[0]
# Instantiate Model
input_nc, output_nc, num_downs, embedding_dim = 1, 1, 7, 128 # For grayscale fonts
net = UNetGenerator(input_nc, output_nc, num_downs, embedding_dim)
# Print Model Summary
print(net)
import torch
import torch.nn as nn
class AdaptiveInstanceNorm2d(nn.Module):
def __init__(self, num_features, style_dim):
super(AdaptiveInstanceNorm2d, self).__init__()
self.norm = nn.InstanceNorm2d(num_features, affine=False)
self.style_scale = nn.Linear(style_dim, num_features)
self.style_bias = nn.Linear(style_dim, num_features)
def forward(self, x, style):
gamma = self.style_scale(style).view(style.shape[0], -1, 1, 1)
beta = self.style_bias(style).view(style.shape[0], -1, 1, 1)
return gamma * self.norm(x) + beta
class SelfAttention(nn.Module):
def __init__(self, in_dim):
super(SelfAttention, self).__init__()
self.query = nn.Conv2d(in_dim, in_dim // 8, 1)
self.key = nn.Conv2d(in_dim, in_dim // 8, 1)
self.value = nn.Conv2d(in_dim, in_dim, 1)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
batch, C, width, height = x.size()
proj_query = self.query(x).view(batch, -1, width * height).permute(0, 2, 1)
proj_key = self.key(x).view(batch, -1, width * height)
energy = torch.bmm(proj_query, proj_key)
attention = self.softmax(energy)
proj_value = self.value(x).view(batch, -1, width * height)
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
out = out.view(batch, C, width, height)
return out + x
class ResSkip(nn.Module):
def __init__(self, channels):
super(ResSkip, self).__init__()
self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return x + self.relu(self.conv(x))
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None,
norm_layer=nn.BatchNorm2d, layer=0, embedding_dim=128,
use_dropout=False, self_attention=False, outermost=False, innermost=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
self.innermost = innermost
self.use_adain = not outermost # Only use AdaIN inside the network
use_bias = norm_layer != nn.BatchNorm2d
if input_nc is None:
input_nc = outer_nc
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(inplace=True)
upnorm = norm_layer(outer_nc)
if innermost:
self.adain = AdaptiveInstanceNorm2d(inner_nc, embedding_dim)
upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
self.down = nn.Sequential(downrelu, downconv)
self.up = nn.Sequential(uprelu, upconv, upnorm)
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
self.down = nn.Sequential(downrelu, downconv, downnorm)
self.up = nn.Sequential(uprelu, upconv, upnorm)
if self.use_adain:
self.adain = AdaptiveInstanceNorm2d(inner_nc, embedding_dim)
if use_dropout:
self.up.add_module("dropout", nn.Dropout(0.5))
self.submodule = submodule
self.self_attn = SelfAttention(inner_nc) if self_attention and layer in [4, 6] else None
self.res_skip = ResSkip(outer_nc) if not outermost and not innermost else None
def forward(self, x, style=None):
encoded = self.down(x)
if self.self_attn:
encoded = self.self_attn(encoded)
if self.use_adain and style is not None:
encoded = self.adain(encoded, style)
if self.submodule:
sub_output, _ = self.submodule(encoded, style)
else:
sub_output = encoded
decoded = self.up(sub_output)
if self.res_skip:
decoded = self.res_skip(decoded)
return (torch.cat([x, decoded], 1) if not self.outermost else decoded), None
# 測試 UNet 區塊
if __name__ == "__main__":
x = torch.randn(1, 3, 256, 256) # 假設輸入圖像
style = torch.randn(1, 128) # 假設風格向量
model = UnetSkipConnectionBlock(outer_nc=64, inner_nc=128, input_nc=3,
embedding_dim=128, outermost=True)
output, _ = model(x, style)
print("Output shape:", output.shape)