使用 autocast 進行混合精度訓練(減少顯存使用 & 加速)

修改前:

for epoch in range(args.epoch):
    for bid, (_, zenmaru, noto) in enumerate(train_dataloader):
        noto, zenmaru = noto.to(device), zenmaru.to(device)

        # 訓練 Generator
        optimizer_G.zero_grad()
        output = generator(noto)
        fake_preds, fake_features = discriminator(output)  # 取得 Discriminator 預測 & 特徵
        real_preds, real_features = discriminator(zenmaru)  # 取得真實圖片的 Discriminator 特徵

        # 1. 主要的 Generator Loss (L1 損失)
        g_loss = criterion(output, zenmaru)

        # 2. Feature Matching Loss
        fm_loss = sum(F.l1_loss(fake_f, real_f) for fake_f, real_f in zip(fake_features, real_features)) / len(real_features)

        # 總 Generator Loss
        total_g_loss = g_loss + 0.1 * fm_loss  # 可調整權重

        total_g_loss.backward(retain_graph=True)  # 保留計算圖
        optimizer_G.step()

        # 訓練 Discriminator
        optimizer_D.zero_grad()
        real_labels = torch.full_like(real_preds, 0.9, device=device)  # One-Sided Label Smoothing
        fake_labels = torch.full_like(fake_preds, 0.1, device=device)

        real_loss = criterion_GAN(real_preds, real_labels)
        fake_loss = criterion_GAN(fake_preds.detach(), fake_labels)  # 確保 fake_preds 不回傳到 G

        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()  # 這裡不需要 retain_graph
        optimizer_D.step()

這裡是你修改後的 trainer,已經加入 混合精度訓練 (mixed precision training) with torch.cuda.amp.autocast(),以減少顯存使用並加速訓練。


修改點

  1. 使用 torch.cuda.amp.autocast():在 前向傳播 (forward pass) 時,自動使用混合精度來加速計算。
  2. 使用 torch.cuda.amp.GradScaler():確保 backward 時不會發生數值不穩定 (underflow),同時管理 loss.backward()
  3. 梯度縮放 (scaler.scale(loss).backward()):防止小數值梯度過小導致無法更新。
  4. 使用 scaler.step(optimizer):自動執行 optimizer.step(),確保縮放後的梯度能夠正確應用。
  5. 使用 scaler.update():調整縮放比例,避免 underflow。

修改後的 trainer

import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler

# 初始化 AMP 梯度縮放
scaler = GradScaler()

for epoch in range(args.epoch):
    for bid, (_, zenmaru, noto) in enumerate(train_dataloader):
        noto, zenmaru = noto.to(device), zenmaru.to(device)

        # 訓練 Generator
        optimizer_G.zero_grad()
        with autocast():  # 啟用混合精度
            output = generator(noto)
            fake_preds, fake_features = discriminator(output)  # 取得 Discriminator 預測 & 特徵
            real_preds, real_features = discriminator(zenmaru)  # 取得真實圖片的 Discriminator 特徵

            # 1. 主要的 Generator Loss (L1 損失)
            g_loss = criterion(output, zenmaru)

            # 2. Feature Matching Loss
            fm_loss = sum(F.l1_loss(fake_f, real_f) for fake_f, real_f in zip(fake_features, real_features)) / len(real_features)

            # 總 Generator Loss
            total_g_loss = g_loss + 0.1 * fm_loss  # 可調整權重
        
        # 反向傳播 & 更新 G (使用 AMP)
        scaler.scale(total_g_loss).backward(retain_graph=True)  # 保留計算圖
        scaler.step(optimizer_G)
        scaler.update()

        # 訓練 Discriminator
        optimizer_D.zero_grad()
        with autocast():  # 啟用混合精度
            real_labels = torch.full_like(real_preds, 0.9, device=device)  # One-Sided Label Smoothing
            fake_labels = torch.full_like(fake_preds, 0.1, device=device)

            real_loss = criterion_GAN(real_preds, real_labels)
            fake_loss = criterion_GAN(fake_preds.detach(), fake_labels)  # 確保 fake_preds 不回傳到 G

            d_loss = (real_loss + fake_loss) / 2

        # 反向傳播 & 更新 D (使用 AMP)
        scaler.scale(d_loss).backward()  # 這裡不需要 retain_graph
        scaler.step(optimizer_D)
        scaler.update()

優化效果

減少顯存使用:混合精度訓練可將 float32 計算轉換為 float16,降低 GPU 記憶體需求。
加速訓練float16 運算速度更快,特別適用於 NVIDIA RTX/A100 這類 GPU (支援 Tensor Core 計算)。
穩定梯度GradScaler 會動態調整梯度縮放,避免 underflow 問題。

這樣你的訓練應該會更快更穩定!🚀


不使用 autocast 的程式碼:

import torch
import torch.nn as nn
import torch.optim as optim
import os
from torch.utils.data import DataLoader

# 初始化資料集
train_dataset = DatasetFromObj(os.path.join(data_dir, 'train.obj'), transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False)

# 設定 device
device = torch.device("cuda" if (torch.cuda.is_available() and args.gpu_ids) else "cpu")

# 初始化模型
generator = UNetGenerator().to(device)
discriminator = Discriminator().to(device)

# 使用 BCEWithLogitsLoss 來避免 sigmoid 重複計算
criterion_GAN = nn.BCEWithLogitsLoss()  # 取代 BCELoss
criterion_L1 = nn.L1Loss()  # 用於 Feature Matching Loss

# 優化器
optimizer_G = optim.Adam(generator.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
optimizer_D = optim.Adam(discriminator.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))

# 是否載入 checkpoint
if args.resume:
    target_filepath = os.path.join(checkpoint_dir, f"{args.resume}_net_G.pth")
    if os.path.exists(target_filepath):
        generator.load_state_dict(torch.load(target_filepath))
        print(f"✅ Generator({args.resume}) 模型已載入")
    
    target_filepath = os.path.join(checkpoint_dir, f"{args.resume}_net_D.pth")
    if os.path.exists(target_filepath):
        discriminator.load_state_dict(torch.load(target_filepath))
        print(f"✅ Discriminator({args.resume}) 模型已載入")

# 訓練迴圈
total_batches = len(train_dataloader)
global_steps = 0

for epoch in range(args.epoch):
    for bid, (_, zenmaru, noto) in enumerate(train_dataloader):
        noto, zenmaru = noto.to(device), zenmaru.to(device)

        # 訓練 Generator
        optimizer_G.zero_grad()
        output = generator(noto)
        fake_preds, fake_features = discriminator(output)  # 取得 Discriminator 預測 & 特徵
        real_preds, real_features = discriminator(zenmaru)  # 取得真實圖片的 Discriminator 特徵

        # 1. 主要的 Generator Loss (L1 損失)
        g_loss = criterion_L1(output, zenmaru)

        # 2. Feature Matching Loss
        fm_loss = sum(nn.functional.l1_loss(fake_f, real_f) for fake_f, real_f in zip(fake_features, real_features)) / len(real_features)

        # 總 Generator Loss
        total_g_loss = g_loss + 0.1 * fm_loss  # 可調整權重

        total_g_loss.backward(retain_graph=True)  # 保持計算圖,避免 D backward 時圖被釋放
        optimizer_G.step()

        # 訓練 Discriminator
        optimizer_D.zero_grad()

        # 修正 Label Smoothing:使用 `torch.ones_like()` 和 `torch.zeros_like()` 來避免 shape 問題
        real_labels = torch.ones_like(real_preds, device=device) * 0.9  # One-Sided Label Smoothing
        fake_labels = torch.zeros_like(fake_preds, device=device) + 0.1

        real_loss = criterion_GAN(real_preds, real_labels)
        fake_loss = criterion_GAN(fake_preds.detach(), fake_labels)  # `detach()` 確保 fake_preds 不回傳梯度到 G

        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()  # 這裡不需要 retain_graph
        optimizer_D.step()

使用 autocast 花費時間:

Epoch [1/9], [0/5] - G Loss: 1.1268 - D Loss: 0.7095 - time:  5.7
Epoch [2/9], [0/5] - G Loss: 0.5990 - D Loss: 0.5335 - time: 27.9
Epoch [3/9], [0/5] - G Loss: 0.4739 - D Loss: 0.4062 - time: 50.2
Epoch [4/9], [0/5] - G Loss: 0.4014 - D Loss: 0.4230 - time: 72.5
Epoch [5/9], [0/5] - G Loss: 0.3997 - D Loss: 0.3341 - time: 94.6
Epoch [6/9], [0/5] - G Loss: 0.3471 - D Loss: 0.3300 - time: 116.7
Epoch [7/9], [0/5] - G Loss: 0.2815 - D Loss: 1.3643 - time: 138.8
Epoch [8/9], [0/5] - G Loss: 0.2320 - D Loss: 0.6053 - time: 160.9
Epoch [9/9], [0/5] - G Loss: 0.1902 - D Loss: 0.8544 - time: 183.1

沒有使用 autocast 花費時間:

Checkpoint: checkpoint step 0, will save after 1
Epoch [1/8], [0/6] - G Loss: 0.8720 - D Loss: 0.7085 - time: 1.9
Epoch [2/8], [0/6] - G Loss: 0.3940 - D Loss: 1.3142 - time: 19.2
Epoch [3/8], [0/6] - G Loss: 0.3493 - D Loss: 1.2127 - time: 49.6
Epoch [4/8], [0/6] - G Loss: 0.3326 - D Loss: 1.1968 - time: 72.2
Epoch [5/8], [0/6] - G Loss: 0.2578 - D Loss: 1.1772 - time: 93.0
Epoch [6/8], [0/6] - G Loss: 0.2915 - D Loss: 1.2308 - time: 114.8
Epoch [7/8], [0/6] - G Loss: 0.2624 - D Loss: 1.2039 - time: 140.5
Epoch [8/8], [0/6] - G Loss: 0.1833 - D Loss: 1.2141 - time: 159.3

發現, 在4GB 的 RAM 的 GPU 之外, 花費的時間是一樣的, 誤差在1~2秒之內, 但記憶體是有使用比較少.

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *