修改前:
class SelfAttention(nn.Module):
def __init__(self, channels):
super(SelfAttention, self).__init__()
self.query = nn.Conv2d(channels, channels // 8, kernel_size=1)
self.key = nn.Conv2d(channels, channels // 8, kernel_size=1)
self.value = nn.Conv2d(channels, channels, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
B, C, H, W = x.shape
proj_query = self.query(x).view(B, -1, H * W).permute(0, 2, 1)
proj_key = self.key(x).view(B, -1, H * W)
energy = torch.bmm(proj_query, proj_key)
attention = F.softmax(energy, dim=-1)
proj_value = self.value(x).view(B, -1, H * W)
out = torch.bmm(proj_value, attention.permute(0, 2, 1)).view(B, C, H, W)
return self.gamma * out + x
after, chatgpt:
class SelfAttention(nn.Module):
def __init__(self, in_channels):
super(SelfAttention, self).__init__()
self.query = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
self.key = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
batch_size, C, width, height = x.size()
proj_query = self.query(x).view(batch_size, -1, width * height).permute(0, 2, 1)
proj_key = self.key(x).view(batch_size, -1, width * height)
attention = F.softmax(torch.bmm(proj_query, proj_key) / (C ** 0.5), dim=-1)
proj_value = self.value(x).view(batch_size, -1, width * height)
out = torch.bmm(proj_value, attention.permute(0, 2, 1)).view(batch_size, C, width, height)
return self.gamma * out + x
after, gemini 版本:
class SelfAttention(nn.Module):
def __init__(self, channels):
super(SelfAttention, self).__init__()
self.query = nn.Conv2d(channels, channels // 8, kernel_size=1)
self.key = nn.Conv2d(channels, channels // 8, kernel_size=1)
self.value = nn.Conv2d(channels, channels, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
B, C, H, W = x.shape
proj_query = self.query(x).view(B, -1, H * W).permute(0, 2, 1) # B, N, C'
proj_key = self.key(x).view(B, -1, H * W) # B, C', N
energy = torch.bmm(proj_query, proj_key) # B, N, N
scale = (C // 8) ** -0.5 # 縮放因子
attention = F.softmax(energy * scale, dim=-1) # B, N, N
proj_value = self.value(x).view(B, -1, H * W) # B, C, N
out = torch.bmm(proj_value, attention.permute(0, 2, 1)).view(B, C, H, W) # B, C, H, W
return self.gamma * out + x