在 UnetSkipConnectionBlock
中加入 Self-Attention 層 可以讓模型更好地捕捉全局上下文信息,從而提升生成圖像的質量。Self-Attention 機制最早在 Attention Is All You Need 中提出,後來被廣泛應用於圖像生成任務中(如 SAGAN)。
以下是修改後的 UnetSkipConnectionBlock
,加入了 Self-Attention 層:
修改後的 UnetSkipConnectionBlock
import torch import torch.nn as nn import torch.nn.functional as F class SelfAttention(nn.Module): """ Self-Attention 層 """ def __init__(self, in_channels): super(SelfAttention, self).__init__() self.in_channels = in_channels # 查詢、鍵、值的卷積層 self.query = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1) self.key = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1) self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1) # 輸出卷積層 self.gamma = nn.Parameter(torch.zeros(1)) # 可學習的縮放參數 def forward(self, x): batch_size, C, height, width = x.size() # 計算查詢、鍵、值 query = self.query(x).view(batch_size, -1, height * width).permute(0, 2, 1) # (B, H*W, C//8) key = self.key(x).view(batch_size, -1, height * width) # (B, C//8, H*W) value = self.value(x).view(batch_size, -1, height * width) # (B, C, H*W) # 計算注意力分數 attention = torch.bmm(query, key) # (B, H*W, H*W) attention = F.softmax(attention, dim=-1) # 沿最後一維做 softmax # 加權求和 out = torch.bmm(value, attention.permute(0, 2, 1)) # (B, C, H*W) out = out.view(batch_size, C, height, width) # 恢復形狀 # 加入殘差連接 out = self.gamma * out + x return out class UnetSkipConnectionBlock(nn.Module): def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False, use_attention=False): super(UnetSkipConnectionBlock, self).__init__() self.outermost = outermost self.use_attention = use_attention # 是否使用 Self-Attention if input_nc is None: input_nc = outer_nc downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1) downrelu = nn.LeakyReLU(0.2, True) downnorm = norm_layer(inner_nc) uprelu = nn.ReLU(True) upnorm = norm_layer(outer_nc) if outermost: upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1) down = [downconv] up = [uprelu, upconv, nn.Tanh()] model = down + [submodule] + up elif innermost: upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1) down = [downrelu, downconv] up = [uprelu, upconv, upnorm] model = down + up else: upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1) down = [downrelu, downconv, downnorm] up = [uprelu, upconv, upnorm] if use_dropout: model = down + [submodule] + up + [nn.Dropout(0.5)] else: model = down + [submodule] + up # 如果需要,加入 Self-Attention 層 if use_attention: self.attention = SelfAttention(outer_nc) else: self.attention = None self.model = nn.Sequential(*model) def forward(self, x): if self.outermost: return self.model(x) else: out = self.model(x) # 如果需要,應用 Self-Attention if self.use_attention: out = self.attention(out) return torch.cat([x, out], 1)
修改點說明
- 新增
SelfAttention
層:- 使用 1×1 卷積計算查詢(Query)、鍵(Key)和值(Value)。
- 通過矩陣乘法和 Softmax 計算注意力分數。
- 將注意力分數應用於值(Value),並加入殘差連接。
- 在
UnetSkipConnectionBlock
中加入use_attention
參數:- 如果
use_attention=True
,則在跳躍連接後加入 Self-Attention 層。
- 如果
- 前向傳播時應用 Self-Attention:
- 如果不是最外層,則在跳躍連接後應用 Self-Attention。
使用範例
# 創建一個帶有 Self-Attention 的 U-Net 生成器 class UnetGenerator(nn.Module): def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, use_attention=False): super(UnetGenerator, self).__init__() # 構建 U-Net 架構 unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True, use_attention=use_attention) for i in range(num_downs - 5): unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout, use_attention=use_attention) unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_attention=use_attention) unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_attention=use_attention) unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_attention=use_attention) self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer, use_attention=use_attention) def forward(self, input): return self.model(input)
總結
- 通過在
UnetSkipConnectionBlock
中加入 Self-Attention 層,模型可以更好地捕捉全局上下文信息。 - 可以通過
use_attention
參數控制是否使用 Self-Attention。 - 這種修改特別適用於需要高質量圖像生成的任務(如字體風格轉換)。