在 UnetSkipConnectionBlock , 加入 self-attention 層

在 UnetSkipConnectionBlock 中加入 Self-Attention 層 可以讓模型更好地捕捉全局上下文信息,從而提升生成圖像的質量。Self-Attention 機制最早在 Attention Is All You Need 中提出,後來被廣泛應用於圖像生成任務中(如 SAGAN)。

以下是修改後的 UnetSkipConnectionBlock,加入了 Self-Attention 層:


修改後的 UnetSkipConnectionBlock

import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttention(nn.Module):
    """ Self-Attention 層 """
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        self.in_channels = in_channels
        # 查詢、鍵、值的卷積層
        self.query = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        # 輸出卷積層
        self.gamma = nn.Parameter(torch.zeros(1))  # 可學習的縮放參數

    def forward(self, x):
        batch_size, C, height, width = x.size()
        # 計算查詢、鍵、值
        query = self.query(x).view(batch_size, -1, height * width).permute(0, 2, 1)  # (B, H*W, C//8)
        key = self.key(x).view(batch_size, -1, height * width)  # (B, C//8, H*W)
        value = self.value(x).view(batch_size, -1, height * width)  # (B, C, H*W)

        # 計算注意力分數
        attention = torch.bmm(query, key)  # (B, H*W, H*W)
        attention = F.softmax(attention, dim=-1)  # 沿最後一維做 softmax

        # 加權求和
        out = torch.bmm(value, attention.permute(0, 2, 1))  # (B, C, H*W)
        out = out.view(batch_size, C, height, width)  # 恢復形狀

        # 加入殘差連接
        out = self.gamma * out + x
        return out


class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc, input_nc=None,
                 submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False, use_attention=False):
        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost
        self.use_attention = use_attention  # 是否使用 Self-Attention
        if input_nc is None:
            input_nc = outer_nc
        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.ReLU(True)
        upnorm = norm_layer(outer_nc)

        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1)
            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1)
            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]
            model = down + up
        else:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1)
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]
            if use_dropout:
                model = down + [submodule] + up + [nn.Dropout(0.5)]
            else:
                model = down + [submodule] + up

        # 如果需要,加入 Self-Attention 層
        if use_attention:
            self.attention = SelfAttention(outer_nc)
        else:
            self.attention = None

        self.model = nn.Sequential(*model)

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            out = self.model(x)
            # 如果需要,應用 Self-Attention
            if self.use_attention:
                out = self.attention(out)
            return torch.cat([x, out], 1)

修改點說明

  1. 新增 SelfAttention 層
    • 使用 1×1 卷積計算查詢(Query)、鍵(Key)和值(Value)。
    • 通過矩陣乘法和 Softmax 計算注意力分數。
    • 將注意力分數應用於值(Value),並加入殘差連接。
  2. 在 UnetSkipConnectionBlock 中加入 use_attention 參數
    • 如果 use_attention=True,則在跳躍連接後加入 Self-Attention 層。
  3. 前向傳播時應用 Self-Attention
    • 如果不是最外層,則在跳躍連接後應用 Self-Attention。

使用範例

# 創建一個帶有 Self-Attention 的 U-Net 生成器
class UnetGenerator(nn.Module):
    def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, use_attention=False):
        super(UnetGenerator, self).__init__()

        # 構建 U-Net 架構
        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True, use_attention=use_attention)
        for i in range(num_downs - 5):
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout, use_attention=use_attention)
        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_attention=use_attention)
        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_attention=use_attention)
        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_attention=use_attention)
        self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer, use_attention=use_attention)

    def forward(self, input):
        return self.model(input)

總結

  • 通過在 UnetSkipConnectionBlock 中加入 Self-Attention 層,模型可以更好地捕捉全局上下文信息。
  • 可以通過 use_attention 參數控制是否使用 Self-Attention。
  • 這種修改特別適用於需要高質量圖像生成的任務(如字體風格轉換)。

About the author

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *