將 self-attention 添加到 zi2zi-pytorch 的 UnetSkipConnectionBlock 之中

1. 了解 UnetSkipConnectionBlock

UnetSkipConnectionBlock 是 UNetGenerator 中用於處理 skip connection 的模組。它接收上一層的輸出和 skip connection 的輸入,將它們Concatenate在一起後進行處理,然後輸出到下一層。

2. 決定 Self-Attention 層的位置

UnetSkipConnectionBlock 中,self-attention 層可以放在 Concatenate 之後,也就是在 self.conv 之前。這樣可以讓模型在融合 skip connection 的資訊後,更好地捕捉長程依賴關係。

3. 修改 UnetSkipConnectionBlock

打開 models/networks.py,找到 UnetSkipConnectionBlock 的定義。在 __init__ 函數中加入 self-attention 層的初始化:

Python

class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_channels, inner_channels, input_channels=None,
                 submodule=None, outermost=False, inner_most=False,
                 norm_layer=nn.BatchNorm2d, use_dropout=False):
        super(UnetSkipConnectionBlock, self).__init__()
        self.outer_most = outermost

        if input_channels is None:
            input_channels = outer_channels
        downconv = nn.Conv2d(input_channels, inner_channels,
                             kernel_size=4, stride=2, padding=1)
        downnorm = norm_layer(inner_channels)
        downrelu = nn.ReLU(True)
        uprelu = nn.ReLU(True)

        if outermost:
            upsample = nn.Upsample(scale_factor=2, mode='bilinear')
            upconv = nn.Conv2d(inner_channels * 2, outer_channels,
                                 kernel_size=1, stride=1)
            upnorm = norm_layer(outer_channels)
            self.attn = SelfAttention(outer_channels) # 初始化 self-attention 層
        elif inner_most:
            upsample = nn.Upsample(scale_factor=2, mode='bilinear')
            upconv = nn.Conv2d(inner_channels, outer_channels,
                                 kernel_size=1, stride=1)
            upnorm = norm_layer(outer_channels)
        else:
            upsample = nn.Upsample(scale_factor=2, mode='bilinear')
            upconv = nn.Conv2d(inner_channels * 2, outer_channels,
                                 kernel_size=1, stride=1)
            upnorm = norm_layer(outer_channels)
            self.attn = SelfAttention(outer_channels) # 初始化 self-attention 層

        self.down_conv = nn.Sequential(downconv, downnorm, downrelu)
        self.up_conv = nn.Sequential(upconv, upnorm, uprelu)
        self.down_sample = submodule
        self.up_sample = upsample

    def forward(self, x):
        if self.outer_most:
            x_down = self.down_conv(x)
            x_up = self.up_sample(x_down)
            x_out = torch.cat([x, x_up], 1)
            x_out = self.up_conv(x_out)
            x_out = self.attn(x_out) # 加入 self-attention 層
            return x_out
        elif self.inner_most:
            x_down = self.down_conv(x)
            x_up = self.up_sample(x_down)
            x_out = torch.cat([x, x_up], 1)
            x_out = self.up_conv(x_out)
            return x_out
        else:
            x_down = self.down_conv(x)
            x_up = self.up_sample(x_down)
            x_out = torch.cat([x, x_up], 1)
            x_out = self.up_conv(x_out)
            x_out = self.attn(x_out) # 加入 self-attention 層
            return x_out

4. 注意事項

  • 請確保 SelfAttention 類的定義已加入到 models/networks.py 中。
  • UnetSkipConnectionBlock 的通道數需要仔細調整,以確保 self-attention 層的輸入和輸出通道數匹配。
  • 加入 self-attention 層後,可能需要調整訓練參數,例如學習率,以達到最佳效果。

5. 測試

完成以上修改後,重新訓練你的模型,並觀察生成字型的品質是否有提升。

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *