將 Feature Matching Loss 的程式碼加入到 Zi2ZiModel 類別中

好的,我們將 Feature Matching Loss 的程式碼加入到 Zi2ZiModel 類別中,並在 backward_Goptimize_parameters 方法中加入對應的計算和優化步驟。

首先,我們需要在 Zi2ZiModel__init__ 方法中初始化 feature_matching_loss

Python

class Zi2ZiModel:
    def __init__(self, input_nc=1, embedding_num=40, embedding_dim=128, ngf=64, ndf=64,
                 Lconst_penalty=10, Lcategory_penalty=1, L1_penalty=100,
                 schedule=10, lr=0.001, gpu_ids=None, save_dir='.', is_training=True,
                 image_size=256, self_attention=False, residual_block=False,
                 weight_decay = 1e-5, final_channels=1, beta1=0.5, g_blur=False, d_blur=False, epoch=40):

        # ... (其他初始化程式碼)

        self.feature_matching_loss = nn.L1Loss() # 初始化 Feature Matching Loss

然後,我們需要修改 compute_feature_matching_loss 函數,確保它返回的是 Discriminator 中間層的特徵,而非最終的 logits:

Python

class Zi2ZiModel:
    # ... (其他程式碼)

    def compute_feature_matching_loss(self, real_AB, fake_AB):
        real_features = self.netD.model[:-1](real_AB)  # 取得 Discriminator 中間層的特徵
        fake_features = self.netD.model[:-1](fake_AB)  # 取得 Discriminator 中間層的特徵
        return self.feature_matching_loss(real_features, fake_features)

注意: 這裡我們使用 self.netD.model[:-1] 來獲取 Discriminator 的中間層特徵,而不是最終的 logits。

接下來,我們在 backward_G 方法中加入 Feature Matching Loss 的計算:

Python

class Zi2ZiModel:
    # ... (其他程式碼)

    def backward_G(self, no_target_source=False):
        fake_AB = torch.cat([self.real_A, self.fake_B], 1)
        real_AB = torch.cat([self.real_A, self.real_B], 1)

        fake_D_logits, fake_category_logits = self.netD(fake_AB)
        real_D_logits, _ = self.netD(real_AB)

        const_loss = self.Lconst_penalty * self.mse(self.encoded_real_A, self.encoded_fake_B)
        l1_loss = self.L1_penalty * self.l1_loss(self.fake_B, self.real_B)
        fake_category_loss = self.Lcategory_penalty * self.category_loss(fake_category_logits, self.labels)
        g_loss_adv = -torch.mean(F.logsigmoid(fake_D_logits - real_D_logits))

        # 計算 Feature Matching Loss
        fm_loss = self.compute_feature_matching_loss(real_AB, fake_AB)

        self.g_loss = g_loss_adv + l1_loss + fake_category_loss + const_loss + fm_loss # 加入 fm_loss
        return const_loss, l1_loss, g_loss_adv, fm_loss # 返回 fm_loss

最後,我們在 optimize_parameters 方法中加入 Feature Matching Loss 的處理:

Python

class Zi2ZiModel:
    # ... (其他程式碼)

    def optimize_parameters(self, use_autocast=False):
        # ... (其他程式碼)

        self.set_requires_grad(self.netD, False)
        self.optimizer_G.zero_grad()
        const_loss, l1_loss, cheat_loss, fm_loss = 0, 0, 0, 0 # 初始化 fm_loss

        if use_autocast:
            with autocast():
                const_loss, l1_loss, cheat_loss, fm_loss = self.backward_G()
                scaled_g_loss = self.scaler_G.scale(self.g_loss)
            scaled_g_loss.backward()
            self.scaler_G.step(self.optimizer_G)
            self.scaler_G.update()
        else:
            const_loss, l1_loss, cheat_loss, fm_loss = self.backward_G()
            self.g_loss.backward()
            self.optimizer_G.step()

        self.forward()
        self.optimizer_G.zero_grad()

        if use_autocast:
            with autocast():
                const_loss, l1_loss, cheat_loss, fm_loss = self.backward_G()
                scaled_g_loss = self.scaler_G.scale(self.g_loss)
            scaled_g_loss.backward()
            self.scaler_G.step(self.optimizer_G)
            self.scaler_G.update()
        else:
            const_loss, l1_loss, cheat_loss, fm_loss = self.backward_G()
            self.g_loss.backward()
            self.optimizer_G.step()
        return const_loss, l1_loss, cheat_loss, fm_loss # 返回 fm_loss

這樣,我們就成功地將 Feature Matching Loss 加入到了 Zi2ZiModel 中,並在訓練過程中進行了計算和優化。記得在訓練過程中監控 fm_loss 的變化,以確保其正常收斂。

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *