是的,這是一個常見的技巧,用來避免不必要的梯度計算,提高訓練效率並減少記憶體使用量。你可以在 trainer.py
中加入這個函數,然後在 Generator
和 Discriminator
訓練時適當地切換 requires_grad
屬性。
如何應用 set_requires_grad
- 當訓練
G
時:D
的梯度應該關閉 (requires_grad=False
),這樣D
的權重不會更新,避免不必要的計算。
- 當訓練
D
時:D
的梯度應該開啟 (requires_grad=True
),確保D
能夠學習區分真實與生成的影像。
修改 trainer.py
我會幫你更新 trainer.py
,加入 set_requires_grad
,並確保 D
在訓練 G
時不會更新梯度。
import torch
import torch.nn as nn
import torch.optim as optim
import functools
from torch.utils.data import DataLoader
from torchvision import transforms
import os
# BinaryLoss 類別
class BinaryLoss(nn.Module):
def __init__(self, real):
super(BinaryLoss, self).__init__()
self.bce = nn.BCEWithLogitsLoss()
self.real = real
def forward(self, logits):
labels = torch.ones(logits.shape[0], 1) if self.real else torch.zeros(logits.shape[0], 1)
labels = labels.to(logits.device) # 確保 labels 在正確的 device 上
return self.bce(logits, labels)
# 設定 requires_grad
def set_requires_grad(nets, requires_grad=False):
"""設定 requires_grad,避免不必要的計算"""
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad
# 設定 device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 影像轉換
transform = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
# 載入資料
dataset = DatasetFromObj(os.path.join(args.data_dir, 'train.obj'), transform=transform)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
# 初始化模型
generator = UNetGenerator().to(device)
discriminator = Discriminator().to(device)
# 續訓
if args.resume:
load_checkpoint(device, generator, discriminator, args.checkpoint_dir, args.resume)
# 損失函數
criterion_l1 = nn.L1Loss()
criterion_mse = nn.MSELoss()
real_binary_loss = BinaryLoss(True).to(device)
fake_binary_loss = BinaryLoss(False).to(device)
# 優化器
optimizer_G = optim.Adam(generator.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
optimizer_D = optim.Adam(discriminator.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
# 訓練迴圈
total_batches = len(dataloader)
for epoch in range(args.epoch):
for bid, (_, zenmaru, noto) in enumerate(dataloader):
noto, zenmaru = noto.to(device), zenmaru.to(device)
# 訓練 Generator
set_requires_grad(discriminator, False) # 停用 D 的梯度
optimizer_G.zero_grad()
output = generator(noto)
# Discriminator 判斷真實與生成影像
binary_real, category_real = discriminator(zenmaru)
binary_fake, category_fake = discriminator(output)
# Loss 計算
l1_loss = criterion_l1(output, zenmaru)
category_loss = criterion_mse(category_fake, category_real)
cheat_loss = real_binary_loss(binary_fake)
const_loss = criterion_mse(output, output.detach())
# G 的損失
g_loss = l1_loss + 0.3 * cheat_loss + 0.1 * category_loss + 0.05 * const_loss
g_loss.backward()
optimizer_G.step()
# 訓練 Discriminator
if bid % 2 == 0: # 每 2 個 batch 訓練一次
set_requires_grad(discriminator, True) # 啟用 D 的梯度
optimizer_D.zero_grad()
binary_fake = discriminator(output.detach())[0]
real_loss = real_binary_loss(binary_real)
fake_loss = fake_binary_loss(binary_fake)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()
print(f"Epoch [{epoch+1}/{args.epoch}], G Loss: {g_loss.item():.4f}, D Loss: {d_loss.item():.4f}")
- 加入
set_requires_grad
函數,確保D
在訓練G
時不會更新梯度。 - 調整
G
和D
訓練步驟:set_requires_grad(discriminator, False)
在G
訓練時關閉D
的梯度。set_requires_grad(discriminator, True)
只在D
訓練時啟用梯度。D
每 2 個 batch 訓練一次,減少梯度不穩定。
這樣可以提高 resume
訓練後的穩定性!🚀 你可以試試看~