加上 Edge-Aware Loss,這會讓模型更注重邊緣細節

你目前的損失模組已經包含 L1Loss, PerceptualLoss, CategoryLoss, FeatureMatchingLoss, GradientPenalty,現在我們來加上 Edge-Aware Loss,這會讓模型更注重邊緣細節,對中文字風格轉換特別有效。

✅ 新增 Edge-Aware Loss:

我們可以使用 Canny 邊緣偵測(或 Sobel Filter)產生邊緣圖,然後計算生成圖與真實圖邊緣圖之間的 L1 差異。


1. 新增 EdgeAwareLoss 類別:

import torchvision.transforms.functional as TF
import cv2
import numpy as np

class EdgeAwareLoss(nn.Module):
    def __init__(self):
        super(EdgeAwareLoss, self).__init__()
        self.l1 = nn.L1Loss()

    def get_edge(self, img):
        # 將 tensor 轉 numpy,再轉灰階
        img_np = img.detach().cpu().numpy()
        edges = []
        for i in range(img_np.shape[0]):
            gray = img_np[i, 0] * 255.0
            gray = np.clip(gray, 0, 255).astype(np.uint8)
            edge = cv2.Canny(gray, 100, 200) / 255.0
            edges.append(edge)
        edge_tensor = torch.tensor(np.stack(edges), dtype=img.dtype, device=img.device).unsqueeze(1)
        return edge_tensor

    def forward(self, pred, target):
        pred_edge = self.get_edge(pred)
        target_edge = self.get_edge(target)
        return self.l1(pred_edge, target_edge)

2. 加入至 Zi2ZiLoss 類別:

class Zi2ZiLoss:
    def __init__(self, model, device, lambda_L1=100, lambda_const=10, lambda_cat=1, lambda_fm=10, lambda_perc=10, lambda_gp=10, lambda_edge=5):
        self.model = model
        self.device = device

        self.L1 = nn.L1Loss().to(device)
        self.const = nn.MSELoss().to(device)
        self.category = CategoryLoss(model.embedding_num).to(device)
        self.perceptual = PerceptualLoss().to(device)
        self.feature_match = nn.L1Loss().to(device)
        self.edge_loss = EdgeAwareLoss().to(device)

        self.lambda_L1 = lambda_L1
        self.lambda_const = lambda_const
        self.lambda_cat = lambda_cat
        self.lambda_fm = lambda_fm
        self.lambda_perc = lambda_perc
        self.lambda_gp = lambda_gp
        self.lambda_edge = lambda_edge  # 新增 edge loss 權重

3. 在 backward_G 加入計算:

def backward_G(self, real_A, real_B, fake_B, encoded_real_A, encoded_fake_B, labels):
    real_AB = torch.cat([real_A, real_B], 1)
    fake_AB = torch.cat([real_A, fake_B], 1)

    fake_D, fake_cat = self.model.netD(fake_AB)
    real_D, _ = self.model.netD(real_AB)

    g_loss_adv = -torch.mean(F.logsigmoid(fake_D - real_D))
    const_loss = self.const(encoded_real_A, encoded_fake_B) * self.lambda_const
    l1_loss = self.L1(fake_B, real_B) * self.lambda_L1
    perc_loss = self.perceptual(fake_B, real_B) * self.lambda_perc
    cat_loss = self.category(fake_cat, labels) * self.lambda_cat
    fm_loss = self.feature_matching_loss(real_AB, fake_AB) * self.lambda_fm
    edge_loss = self.edge_loss(fake_B, real_B) * self.lambda_edge  # Edge-aware loss ✅

    total_G_loss = g_loss_adv + const_loss + l1_loss + perc_loss + cat_loss + fm_loss + edge_loss
    return total_G_loss, {
        "adv": g_loss_adv.item(),
        "const": const_loss.item(),
        "l1": l1_loss.item(),
        "perc": perc_loss.item(),
        "cat": cat_loss.item(),
        "fm": fm_loss.item(),
        "edge": edge_loss.item()  # Log edge-aware loss
    }

這樣整體就能在訓練時鼓勵模型生成更銳利、符合真實邊緣結構的圖片。你也可以選擇用 Sobel edge(可微分)替代 Canny,如果你偏好端到端可導版本,我可以幫你改寫。

需要我也幫你改成 Sobel-based differentiable edge loss 嗎?

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *