set_requires_grad 避免不必要的梯度計算,提高訓練效率並減少記憶體使用量

是的,這是一個常見的技巧,用來避免不必要的梯度計算,提高訓練效率並減少記憶體使用量。你可以在 trainer.py 中加入這個函數,然後在 GeneratorDiscriminator 訓練時適當地切換 requires_grad 屬性。


如何應用 set_requires_grad

  1. 當訓練 G
    • D 的梯度應該關閉 (requires_grad=False),這樣 D 的權重不會更新,避免不必要的計算。
  2. 當訓練 D
    • D 的梯度應該開啟 (requires_grad=True),確保 D 能夠學習區分真實與生成的影像。

修改 trainer.py

我會幫你更新 trainer.py,加入 set_requires_grad,並確保 D 在訓練 G 時不會更新梯度。

import torch
import torch.nn as nn
import torch.optim as optim
import functools
from torch.utils.data import DataLoader
from torchvision import transforms
import os

# BinaryLoss 類別
class BinaryLoss(nn.Module):
    def __init__(self, real):
        super(BinaryLoss, self).__init__()
        self.bce = nn.BCEWithLogitsLoss()
        self.real = real

    def forward(self, logits):
        labels = torch.ones(logits.shape[0], 1) if self.real else torch.zeros(logits.shape[0], 1)
        labels = labels.to(logits.device)  # 確保 labels 在正確的 device 上
        return self.bce(logits, labels)

# 設定 requires_grad
def set_requires_grad(nets, requires_grad=False):
    """設定 requires_grad,避免不必要的計算"""
    if not isinstance(nets, list):
        nets = [nets]
    for net in nets:
        if net is not None:
            for param in net.parameters():
                param.requires_grad = requires_grad

# 設定 device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 影像轉換
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# 載入資料
dataset = DatasetFromObj(os.path.join(args.data_dir, 'train.obj'), transform=transform)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)

# 初始化模型
generator = UNetGenerator().to(device)
discriminator = Discriminator().to(device)

# 續訓
if args.resume:
    load_checkpoint(device, generator, discriminator, args.checkpoint_dir, args.resume)

# 損失函數
criterion_l1 = nn.L1Loss()
criterion_mse = nn.MSELoss()
real_binary_loss = BinaryLoss(True).to(device)
fake_binary_loss = BinaryLoss(False).to(device)

# 優化器
optimizer_G = optim.Adam(generator.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
optimizer_D = optim.Adam(discriminator.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))

# 訓練迴圈
total_batches = len(dataloader)
for epoch in range(args.epoch):
    for bid, (_, zenmaru, noto) in enumerate(dataloader):
        noto, zenmaru = noto.to(device), zenmaru.to(device)

        # 訓練 Generator
        set_requires_grad(discriminator, False)  # 停用 D 的梯度
        optimizer_G.zero_grad()
        output = generator(noto)

        # Discriminator 判斷真實與生成影像
        binary_real, category_real = discriminator(zenmaru)
        binary_fake, category_fake = discriminator(output)

        # Loss 計算
        l1_loss = criterion_l1(output, zenmaru)
        category_loss = criterion_mse(category_fake, category_real)
        cheat_loss = real_binary_loss(binary_fake)
        const_loss = criterion_mse(output, output.detach())

        # G 的損失
        g_loss = l1_loss + 0.3 * cheat_loss + 0.1 * category_loss + 0.05 * const_loss
        g_loss.backward()
        optimizer_G.step()

        # 訓練 Discriminator
        if bid % 2 == 0:  # 每 2 個 batch 訓練一次
            set_requires_grad(discriminator, True)  # 啟用 D 的梯度
            optimizer_D.zero_grad()
            binary_fake = discriminator(output.detach())[0]
            real_loss = real_binary_loss(binary_real)
            fake_loss = fake_binary_loss(binary_fake)
            d_loss = (real_loss + fake_loss) / 2
            d_loss.backward()
            optimizer_D.step()

    print(f"Epoch [{epoch+1}/{args.epoch}], G Loss: {g_loss.item():.4f}, D Loss: {d_loss.item():.4f}")
  • 加入 set_requires_grad 函數,確保 D 在訓練 G 時不會更新梯度。
  • 調整 GD 訓練步驟
    • set_requires_grad(discriminator, False)G 訓練時關閉 D 的梯度。
    • set_requires_grad(discriminator, True) 只在 D 訓練時啟用梯度。
    • D2 個 batch 訓練一次,減少梯度不穩定。

這樣可以提高 resume 訓練後的穩定性!🚀 你可以試試看~

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *