1. 了解 UnetSkipConnectionBlock
UnetSkipConnectionBlock
是 UNetGenerator 中用於處理 skip connection 的模組。它接收上一層的輸出和 skip connection 的輸入,將它們Concatenate在一起後進行處理,然後輸出到下一層。
2. 決定 Self-Attention 層的位置
在 UnetSkipConnectionBlock
中,self-attention 層可以放在 Concatenate 之後,也就是在 self.conv
之前。這樣可以讓模型在融合 skip connection 的資訊後,更好地捕捉長程依賴關係。
3. 修改 UnetSkipConnectionBlock
打開 models/networks.py
,找到 UnetSkipConnectionBlock
的定義。在 __init__
函數中加入 self-attention 層的初始化:
Python
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_channels, inner_channels, input_channels=None,
submodule=None, outermost=False, inner_most=False,
norm_layer=nn.BatchNorm2d, use_dropout=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outer_most = outermost
if input_channels is None:
input_channels = outer_channels
downconv = nn.Conv2d(input_channels, inner_channels,
kernel_size=4, stride=2, padding=1)
downnorm = norm_layer(inner_channels)
downrelu = nn.ReLU(True)
uprelu = nn.ReLU(True)
if outermost:
upsample = nn.Upsample(scale_factor=2, mode='bilinear')
upconv = nn.Conv2d(inner_channels * 2, outer_channels,
kernel_size=1, stride=1)
upnorm = norm_layer(outer_channels)
self.attn = SelfAttention(outer_channels) # 初始化 self-attention 層
elif inner_most:
upsample = nn.Upsample(scale_factor=2, mode='bilinear')
upconv = nn.Conv2d(inner_channels, outer_channels,
kernel_size=1, stride=1)
upnorm = norm_layer(outer_channels)
else:
upsample = nn.Upsample(scale_factor=2, mode='bilinear')
upconv = nn.Conv2d(inner_channels * 2, outer_channels,
kernel_size=1, stride=1)
upnorm = norm_layer(outer_channels)
self.attn = SelfAttention(outer_channels) # 初始化 self-attention 層
self.down_conv = nn.Sequential(downconv, downnorm, downrelu)
self.up_conv = nn.Sequential(upconv, upnorm, uprelu)
self.down_sample = submodule
self.up_sample = upsample
def forward(self, x):
if self.outer_most:
x_down = self.down_conv(x)
x_up = self.up_sample(x_down)
x_out = torch.cat([x, x_up], 1)
x_out = self.up_conv(x_out)
x_out = self.attn(x_out) # 加入 self-attention 層
return x_out
elif self.inner_most:
x_down = self.down_conv(x)
x_up = self.up_sample(x_down)
x_out = torch.cat([x, x_up], 1)
x_out = self.up_conv(x_out)
return x_out
else:
x_down = self.down_conv(x)
x_up = self.up_sample(x_down)
x_out = torch.cat([x, x_up], 1)
x_out = self.up_conv(x_out)
x_out = self.attn(x_out) # 加入 self-attention 層
return x_out
4. 注意事項
- 請確保
SelfAttention
類的定義已加入到models/networks.py
中。 UnetSkipConnectionBlock
的通道數需要仔細調整,以確保 self-attention 層的輸入和輸出通道數匹配。- 加入 self-attention 層後,可能需要調整訓練參數,例如學習率,以達到最佳效果。
5. 測試
完成以上修改後,重新訓練你的模型,並觀察生成字型的品質是否有提升。