before:
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None,
norm_layer=nn.BatchNorm2d, layer=0, embedding_dim=128,
use_dropout=False, self_attention=False, blur=False, outermost=False, innermost=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
self.innermost = innermost
use_bias = norm_layer != nn.BatchNorm2d # 若使用 BatchNorm,則 bias=False
if input_nc is None:
input_nc = outer_nc
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(inplace=False) # 這裡必須是 False
upnorm = norm_layer(outer_nc)
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
self.down = nn.Sequential(downconv)
self.up = nn.Sequential(uprelu, upconv, nn.Tanh())
elif innermost:
upconv = nn.ConvTranspose2d(inner_nc + embedding_dim, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
self.down = nn.Sequential(downrelu, downconv)
self.up = nn.Sequential(uprelu, upconv, upnorm)
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
self.down = nn.Sequential(downrelu, downconv, downnorm)
self.up = nn.Sequential(uprelu, upconv, upnorm)
if use_dropout:
self.up.add_module("dropout", nn.Dropout(0.5))
self.submodule = submodule
self.self_attn = SelfAttention(inner_nc) if self_attention and layer in [4, 6] else None
self.res_skip = ResSkip(outer_nc) if not outermost and not innermost else None
def forward(self, x, style=None):
encoded = self.down(x)
if self.self_attn:
encoded = self.self_attn(encoded)
if self.innermost:
if style is not None:
encoded = torch.cat([style.view(style.shape[0], style.shape[1], 1, 1), encoded], dim=1)
decoded = self.up(encoded)
if decoded.shape[2:] != x.shape[2:]:
decoded = F.interpolate(decoded, size=x.shape[2:], mode='bilinear', align_corners=False)
if self.res_skip:
decoded = self.res_skip(decoded)
return torch.cat([x, decoded], 1), encoded.view(x.shape[0], -1)
elif self.outermost:
if self.submodule:
sub_output, encoded_real_A = self.submodule(encoded, style)
else:
sub_output = encoded
decoded = self.up(sub_output)
if self.res_skip:
decoded = self.res_skip(decoded)
return decoded, encoded_real_A
else:
if self.submodule:
sub_output, encoded_real_A = self.submodule(encoded, style)
else:
sub_output = encoded
decoded = self.up(sub_output)
if decoded.shape[2:] != x.shape[2:]:
decoded = F.interpolate(decoded, size=x.shape[2:], mode='bilinear', align_corners=False)
if self.res_skip:
decoded = self.res_skip(decoded)
return torch.cat([x, decoded], 1), encoded_real_A
after, chatgpt:
class UNetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc, submodule=None, norm_layer=nn.BatchNorm2d):
super(UNetSkipConnectionBlock, self).__init__()
self.submodule = submodule
self.down = nn.Sequential(
nn.Conv2d(outer_nc, inner_nc, kernel_size=4, stride=2, padding=1),
norm_layer(inner_nc),
nn.LeakyReLU(0.2, inplace=True)
)
self.up = nn.Sequential(
nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1),
norm_layer(outer_nc),
nn.ReLU(inplace=True)
)
def forward(self, x):
if self.submodule:
return torch.cat([self.up(self.submodule(self.down(x))), x], 1)
else:
return torch.cat([self.up(self.down(x)), x], 1)
after, gemini:
好的,我們來優化 UnetSkipConnectionBlock
的 forward
方法,主要目標是提高效率和簡潔性。
問題分析:
- 程式碼中有重複的邏輯,特別是在處理
submodule
的部分。 decoded.shape[2:] != x.shape[2:]
的檢查和F.interpolate
操作在多個地方出現,可以考慮提取出來。encoded_real_A
的回傳邏輯可以簡化。
優化策略:
- 提取共用邏輯: 將
submodule
的處理和F.interpolate
的檢查和執行提取到單獨的函數中。 - 簡化回傳邏輯: 統一回傳
encoded_real_A
的方式。 - 減少重複計算: 減少不必要的形狀檢查和計算。
優化後的程式碼:
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
# ... (ResSkip 和 SelfAttention 類別保持不變)
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None,
norm_layer=nn.BatchNorm2d, layer=0, embedding_dim=128,
use_dropout=False, self_attention=False, blur=False, outermost=False, innermost=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
self.innermost = innermost
use_bias = norm_layer != nn.BatchNorm2d
if input_nc is None:
input_nc = outer_nc
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(inplace=False)
upnorm = norm_layer(outer_nc)
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
self.down = nn.Sequential(downconv)
self.up = nn.Sequential(uprelu, upconv, nn.Tanh())
elif innermost:
upconv = nn.ConvTranspose2d(inner_nc + embedding_dim, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
self.down = nn.Sequential(downrelu, downconv)
self.up = nn.Sequential(uprelu, upconv, upnorm)
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
self.down = nn.Sequential(downrelu, downconv, downnorm)
self.up = nn.Sequential(uprelu, upconv, upnorm)
if use_dropout:
self.up.add_module("dropout", nn.Dropout(0.5))
self.submodule = submodule
self.self_attn = SelfAttention(inner_nc) if self_attention and layer in [4, 6] else None
self.res_skip = ResSkip(outer_nc) if not outermost and not innermost else None
def _process_submodule(self, encoded, style):
if self.submodule:
sub_output, encoded_real_A = self.submodule(encoded, style)
else:
sub_output = encoded
encoded_real_A = None
return sub_output, encoded_real_A
def _interpolate_if_needed(self, decoded, x):
if decoded.shape[2:] != x.shape[2:]:
decoded = F.interpolate(decoded, size=x.shape[2:], mode='bilinear', align_corners=False)
return decoded
def forward(self, x, style=None):
encoded = self.down(x)
if self.self_attn:
encoded = self.self_attn(encoded)
if self.innermost:
if style is not None:
encoded = torch.cat([style.view(style.shape[0], style.shape[1], 1, 1), encoded], dim=1)
decoded = self.up(encoded)
decoded = self._interpolate_if_needed(decoded, x)
if self.res_skip:
decoded = self.res_skip(decoded)
return torch.cat([x, decoded], 1), encoded.view(x.shape[0], -1)
sub_output, encoded_real_A = self._process_submodule(encoded, style)
decoded = self.up(sub_output)
decoded = self._interpolate_if_needed(decoded, x)
if self.res_skip:
decoded = self.res_skip(decoded)
if self.outermost:
return decoded, encoded_real_A
else:
return torch.cat([x, decoded], 1), encoded_real_A
主要改進:
_process_submodule
函數: 提取了submodule
的處理邏輯,減少了重複程式碼。_interpolate_if_needed
函數: 提取了F.interpolate
的檢查和執行邏輯。- 簡化的回傳邏輯: 統一了
encoded_real_A
的回傳方式。 - 更清晰的程式碼結構: 通過提取函數,程式碼的可讀性和可維護性得到了提高。
這些優化可以提高 UnetSkipConnectionBlock
的效率和程式碼的清晰度。