Zi2ZiModel 字型風格轉換, 改善 SelfAttention 計算方式

修改前:

class SelfAttention(nn.Module):
    def __init__(self, channels):
        super(SelfAttention, self).__init__()
        self.query = nn.Conv2d(channels, channels // 8, kernel_size=1)
        self.key = nn.Conv2d(channels, channels // 8, kernel_size=1)
        self.value = nn.Conv2d(channels, channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        B, C, H, W = x.shape
        proj_query = self.query(x).view(B, -1, H * W).permute(0, 2, 1)  
        proj_key = self.key(x).view(B, -1, H * W)  
        energy = torch.bmm(proj_query, proj_key)  
        attention = F.softmax(energy, dim=-1)  
        proj_value = self.value(x).view(B, -1, H * W)  
        out = torch.bmm(proj_value, attention.permute(0, 2, 1)).view(B, C, H, W)  
        return self.gamma * out + x  

after, chatgpt:

class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        self.query = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        
    def forward(self, x):
        batch_size, C, width, height = x.size()
        proj_query = self.query(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        proj_key = self.key(x).view(batch_size, -1, width * height)
        attention = F.softmax(torch.bmm(proj_query, proj_key) / (C ** 0.5), dim=-1)
        proj_value = self.value(x).view(batch_size, -1, width * height)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1)).view(batch_size, C, width, height)
        return self.gamma * out + x

after, gemini 版本:

class SelfAttention(nn.Module):
    def __init__(self, channels):
        super(SelfAttention, self).__init__()
        self.query = nn.Conv2d(channels, channels // 8, kernel_size=1)
        self.key = nn.Conv2d(channels, channels // 8, kernel_size=1)
        self.value = nn.Conv2d(channels, channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        B, C, H, W = x.shape
        proj_query = self.query(x).view(B, -1, H * W).permute(0, 2, 1)  # B, N, C'
        proj_key = self.key(x).view(B, -1, H * W)  # B, C', N
        energy = torch.bmm(proj_query, proj_key)  # B, N, N
        scale = (C // 8) ** -0.5  # 縮放因子
        attention = F.softmax(energy * scale, dim=-1)  # B, N, N
        proj_value = self.value(x).view(B, -1, H * W)  # B, C, N
        out = torch.bmm(proj_value, attention.permute(0, 2, 1)).view(B, C, H, W)  # B, C, H, W
        return self.gamma * out + x

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *