你目前的損失模組已經包含 L1Loss
, PerceptualLoss
, CategoryLoss
, FeatureMatchingLoss
, GradientPenalty
,現在我們來加上 Edge-Aware Loss,這會讓模型更注重邊緣細節,對中文字風格轉換特別有效。
✅ 新增 Edge-Aware Loss:
我們可以使用 Canny 邊緣偵測(或 Sobel Filter)產生邊緣圖,然後計算生成圖與真實圖邊緣圖之間的 L1 差異。
1. 新增 EdgeAwareLoss
類別:
import torchvision.transforms.functional as TF
import cv2
import numpy as np
class EdgeAwareLoss(nn.Module):
def __init__(self):
super(EdgeAwareLoss, self).__init__()
self.l1 = nn.L1Loss()
def get_edge(self, img):
# 將 tensor 轉 numpy,再轉灰階
img_np = img.detach().cpu().numpy()
edges = []
for i in range(img_np.shape[0]):
gray = img_np[i, 0] * 255.0
gray = np.clip(gray, 0, 255).astype(np.uint8)
edge = cv2.Canny(gray, 100, 200) / 255.0
edges.append(edge)
edge_tensor = torch.tensor(np.stack(edges), dtype=img.dtype, device=img.device).unsqueeze(1)
return edge_tensor
def forward(self, pred, target):
pred_edge = self.get_edge(pred)
target_edge = self.get_edge(target)
return self.l1(pred_edge, target_edge)
2. 加入至 Zi2ZiLoss
類別:
class Zi2ZiLoss:
def __init__(self, model, device, lambda_L1=100, lambda_const=10, lambda_cat=1, lambda_fm=10, lambda_perc=10, lambda_gp=10, lambda_edge=5):
self.model = model
self.device = device
self.L1 = nn.L1Loss().to(device)
self.const = nn.MSELoss().to(device)
self.category = CategoryLoss(model.embedding_num).to(device)
self.perceptual = PerceptualLoss().to(device)
self.feature_match = nn.L1Loss().to(device)
self.edge_loss = EdgeAwareLoss().to(device)
self.lambda_L1 = lambda_L1
self.lambda_const = lambda_const
self.lambda_cat = lambda_cat
self.lambda_fm = lambda_fm
self.lambda_perc = lambda_perc
self.lambda_gp = lambda_gp
self.lambda_edge = lambda_edge # 新增 edge loss 權重
3. 在 backward_G
加入計算:
def backward_G(self, real_A, real_B, fake_B, encoded_real_A, encoded_fake_B, labels):
real_AB = torch.cat([real_A, real_B], 1)
fake_AB = torch.cat([real_A, fake_B], 1)
fake_D, fake_cat = self.model.netD(fake_AB)
real_D, _ = self.model.netD(real_AB)
g_loss_adv = -torch.mean(F.logsigmoid(fake_D - real_D))
const_loss = self.const(encoded_real_A, encoded_fake_B) * self.lambda_const
l1_loss = self.L1(fake_B, real_B) * self.lambda_L1
perc_loss = self.perceptual(fake_B, real_B) * self.lambda_perc
cat_loss = self.category(fake_cat, labels) * self.lambda_cat
fm_loss = self.feature_matching_loss(real_AB, fake_AB) * self.lambda_fm
edge_loss = self.edge_loss(fake_B, real_B) * self.lambda_edge # Edge-aware loss ✅
total_G_loss = g_loss_adv + const_loss + l1_loss + perc_loss + cat_loss + fm_loss + edge_loss
return total_G_loss, {
"adv": g_loss_adv.item(),
"const": const_loss.item(),
"l1": l1_loss.item(),
"perc": perc_loss.item(),
"cat": cat_loss.item(),
"fm": fm_loss.item(),
"edge": edge_loss.item() # Log edge-aware loss
}
這樣整體就能在訓練時鼓勵模型生成更銳利、符合真實邊緣結構的圖片。你也可以選擇用 Sobel edge(可微分)替代 Canny,如果你偏好端到端可導版本,我可以幫你改寫。
需要我也幫你改成 Sobel-based differentiable edge loss 嗎?