在 UnetSkipConnectionBlock 中加入 Self-Attention 層 可以讓模型更好地捕捉全局上下文信息,從而提升生成圖像的質量。Self-Attention 機制最早在 Attention Is All You Need 中提出,後來被廣泛應用於圖像生成任務中(如 SAGAN)。
以下是修改後的 UnetSkipConnectionBlock,加入了 Self-Attention 層:
修改後的 UnetSkipConnectionBlock
import torch
import torch.nn as nn
import torch.nn.functional as F
class SelfAttention(nn.Module):
""" Self-Attention 層 """
def __init__(self, in_channels):
super(SelfAttention, self).__init__()
self.in_channels = in_channels
# 查詢、鍵、值的卷積層
self.query = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
self.key = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1)
# 輸出卷積層
self.gamma = nn.Parameter(torch.zeros(1)) # 可學習的縮放參數
def forward(self, x):
batch_size, C, height, width = x.size()
# 計算查詢、鍵、值
query = self.query(x).view(batch_size, -1, height * width).permute(0, 2, 1) # (B, H*W, C//8)
key = self.key(x).view(batch_size, -1, height * width) # (B, C//8, H*W)
value = self.value(x).view(batch_size, -1, height * width) # (B, C, H*W)
# 計算注意力分數
attention = torch.bmm(query, key) # (B, H*W, H*W)
attention = F.softmax(attention, dim=-1) # 沿最後一維做 softmax
# 加權求和
out = torch.bmm(value, attention.permute(0, 2, 1)) # (B, C, H*W)
out = out.view(batch_size, C, height, width) # 恢復形狀
# 加入殘差連接
out = self.gamma * out + x
return out
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc, input_nc=None,
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False, use_attention=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
self.use_attention = use_attention # 是否使用 Self-Attention
if input_nc is None:
input_nc = outer_nc
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(True)
upnorm = norm_layer(outer_nc)
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1)
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
elif innermost:
upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1)
down = [downrelu, downconv]
up = [uprelu, upconv, upnorm]
model = down + up
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
# 如果需要,加入 Self-Attention 層
if use_attention:
self.attention = SelfAttention(outer_nc)
else:
self.attention = None
self.model = nn.Sequential(*model)
def forward(self, x):
if self.outermost:
return self.model(x)
else:
out = self.model(x)
# 如果需要,應用 Self-Attention
if self.use_attention:
out = self.attention(out)
return torch.cat([x, out], 1)
修改點說明
- 新增
SelfAttention層:- 使用 1×1 卷積計算查詢(Query)、鍵(Key)和值(Value)。
- 通過矩陣乘法和 Softmax 計算注意力分數。
- 將注意力分數應用於值(Value),並加入殘差連接。
- 在
UnetSkipConnectionBlock中加入use_attention參數:- 如果
use_attention=True,則在跳躍連接後加入 Self-Attention 層。
- 如果
- 前向傳播時應用 Self-Attention:
- 如果不是最外層,則在跳躍連接後應用 Self-Attention。
使用範例
# 創建一個帶有 Self-Attention 的 U-Net 生成器
class UnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, use_attention=False):
super(UnetGenerator, self).__init__()
# 構建 U-Net 架構
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True, use_attention=use_attention)
for i in range(num_downs - 5):
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout, use_attention=use_attention)
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_attention=use_attention)
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_attention=use_attention)
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_attention=use_attention)
self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer, use_attention=use_attention)
def forward(self, input):
return self.model(input)
總結
- 通過在
UnetSkipConnectionBlock中加入 Self-Attention 層,模型可以更好地捕捉全局上下文信息。 - 可以通過
use_attention參數控制是否使用 Self-Attention。 - 這種修改特別適用於需要高質量圖像生成的任務(如字體風格轉換)。