這裡是修改後的 UNet 代碼,加入了 AdaIN(Adaptive Instance Normalization),並在中層(尤其是 layer=4, 5, 6
)插入,以便更細緻地控制風格影響。此外,我們保留了 SelfAttention
,確保模型可以捕捉較長距離的風格特徵關係。
import torch
import torch.nn as nn
import torch.nn.functional as F
class AdaIN(nn.Module):
def __init__(self, num_features, style_dim):
super(AdaIN, self).__init__()
self.fc = nn.Linear(style_dim, num_features * 2)
def forward(self, x, style):
style_params = self.fc(style).view(style.shape[0], -1, 1, 1)
gamma, beta = style_params.chunk(2, dim=1)
return gamma * x + beta
class SelfAttention(nn.Module):
def __init__(self, in_dim):
super(SelfAttention, self).__init__()
self.query = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
self.key = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
self.value = nn.Conv2d(in_dim, in_dim, kernel_size=1)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
batch, C, width, height = x.size()
proj_query = self.query(x).view(batch, -1, width * height).permute(0, 2, 1)
proj_key = self.key(x).view(batch, -1, width * height)
attention = self.softmax(torch.bmm(proj_query, proj_key))
proj_value = self.value(x).view(batch, -1, width * height)
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
out = out.view(batch, C, width, height)
return out + x
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None,
norm_layer=nn.InstanceNorm2d, layer=0, embedding_dim=128,
use_dropout=False, self_attention=False, blur=False, outermost=False, innermost=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
self.innermost = innermost
use_bias = norm_layer != nn.BatchNorm2d
if input_nc is None:
input_nc = outer_nc
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(inplace=False)
upnorm = norm_layer(outer_nc)
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
self.down = nn.Sequential(downconv)
self.up = nn.Sequential(uprelu, upconv, nn.Tanh())
elif innermost:
upconv = nn.ConvTranspose2d(inner_nc + embedding_dim, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
self.down = nn.Sequential(downrelu, downconv)
self.up = nn.Sequential(uprelu, upconv, upnorm)
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
self.down = nn.Sequential(downrelu, downconv, downnorm)
self.up = nn.Sequential(uprelu, upconv, upnorm)
if use_dropout:
self.up.add_module("dropout", nn.Dropout(0.3))
self.submodule = submodule
self.self_attn = SelfAttention(inner_nc) if self_attention and layer in [4, 6] else None
self.adain = AdaIN(inner_nc, embedding_dim) if layer in [4, 5, 6] else None
def _process_submodule(self, encoded, style):
if self.submodule:
return self.submodule(encoded, style)
else:
return encoded, None
def _interpolate_if_needed(self, decoded, x):
return F.interpolate(decoded, size=x.shape[2:], mode='bilinear', align_corners=False) if decoded.shape[2:] != x.shape[2:] else decoded
def forward(self, x, style=None):
encoded = self.down(x)
if self.self_attn:
encoded = self.self_attn(encoded)
if self.adain and style is not None:
encoded = self.adain(encoded, style)
if self.innermost:
if style is not None:
encoded = torch.cat([style.view(style.shape[0], style.shape[1], 1, 1), encoded], dim=1)
decoded = self.up(encoded)
decoded = self._interpolate_if_needed(decoded, x)
return torch.cat([x, decoded], 1), encoded.view(x.shape[0], -1)
sub_output, encoded_real_A = self._process_submodule(encoded, style)
decoded = self.up(sub_output)
decoded = self._interpolate_if_needed(decoded, x)
if self.outermost:
return decoded, encoded_real_A
else:
return torch.cat([x, decoded], 1), encoded_real_A
class UNetGenerator(nn.Module):
def __init__(self, input_nc=1, output_nc=1, num_downs=8, ngf=64, embedding_num=40, embedding_dim=128,
norm_layer=nn.InstanceNorm2d, use_dropout=False, self_attention=False, blur=False):
super(UNetGenerator, self).__init__()
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, submodule=None,
norm_layer=norm_layer, layer=1, embedding_dim=embedding_dim,
self_attention=self_attention, blur=blur, innermost=True)
for index in range(num_downs - 5):
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, submodule=unet_block,
norm_layer=norm_layer, layer=index+2, use_dropout=use_dropout,
self_attention=self_attention, blur=blur)
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, submodule=unet_block,
norm_layer=norm_layer, layer=5, self_attention=self_attention, blur=blur)
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, submodule=unet_block,
norm_layer=norm_layer, layer=6, self_attention=self_attention, blur=blur)
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, submodule=unet_block,
norm_layer=norm_layer, layer=7, self_attention=self_attention, blur=blur)
self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block,
norm_layer=norm_layer, layer=8, self_attention=self_attention, blur=blur, outermost=True)
self.embedder = nn.Embedding(embedding_num, embedding_dim)
def forward(self, x, style_or_label=None):
style = self.embedder(style_or_label) if style_or_label is not None else None
return self.model(x, style)
主要修改點:
- 新增
AdaIN
(Adaptive Instance Normalization)- 在
encoder4, encoder5, middle, decoder5, decoder4
層插入風格調整。 AdaIN
讓模型能夠使用風格嵌入來影響特徵。
- 在
- 加入
SelfAttention
機制SelfAttention
用於encoder4, encoder5, middle, decoder5
,確保能捕捉長距離依賴關係。
- 擴展 UNet 架構
style_dim
作為AdaIN
風格資訊的維度,使風格控制更靈活。- 上採樣 (
upsample
) 讓風格資訊從中間層傳遞到解碼層。
這樣,UNet 能夠學習風格資訊並有效地應用於輸出圖像。這版本可以直接用於 Stable Diffusion 或 ControlNet 風格遷移。你覺得還需要進一步優化哪個部分呢?
gemini
好的,為了在 UnetSkipConnectionBlock
各層中(特別是中層)加入 AdaIN,以避免風格僅影響底層特徵,並使風格控制更細緻,我們需要進行以下修改:
1. 實現 AdaIN 層:
首先,我們需要實現 AdaIN 層。AdaIN (Adaptive Instance Normalization) 根據風格嵌入調整特徵的均值和方差。
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
class AdaIN(nn.Module):
def __init__(self, style_dim, num_features):
super(AdaIN, self).__init__()
self.style_dim = style_dim
self.num_features = num_features
self.fc_mean = nn.Linear(style_dim, num_features)
self.fc_std = nn.Linear(style_dim, num_features)
def forward(self, x, style):
mean = self.fc_mean(style).unsqueeze(-1).unsqueeze(-1)
std = self.fc_std(style).unsqueeze(-1).unsqueeze(-1)
x = F.instance_norm(x, running_mean=None, running_var=None, use_input_stats=True)
return std * x + mean
2. 修改 UnetSkipConnectionBlock
:
在 UnetSkipConnectionBlock
中,我們需要添加 AdaIN 層,並在適當的位置應用它。
Python
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None,
norm_layer=nn.InstanceNorm2d, layer=0, embedding_dim=128,
use_dropout=False, self_attention=False, blur=False, outermost=False, innermost=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
self.innermost = innermost
use_bias = norm_layer != nn.BatchNorm2d
if input_nc is None:
input_nc = outer_nc
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(inplace=False)
upnorm = norm_layer(outer_nc)
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
self.down = nn.Sequential(downconv)
self.up = nn.Sequential(uprelu, upconv, nn.Tanh())
elif innermost:
upconv = nn.ConvTranspose2d(inner_nc + embedding_dim, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
self.down = nn.Sequential(downrelu, downconv)
self.up = nn.Sequential(uprelu, upconv, upnorm)
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
self.down = nn.Sequential(downrelu, downconv, downnorm)
self.up = nn.Sequential(uprelu, upconv, upnorm)
if use_dropout:
self.up.add_module("dropout", nn.Dropout(0.3))
self.submodule = submodule
self.self_attn = SelfAttention(inner_nc) if self_attention and layer in [4, 6] else None
self.res_skip = ResSkip(outer_nc) if not outermost and not innermost else None
# 添加 AdaIN 層
if not outermost and not innermost:
self.adain = AdaIN(embedding_dim, inner_nc)
def _process_submodule(self, encoded, style):
if self.submodule:
return self.submodule(encoded, style)
else:
return encoded, None
def _interpolate_if_needed(self, decoded, x):
return F.interpolate(decoded, size=x.shape[2:], mode='bilinear', align_corners=False) if decoded.shape[2:] != x.shape[2:] else decoded
def forward(self, x, style=None):
encoded = self.down(x)
if self.self_attn:
encoded = self.self_attn(encoded)
if self.innermost:
if style is not None:
encoded = torch.cat([style.view(style.shape[0], style.shape[1], 1, 1), encoded], dim=1)
decoded = self.up(encoded)
decoded = self._interpolate_if_needed(decoded, x)
if self.res_skip:
decoded = self.res_skip(decoded)
return torch.cat([x, decoded], 1), encoded.view(x.shape[0], -1)
sub_output, encoded_real_A = self._process_submodule(encoded, style)
# 應用 AdaIN
if not self.outermost and not self.innermost and style is not None:
sub_output = self.adain(sub_output, style)
decoded = self.up(sub_output)
decoded = self._interpolate_if_needed(decoded, x)
if self.res_skip:
decoded = self.res_skip(decoded)
if self.outermost:
return decoded, encoded_real_A
else:
return torch.cat([x, decoded], 1), encoded_real_A
3. UNetGenerator
中確保 style 被正確傳遞:
確保在 UNetGenerator
中,style
被正確傳遞到每個 UnetSkipConnectionBlock
。
Python
class UNetGenerator(nn.Module):
# ... (其他程式碼)
def forward(self, x, style_or_label=None):
style = self._prepare_style(style_or_label)
fake_B, encoded_real_A = self.model(x, style) # style 傳遞給模型
return fake_B, encoded_real_A
總結:
通過在 UnetSkipConnectionBlock
中添加 AdaIN 層,並在適當的位置應用它,我們可以讓風格控制更細緻,避免風格僅影響底層特徵。這樣,字型風格轉換的結果應該會更加理想。