不知道為什麼 gemini 2.5 pro 都用英文在回答, 明明就用中文去問他.
Okay, let’s replace the standard SelfAttention
with a more efficient LinearAttention
implementation. We’ll keep the StyleModulation
(FiLM-like) block introduced previously.
1. Define the LinearAttention
Module
This version avoids the N^2 complexity of standard self-attention. We’ll use the ELU feature map approach.
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
# Assume other necessary imports like models, VGG16_Weights, T are present
from torchvision import models # Example import
from torchvision.models import VGG16_Weights # Example import
import torchvision.transforms as T # Example import
# --- Keep StyleModulation and ResSkip as defined before ---
class StyleModulation(nn.Module):
def __init__(self, channels, style_dim):
super(StyleModulation, self).__init__()
# InstanceNorm is often used before style modulation
self.norm = nn.InstanceNorm2d(channels, affine=False)
# Simpler MLP for generating gamma/beta
self.style_fc = nn.Linear(style_dim, channels * 2)
# Initialize weights for style_fc for better starting point
self.style_fc.weight.data.normal_(0, 0.02)
self.style_fc.bias.data.zero_()
def forward(self, x, style):
x = self.norm(x)
style_params = self.style_fc(style) # (B, 2C)
gamma, beta = style_params.chunk(2, dim=1) # 各為 (B, C)
# Reshape gamma and beta for broadcasting: (B, C, 1, 1)
gamma = gamma.unsqueeze(2).unsqueeze(3)
beta = beta.unsqueeze(2).unsqueeze(3)
# Apply FiLM: gamma * x + beta
return gamma * x + beta
class ResSkip(nn.Module):
def __init__(self, channels):
super(ResSkip, self).__init__()
self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.norm = nn.GroupNorm(8, channels) # GroupNorm is less batch-dependent
self.relu = nn.SiLU(inplace=True) # SiLU activation
def forward(self, x):
identity = x
out = self.conv(x)
out = self.norm(out)
out = self.relu(out)
return identity + out # Residual connection
# --- Efficient Linear Attention ---
class LinearAttention(nn.Module):
"""
Linear Self-Attention mechanism.
Approximates softmax(Q K^T / sqrt(d)) V with Q (K^T V) using feature maps.
Complexity: O(N * d^2) where N=H*W, d=key_channels. More efficient than O(N^2 * d).
"""
def __init__(self, channels, key_channels=None):
super(LinearAttention, self).__init__()
# Default key_channels to channels // 8 like original SelfAttention if None
self.key_channels = key_channels if key_channels is not None else channels // 8
self.value_channels = channels # Keep value channels same as input/output
# Use Conv1x1 for Q, K, V projections
self.query = nn.Conv2d(channels, self.key_channels, kernel_size=1)
self.key = nn.Conv2d(channels, self.key_channels, kernel_size=1)
self.value = nn.Conv2d(channels, self.value_channels, kernel_size=1)
# Output projection (here, it's identity since value_channels == channels)
self.out_proj = nn.Identity()
# Gamma for residual connection, initialize to zero to start with identity
self.gamma = nn.Parameter(torch.zeros(1))
# Feature map for linearizing attention: elu(x) + 1 ensures non-negativity
self.feature_map = lambda x: F.elu(x) + 1.0
def forward(self, x):
B, C, H, W = x.shape
N = H * W # Sequence length
# 1. Project inputs Q, K, V
q = self.query(x) # (B, key_channels, H, W)
k = self.key(x) # (B, key_channels, H, W)
v = self.value(x) # (B, value_channels, H, W)
# 2. Apply feature map and reshape for matrix multiplication
# Reshape to (B, channels, N)
q_mapped = self.feature_map(q).view(B, self.key_channels, N)
k_mapped = self.feature_map(k).view(B, self.key_channels, N)
v_reshaped = v.view(B, self.value_channels, N)
# 3. Compute K^T V equivalent (without explicit N^2 matrix)
# kv = K^T @ V^T ? No. kv = K @ V^T?
# k_mapped: (B, key_c, N)
# v_reshaped.transpose(-1, -2): (B, N, value_c)
# kv = torch.bmm(k_mapped, v_reshaped.transpose(-1, -2)) # (B, key_c, value_c)
kv_context = torch.bmm(k_mapped, v_reshaped.transpose(-1, -2))
# 4. Compute normalization factor Z = K^T @ 1 (sum over N dimension)
# k_mapped: (B, key_c, N)
# z_norm = torch.sum(k_mapped, dim=-1, keepdim=True) # (B, key_c, 1)
z_norm_factor = k_mapped.sum(dim=-1, keepdim=True) # (B, key_channels, 1)
# 5. Compute Q @ (K^T V) equivalent = Q @ kv_context
# q_mapped.transpose(-1, -2): (B, N, key_c)
# kv_context: (B, key_c, value_c)
# qkv = torch.bmm(q_mapped.transpose(-1, -2), kv_context) # (B, N, value_c)
qkv_aggregated = torch.bmm(q_mapped.transpose(-1,-2), kv_context) # (B, N, value_channels)
# 6. Compute Q @ Z equivalent = Q @ z_norm_factor
# q_mapped.transpose(-1, -2): (B, N, key_c)
# z_norm_factor: (B, key_c, 1)
# qz = torch.bmm(q_mapped.transpose(-1, -2), z_norm_factor) # (B, N, 1)
qz_normalization = torch.bmm(q_mapped.transpose(-1,-2), z_norm_factor) # (B, N, 1)
# 7. Normalize: Add small epsilon for stability during division
# Transpose result back to (B, value_c, N) before reshaping
# out_reshaped = (qkv / (qz + 1e-6)).transpose(-1, -2)
normalized_out = (qkv_aggregated / (qz_normalization.clamp(min=1e-6))).transpose(-1,-2) # (B, value_channels, N)
# 8. Reshape back to image format (B, value_channels, H, W)
out = normalized_out.view(B, self.value_channels, H, W)
# 9. Optional output projection (currently Identity)
out = self.out_proj(out) # (B, channels, H, W)
# 10. Add residual connection weighted by gamma
# Initialize gamma=0 -> starts as identity mapping (output = x)
return self.gamma * out + x
# --- Keep original SelfAttention for comparison if needed ---
class SelfAttention(nn.Module):
# (Definition from previous code)
def __init__(self, channels):
super(SelfAttention, self).__init__()
self.query = nn.Conv2d(channels, channels // 8, kernel_size=1)
self.key = nn.Conv2d(channels, channels // 8, kernel_size=1)
self.value = nn.Conv2d(channels, channels, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros(1)) # Initialize gamma to 0
self.scale = (channels // 8) ** -0.5
def forward(self, x):
B, C, H, W = x.shape
N = H * W
proj_query = self.query(x).view(B, -1, N).permute(0, 2, 1) # (B, N, C')
proj_key = self.key(x).view(B, -1, N) # (B, C', N)
energy = torch.bmm(proj_query, proj_key) * self.scale # (B, N, N)
attention = F.softmax(energy, dim=-1) # (B, N, N)
proj_value = self.value(x).view(B, -1, N) # (B, C, N)
# out = torch.bmm(proj_value, attention.permute(0, 2, 1)) # (B, C, N)
out = torch.bmm(proj_value, attention.transpose(-1,-2)) # More intuitive transpose
out = out.view(B, C, H, W)
return self.gamma * out + x
# --- Modify UnetSkipConnectionBlock to use attention_type ---
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None,
norm_layer=nn.InstanceNorm2d, layer=0, embedding_dim=128,
use_dropout=False, self_attention=False, blur=False, outermost=False, innermost=False,
attention_type='linear'): # Added attention_type ('standard', 'linear', or 'none')
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
self.innermost = innermost
self.embedding_dim = embedding_dim
use_bias = norm_layer != nn.BatchNorm2d # InstanceNorm handles bias via affine=True
if input_nc is None:
input_nc = outer_nc
# --- Downsampling Layers ---
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
nn.init.kaiming_normal_(downconv.weight, nonlinearity='leaky_relu')
downrelu = nn.LeakyReLU(0.2, True)
# Norm layer might be part of StyleModulation or separate
# Let StyleModulation handle its own norm for now
# downnorm = norm_layer(inner_nc) # Removed, handled by StyleModulation
# --- Upsampling Layers ---
uprelu = nn.ReLU(inplace=True) # Use inplace=True for uprelu as well
upnorm = norm_layer(outer_nc) # Norm for the output of upconv
# --- Block Structure Definition ---
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=True) # Use bias in final layer often okay
nn.init.kaiming_normal_(upconv.weight)
self.down = nn.Sequential(downrelu, downconv) # Add ReLU before conv here too? Original had just conv. Let's match original: just conv
self.down = nn.Sequential(downconv)
self.up = nn.Sequential(uprelu, upconv, nn.Tanh())
self.attn_block = None
self.style_mod = None
self.res_skip = None
elif innermost:
# Upconv input is just inner_nc because no submodule output to concat
upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
nn.init.kaiming_normal_(upconv.weight)
self.down = nn.Sequential(downrelu, downconv) # No norm here, style_mod has it
self.up = nn.Sequential(uprelu, upconv, upnorm) # Add norm after upconv
# Style modulation applied after down block + attention
self.style_mod = StyleModulation(inner_nc, embedding_dim)
self.res_skip = None # No skip connection applied before cat in innermost return
else: # Intermediate blocks
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
nn.init.kaiming_normal_(upconv.weight)
self.down = nn.Sequential(downrelu, downconv) # No norm here, style_mod has it
self.up = nn.Sequential(uprelu, upconv, upnorm) # Add norm after upconv
# Style modulation applied after down block + attention
self.style_mod = StyleModulation(inner_nc, embedding_dim)
# ResSkip applied *after* upsampling, before final cat
self.res_skip = ResSkip(outer_nc)
if use_dropout:
# Add dropout module to self.up sequence AFTER upnorm
self.up.add_module("dropout", nn.Dropout(0.3))
self.submodule = submodule
# --- Attention Block Instantiation ---
if self_attention and layer in [4, 6]: # Apply attention only at specified layers
if attention_type == 'linear':
print(f"INFO: Using Linear Attention in U-Net layer {layer} (inner_nc={inner_nc})")
self.attn_block = LinearAttention(inner_nc)
elif attention_type == 'standard':
print(f"INFO: Using Standard Self-Attention in U-Net layer {layer} (inner_nc={inner_nc})")
self.attn_block = SelfAttention(inner_nc)
elif attention_type == 'none':
print(f"INFO: No Attention in U-Net layer {layer}")
self.attn_block = None
else:
print(f"WARNING: Invalid attention_type '{attention_type}' in U-Net layer {layer}. No attention applied.")
self.attn_block = None
else:
self.attn_block = None # No attention if flag is False or not the specified layer
def _process_submodule(self, encoded, style):
# Pass style embedding down recursively
if self.submodule:
# Submodule forward returns (concatenated_output, bottleneck_features)
return self.submodule(encoded, style)
else: # Innermost block reached
# For const_loss, return the modulated features before upsampling
# The 'encoded' passed here is already modulated by style_mod in the innermost forward
bottleneck_features = encoded.view(encoded.shape[0], -1)
# Innermost doesn't return concatenated output, just the raw upsampled features (decoded)
# Let's return None for the first element to signal it's the bottom.
# No, the caller expects (sub_output, encoded_real_A).
# Let sub_output be the raw encoded features, and encoded_real_A be the flattened ones.
return encoded, bottleneck_features
def _interpolate_if_needed(self, decoded, x):
# This function remains the same
if decoded.shape[2:] != x.shape[2:]:
return F.interpolate(decoded, size=x.shape[2:], mode='bilinear', align_corners=False)
return decoded
def forward(self, x, style): # Style embedding is now required for modulation
# --- Down Path ---
encoded_pre_attn = self.down(x)
# --- Attention (Optional) ---
if self.attn_block:
encoded_post_attn = self.attn_block(encoded_pre_attn)
else:
encoded_post_attn = encoded_pre_attn
# --- Style Modulation (Applies InstanceNorm first) ---
if self.style_mod and style is not None:
encoded_modulated = self.style_mod(encoded_post_attn, style)
else:
# Outermost block doesn't have style_mod
# Or handle cases where style is None defensively (though should always be provided now)
encoded_modulated = encoded_post_attn
# --- Process Submodule or Innermost Logic ---
if self.innermost:
# encoded_modulated contains the final bottleneck features
bottleneck_features_flat = encoded_modulated.view(x.shape[0], -1)
# Apply up path blocks: ReLU, ConvTranspose, Norm, [Dropout]
decoded = self.up(encoded_modulated)
# Interpolate
decoded = self._interpolate_if_needed(decoded, x)
# Innermost returns concatenated input and decoded output for the layer above it,
# plus the flattened bottleneck features for the consistency loss.
return torch.cat([x, decoded], 1), bottleneck_features_flat
else: # Intermediate or Outermost blocks
# Recursively call submodule
# Pass the modulated features down
sub_output_cat, bottleneck_features_flat = self._process_submodule(encoded_modulated, style)
# --- Up Path ---
# Apply up path blocks: ReLU, ConvTranspose, Norm, [Dropout]
# Input to self.up is the concatenated output from the layer below (sub_output_cat)
decoded = self.up(sub_output_cat)
# Interpolate
decoded = self._interpolate_if_needed(decoded, x)
# Apply ResSkip (if not outermost)
if self.res_skip:
decoded = self.res_skip(decoded)
# --- Output ---
if self.outermost:
# Final output image and the bottleneck features from the bottom
return decoded, bottleneck_features_flat
else:
# Concatenate skip connection (original x) with decoded features
# Return this for the layer above, along with bottleneck features
return torch.cat([x, decoded], 1), bottleneck_features_flat
# --- Modify UNetGenerator to accept attention_type ---
class UNetGenerator(nn.Module):
def __init__(self, input_nc=1, output_nc=1, num_downs=8, ngf=64, embedding_num=2, embedding_dim=128,
norm_layer=nn.InstanceNorm2d, use_dropout=False, self_attention=False, blur=False,
attention_type='linear'): # Added attention_type, default to linear
super(UNetGenerator, self).__init__()
print(f"--- Initializing UNetGenerator with {num_downs} downs, Attention: {self_attention}, Type: {attention_type} ---")
# --- Build the U-Net with FiLM and selectable Attention ---
# Innermost block
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None,
norm_layer=norm_layer, layer=1, embedding_dim=embedding_dim,
self_attention=self_attention, blur=blur, innermost=True,
attention_type=attention_type) # Pass params
# Middle layers (if num_downs > 5)
for index in range(num_downs - 5):
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block,
norm_layer=norm_layer, layer=index+2, embedding_dim=embedding_dim,
use_dropout=use_dropout, self_attention=self_attention, blur=blur,
attention_type=attention_type) # Pass params
# Upper layers (closer to output)
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block,
norm_layer=norm_layer, layer=num_downs-3, embedding_dim=embedding_dim,
self_attention=self_attention, blur=blur,
attention_type=attention_type) # Pass params
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block,
norm_layer=norm_layer, layer=num_downs-2, embedding_dim=embedding_dim,
self_attention=self_attention, blur=blur,
attention_type=attention_type) # Pass params
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block,
norm_layer=norm_layer, layer=num_downs-1, embedding_dim=embedding_dim,
self_attention=self_attention, blur=blur,
attention_type=attention_type) # Pass params
# Outermost block
self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block,
norm_layer=norm_layer, layer=num_downs, embedding_dim=embedding_dim,
self_attention=self_attention, blur=blur, outermost=True,
attention_type=attention_type) # Pass params
# Embedding layer for the 2 styles
self.embedder = nn.Embedding(embedding_num, embedding_dim)
print(f"Style Embedder: {embedding_num} styles, dim={embedding_dim}")
def _prepare_style(self, style_or_label):
# Ensure labels are long type and on the correct device
if style_or_label is not None and isinstance(style_or_label, torch.Tensor):
if 'LongTensor' in style_or_label.type():
# Check device consistency if needed, assume input tensor is already on correct device
if style_or_label.max() >= self.embedder.num_embeddings:
raise ValueError(f"Label index {style_or_label.max()} is out of bounds for embedding_num={self.embedder.num_embeddings}")
return self.embedder(style_or_label)
else: # Assume it's already an embedding
return style_or_label
elif style_or_label is None:
raise ValueError("Style/label cannot be None for StyleModulation U-Net.")
else: # Handle non-tensor case? Maybe raise error.
raise TypeError(f"Unsupported type for style_or_label: {type(style_or_label)}")
def forward(self, x, style_or_label): # Style is required
style = self._prepare_style(style_or_label)
# self.model returns (output_image, bottleneck_features_flat)
fake_B, encoded_bottleneck = self.model(x, style)
return fake_B, encoded_bottleneck
def encode(self, x, style_or_label): # Style is required
style = self._prepare_style(style_or_label)
# Call the model and extract the second return value (bottleneck features)
_, encoded_bottleneck = self.model(x, style)
return encoded_bottleneck
# --- Keep Discriminator, CategoryLoss, PerceptualLoss as before ---
# Make sure imports (like T) are present
import torchvision.transforms as T # Example: Ensure T is imported if used in Discriminator
class Discriminator(nn.Module):
def __init__(self, input_nc, embedding_num=2, ndf=64, norm_layer=nn.BatchNorm2d, blur=False): # Default embedding_num=2
super(Discriminator, self).__init__()
use_bias = norm_layer != nn.BatchNorm2d
kw = 5
padw = 2
sequence = [
# Input is cat(real_A, fake_B) or cat(real_A, real_B), so input_nc * 2 channels
nn.utils.spectral_norm(nn.Conv2d(input_nc * 2, ndf, kernel_size=kw, stride=2, padding=padw)),
nn.LeakyReLU(0.2, True)
]
nf_mult = 1
# Typically 3 downsampling layers in PatchGAN-like D
for n in range(1, 3): # n = 1, 2
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8) # nf_mult = 2, 4
sequence += [
nn.utils.spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias)),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
# Output channels: ndf*2, ndf*4
# Final convolution before output layers
nf_mult_prev = nf_mult # 4
nf_mult = min(2 ** 3, 8) # 8
sequence += [
nn.utils.spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias)),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
# Output channels: ndf*8
self.model = nn.Sequential(*sequence)
# Output layers: Use AdaptiveAvgPool to handle variable input sizes before FC layers
# Output size from model depends on input image size and strides.
# Example: 256 -> 128 -> 64 -> 32. Final feature map size = 32x32
# If using AdaptiveAvgPool((1, 1)), final_features = ndf * nf_mult
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
final_features = ndf * nf_mult # ndf * 8
self.binary = nn.Linear(final_features, 1)
self.category = nn.Linear(final_features, embedding_num) # Use embedding_num here
self.blur = blur
if blur:
# Ensure T is imported: import torchvision.transforms as T
self.gaussian_blur = T.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)) # Use range for sigma
def forward(self, input_concat): # Input is already concatenated real_A + real_B/fake_B
if self.blur:
input_concat = self.gaussian_blur(input_concat)
features = self.model(input_concat)
pooled_features = self.global_pool(features)
flat_features = pooled_features.view(input_concat.shape[0], -1) # Flatten
binary_logits = self.binary(flat_features)
category_logits = self.category(flat_features)
return binary_logits, category_logits
class CategoryLoss(nn.Module):
# (Definition from previous code)
def __init__(self, category_num=2): # Default to 2
super(CategoryLoss, self).__init__()
# Use identity matrix for one-hot encoding target
self.register_buffer('identity', torch.eye(category_num))
self.loss = nn.BCEWithLogitsLoss()
def forward(self, category_logits, labels):
# Ensure labels are on the same device as identity matrix and logits
target = self.identity[labels].to(category_logits.device)
return self.loss(category_logits, target)
class PerceptualLoss(nn.Module):
# (Definition from previous code - remember VGG input normalization)
def __init__(self, normalize_input=True): # Add flag for normalization
super(PerceptualLoss, self).__init__()
vgg = models.vgg16(weights=VGG16_Weights.DEFAULT).features
self.slice1 = nn.Sequential(*list(vgg[:4])) # Conv1_2
self.slice2 = nn.Sequential(*list(vgg[4:9])) # Conv2_2
self.slice3 = nn.Sequential(*list(vgg[9:16])) # Conv3_3
self.slice4 = nn.Sequential(*list(vgg[16:23]))# Conv4_3
for param in self.parameters():
param.requires_grad = False
self.normalize_input = normalize_input
if normalize_input:
# ImageNet mean and std
self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
def _normalize(self, x):
# Assumes input x is in range [0, 1] (adjust if it's [-1, 1])
if not self.normalize_input:
return x
if x.min() < -0.1 or x.max() > 1.1: # Basic check if data is likely [-1, 1]
x = (x + 1.0) / 2.0 # Convert from [-1, 1] to [0, 1]
# Normalize to ImageNet stats
return (x - self.mean) / self.std
def forward(self, x, y):
"""Input x, y assumed to be generator output / target image"""
if x.shape[1] == 1:
x = x.repeat(1, 3, 1, 1)
y = y.repeat(1, 3, 1, 1)
# Normalize inputs for VGG
x = self._normalize(x.to(self.mean.device))
y = self._normalize(y.to(self.mean.device))
# Extract features
fx1, fy1 = self.slice1(x), self.slice1(y)
fx2, fy2 = self.slice2(fx1), self.slice2(fy1)
fx3, fy3 = self.slice3(fx2), self.slice3(fy2)
fx4, fy4 = self.slice4(fx3), self.slice4(fy3)
# L1 loss on feature maps
loss = (
F.l1_loss(fx1, fy1) +
F.l1_loss(fx2, fy2) +
F.l1_loss(fx3, fy3) +
F.l1_loss(fx4, fy4)
)
return loss
# --- Modify Zi2ZiModel to accept attention_type and set embedding_num=2 ---
class Zi2ZiModel:
# Default embedding_num=2, attention_type='linear'
def __init__(self, input_nc=1, embedding_num=2, embedding_dim=128, ngf=64, ndf=64,
Lconst_penalty=10, Lcategory_penalty=1, L1_penalty=100,
schedule=10, lr=0.001, gpu_ids=None, save_dir='.', is_training=True,
self_attention=False, attention_type='linear', # Added attention_type
weight_decay = 1e-5, beta1=0.5, g_blur=False, d_blur=False, epoch=40,
gradient_clip=0.5, norm_type="instance"):
print(f"--- Initializing Zi2ZiModel ---")
print(f"Styles: {embedding_num}, Dim: {embedding_dim}")
print(f"Generator Attention: {self_attention}, Type: {attention_type}")
print(f"Norm Type: {norm_type}")
# Store attention config
self.self_attention = self_attention
self.attention_type = attention_type
self.norm_type = norm_type
if is_training:
self.use_dropout = True
else:
self.use_dropout = False
# Loss weights
self.Lconst_penalty = Lconst_penalty
self.Lcategory_penalty = Lcategory_penalty
self.L1_penalty = L1_penalty
self.perceptual_weight = 10.0 # Make perceptual weight configurable?
self.gradient_penalty_weight = 10.0 # Make GP weight configurable?
self.epoch = epoch
self.schedule = schedule # Unused? Check usage.
self.save_dir = save_dir
self.gpu_ids = gpu_ids
self.device = torch.device("cuda" if self.gpu_ids and torch.cuda.is_available() else "cpu")
print(f"Using device: {self.device}")
self.input_nc = input_nc
self.embedding_dim = embedding_dim
self.embedding_num = embedding_num # Should be 2
self.ngf = ngf
self.ndf = ndf
self.lr = lr
self.beta1 = beta1
self.weight_decay = weight_decay
self.is_training = is_training
# self.residual_block param was unused, removed.
self.g_blur = g_blur
self.d_blur = d_blur
self.gradient_clip = gradient_clip
# Setup components (networks, optimizers, losses)
self.setup()
# Grad scalers for mixed precision
self.scaler_G = torch.cuda.amp.GradScaler(enabled=self.device.type == 'cuda')
self.scaler_D = torch.cuda.amp.GradScaler(enabled=self.device.type == 'cuda')
def setup(self):
"""Initializes networks, optimizers, and loss functions."""
print("Setting up networks, optimizers, and losses...")
if self.norm_type.lower() == "batch":
# Note: BatchNorm might interfere with style modulation / instance-level tasks
norm_layer_g = nn.BatchNorm2d
print("Using BatchNorm for Generator U-Net.")
elif self.norm_type.lower() == "instance":
norm_layer_g = nn.InstanceNorm2d
print("Using InstanceNorm for Generator U-Net.")
else:
raise ValueError(f"Unsupported norm_type for Generator: {self.norm_type}")
# Use BatchNorm for Discriminator as per original code? Or InstanceNorm?
# Often BatchNorm is fine for Discriminator.
norm_layer_d = nn.BatchNorm2d
print("Using BatchNorm for Discriminator.")
# --- Initialize Generator ---
self.netG = UNetGenerator(
input_nc=self.input_nc,
output_nc=self.input_nc, # Output is single channel like input
ngf=self.ngf,
use_dropout=self.use_dropout,
embedding_num=self.embedding_num,
embedding_dim=self.embedding_dim,
self_attention=self.self_attention, # Pass flag
attention_type=self.attention_type, # Pass type
blur=self.g_blur,
norm_layer=norm_layer_g # Pass chosen norm layer
).to(self.device)
# --- Initialize Discriminator ---
self.netD = Discriminator(
input_nc=self.input_nc, # D sees concatenated input (A+B = 1+1=2 channels)
embedding_num=self.embedding_num,
ndf=self.ndf,
blur=self.d_blur,
norm_layer=norm_layer_d # Use specified norm for D
).to(self.device)
# Weight Initialization (Assuming init_net exists and works)
# Make sure init_net is defined somewhere
# init_net(self.netG, gpu_ids=self.gpu_ids)
# init_net(self.netD, gpu_ids=self.gpu_ids)
print("Skipping init_net call (ensure it's defined and used if needed).")
# --- Optimizers (AdamW is a good choice) ---
self.optimizer_G = torch.optim.AdamW(self.netG.parameters(), lr=self.lr, betas=(self.beta1, 0.999), weight_decay=self.weight_decay)
self.optimizer_D = torch.optim.AdamW(self.netD.parameters(), lr=self.lr, betas=(self.beta1, 0.999), weight_decay=self.weight_decay)
print(f"Optimizers: AdamW (lr={self.lr}, beta1={self.beta1}, wd={self.weight_decay})")
# --- Schedulers (Cosine Annealing) ---
eta_min = self.lr * 0.01 # Example: anneal down to 1% of initial LR
self.scheduler_G = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer_G, T_max=self.epoch, eta_min=eta_min)
self.scheduler_D = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer_D, T_max=self.epoch, eta_min=eta_min)
print(f"Schedulers: CosineAnnealingLR (T_max={self.epoch}, eta_min={eta_min})")
# --- Loss Functions ---
self.criterion_L1 = nn.L1Loss().to(self.device)
self.criterion_MSE = nn.MSELoss().to(self.device) # For const_loss
self.criterion_Category = CategoryLoss(self.embedding_num).to(self.device)
self.criterion_Perceptual = PerceptualLoss().to(self.device) # Assumes VGG normalization inside
self.criterion_FeatureMatch = nn.L1Loss().to(self.device) # For Feature Matching Loss
print("Loss functions initialized.")
# Set training/eval mode
self.set_train_eval_mode()
def set_train_eval_mode(self):
if self.is_training:
self.netG.train()
self.netD.train()
print("Model set to TRAIN mode.")
else:
self.netG.eval()
self.netD.eval()
print("Model set to EVAL mode.")
def set_input(self, data):
"""Unpack input data from the dataloader and perform necessary pre-processing."""
# Assuming data is a dictionary or tuple: {'label': labels, 'A': real_A, 'B': real_B}
self.labels = data['label'].to(self.device)
self.real_A = data['A'].to(self.device) # Input font image
self.real_B = data['B'].to(self.device) # Target font image
def forward(self):
"""Run forward pass; called by both optimize_parameters and inference."""
# Generate fake image and get bottleneck features for consistency loss
self.fake_B, self.encoded_real_A = self.netG(self.real_A, self.labels)
# Encode the generated fake image using the same style label
self.encoded_fake_B = self.netG.encode(self.fake_B, self.labels)
def compute_feature_matching_loss(self, real_features_list, fake_features_list):
"""Computes L1 loss between intermediate features of the discriminator."""
fm_loss = 0.0
# Assuming D's forward pass returns intermediate features if needed,
# or modify D to extract features. For now, let's assume simple L1 on D's output pre-pool.
# This requires modifying D's forward or using hooks.
# Simpler: Use the feature matching loss from original code (on output before pooling?)
# D.model produces features before pooling/FC layers
real_fm = self.netD.model(torch.cat([self.real_A, self.real_B], 1))
fake_fm = self.netD.model(torch.cat([self.real_A, self.fake_B], 1))
fm_loss = self.criterion_FeatureMatch(fake_fm, real_fm.detach()) # Detach real features
return fm_loss * 10.0 # Add weight to FM loss (common practice)
def compute_gradient_penalty(self, real_samples, fake_samples):
"""Calculates the gradient penalty loss for WGAN-GP"""
# Random weight term for interpolation between real and fake samples
alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=self.device)
# Get random interpolation between real and fake samples
interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
d_interpolates, _ = self.netD(interpolates) # Get discriminator output for interpolates
# Use autograd to compute gradients
grad_outputs = torch.ones(d_interpolates.size(), device=self.device, requires_grad=False)
gradients = torch.autograd.grad(
outputs=d_interpolates,
inputs=interpolates,
grad_outputs=grad_outputs,
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
gradients = gradients.view(gradients.size(0), -1) # Flatten gradients
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
return gradient_penalty
def backward_D(self):
"""Calculate GAN loss for the discriminator"""
# Real samples (pair real source image with real target image)
real_AB = torch.cat([self.real_A, self.real_B], 1)
# Fake samples (pair real source image with generated fake image)
# Detach fake_B to prevent gradients flowing back to the generator here
fake_AB = torch.cat([self.real_A, self.fake_B.detach()], 1)
# Get discriminator predictions
real_D_logits, real_category_logits = self.netD(real_AB)
fake_D_logits, fake_category_logits = self.netD(fake_AB)
# --- Adversarial Loss (Relativistic Average Logit) ---
# D tries to make real_D_logits higher than fake_D_logits on average
loss_D_adv = -torch.mean(F.logsigmoid(real_D_logits - fake_D_logits) +
F.logsigmoid(fake_D_logits - real_D_logits))
# --- Category Loss ---
# D should correctly classify the style of both real and fake images
loss_D_real_category = self.criterion_Category(real_category_logits, self.labels)
loss_D_fake_category = self.criterion_Category(fake_category_logits, self.labels)
loss_D_category = (loss_D_real_category + loss_D_fake_category) * 0.5 # Average category loss
# --- Gradient Penalty ---
# Optional: Apply only if using WGAN-GP loss type instead of relativistic
# gp = self.compute_gradient_penalty(real_AB.data, fake_AB.data)
# Using RaLSGAN, GP is less standard, maybe remove or reduce weight.
gp = self.compute_gradient_penalty(real_AB, fake_AB) # Keep it for now
# --- Total Discriminator Loss ---
self.loss_D = loss_D_adv + \
loss_D_category * self.Lcategory_penalty + \
gp * self.gradient_penalty_weight
# Return individual components for logging
return loss_D_adv, loss_D_category, gp
def backward_G(self):
"""Calculate GAN loss for the generator"""
# --- Adversarial Loss ---
# Generator tries to fool the discriminator
real_AB = torch.cat([self.real_A, self.real_B], 1).detach() # Detach real pair
fake_AB = torch.cat([self.real_A, self.fake_B], 1)
real_D_logits, _ = self.netD(real_AB)
fake_D_logits, fake_category_logits = self.netD(fake_AB)
# Relativistic Average Logit loss for G
# G tries to make fake logits higher than real logits on average
loss_G_adv = -torch.mean(F.logsigmoid(fake_D_logits - real_D_logits) +
F.logsigmoid(real_D_logits - fake_D_logits))
# --- Category Loss ---
# Generator should produce images of the target style
loss_G_category = self.criterion_Category(fake_category_logits, self.labels)
# --- L1 Reconstruction Loss ---
loss_G_L1 = self.criterion_L1(self.fake_B, self.real_B)
# --- Consistency Loss ---
# Compare bottleneck features of real source and fake target
loss_G_const = self.criterion_MSE(self.encoded_fake_B, self.encoded_real_A.detach()) # Detach real encoding
# --- Feature Matching Loss ---
# Match intermediate features in Discriminator
loss_G_FM = self.compute_feature_matching_loss(real_AB, fake_AB)
# --- Perceptual Loss ---
loss_G_perceptual = self.criterion_Perceptual(self.fake_B, self.real_B)
# --- Total Generator Loss ---
self.loss_G = loss_G_adv + \
loss_G_category * self.Lcategory_penalty + \
loss_G_L1 * self.L1_penalty + \
loss_G_const * self.Lconst_penalty + \
loss_G_FM + \
loss_G_perceptual * self.perceptual_weight
# Return individual components for logging
return loss_G_adv, loss_G_category, loss_G_L1, loss_G_const, loss_G_FM, loss_G_perceptual
def optimize_parameters(self, use_autocast=True): # Enable autocast by default if cuda available
"""Calculate losses, gradients, and update network weights."""
autocast_enabled = use_autocast and self.device.type == 'cuda'
# === Forward Pass ===
# Moved forward call inside optimizer steps if needed for gradient penalty/autocast context
# self.forward() # Run forward pass to get fake_B, encodings etc.
# === Update Discriminator ===
self.set_requires_grad(self.netD, True) # Enable grads for D
self.optimizer_D.zero_grad()
with torch.cuda.amp.autocast(enabled=autocast_enabled):
self.forward() # Forward pass needed to compute fake_B for D loss
# Calculate D loss components
loss_D_adv, loss_D_category, gp = self.backward_D()
# Scale the loss for mixed precision
scaled_loss_D = self.scaler_D.scale(self.loss_D)
# Backward pass for D loss
scaled_loss_D.backward()
# Unscale gradients and step optimizer for D
self.scaler_D.unscale_(self.optimizer_D)
# Optional: Clip gradients for D
# torch.nn.utils.clip_grad_norm_(self.netD.parameters(), self.gradient_clip)
self.scaler_D.step(self.optimizer_D)
self.scaler_D.update()
# Check for NaN Discriminator loss
if torch.isnan(self.loss_D):
print("ERROR: Discriminator loss is NaN. Stopping training.")
# Handle error appropriately, e.g., raise exception or exit
raise RuntimeError("Discriminator loss is NaN")
# === Update Generator ===
self.set_requires_grad(self.netD, False) # Disable grads for D
self.optimizer_G.zero_grad()
with torch.cuda.amp.autocast(enabled=autocast_enabled):
self.forward() # Forward pass needed again for G loss components
# Calculate G loss components
loss_G_adv, loss_G_category, loss_G_L1, loss_G_const, loss_G_FM, loss_G_perceptual = self.backward_G()
# Scale the loss for mixed precision
scaled_loss_G = self.scaler_G.scale(self.loss_G)
# Backward pass for G loss
scaled_loss_G.backward()
# Unscale gradients and step optimizer for G
self.scaler_G.unscale_(self.optimizer_G)
# Clip gradients for G
torch.nn.utils.clip_grad_norm_(self.netG.parameters(), self.gradient_clip)
self.scaler_G.step(self.optimizer_G)
self.scaler_G.update()
# Check for NaN Generator loss
if torch.isnan(self.loss_G):
print("ERROR: Generator loss is NaN. Stopping training.")
raise RuntimeError("Generator loss is NaN")
# Return scalar values of main losses for logging
# Use .item() to get Python float from tensor
return {
'G_adv': loss_G_adv.item(),
'G_category': loss_G_category.item(),
'G_L1': loss_G_L1.item(),
'G_const': loss_G_const.item(),
'G_FM': loss_G_FM.item(),
'G_perceptual': loss_G_perceptual.item(),
'G_total': self.loss_G.item(),
'D_adv': loss_D_adv.item(),
'D_category': loss_D_category.item(),
'D_gp': gp.item(),
'D_total': self.loss_D.item()
}
def set_requires_grad(self, nets, requires_grad=False):
"""Set requires_grad=False for all parameters in a network list to avoid unnecessary computations"""
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad
# Add methods for saving/loading models, inference etc. as needed
def save_networks(self, epoch_label):
"""Saves generator and discriminator networks."""
save_filename_g = f'{epoch_label}_net_G.pth'
save_path_g = os.path.join(self.save_dir, save_filename_g)
save_filename_d = f'{epoch_label}_net_D.pth'
save_path_d = os.path.join(self.save_dir, save_filename_d)
torch.save(self.netG.cpu().state_dict(), save_path_g)
torch.save(self.netD.cpu().state_dict(), save_path_d)
# Move back to original device if necessary
if torch.cuda.is_available():
self.netG.to(self.device)
self.netD.to(self.device)
print(f"Saved networks for epoch {epoch_label} to {self.save_dir}")
def load_networks(self, epoch_label):
"""Loads generator and discriminator networks."""
load_filename_g = f'{epoch_label}_net_G.pth'
load_path_g = os.path.join(self.save_dir, load_filename_g)
load_filename_d = f'{epoch_label}_net_D.pth'
load_path_d = os.path.join(self.save_dir, load_filename_d)
if os.path.exists(load_path_g):
state_dict_g = torch.load(load_path_g, map_location=self.device)
self.netG.load_state_dict(state_dict_g)
print(f"Loaded generator weights from {load_path_g}")
else:
print(f"WARNING: Generator weights not found at {load_path_g}")
if os.path.exists(load_path_d):
state_dict_d = torch.load(load_path_d, map_location=self.device)
self.netD.load_state_dict(state_dict_d)
print(f"Loaded discriminator weights from {load_path_d}")
else:
print(f"WARNING: Discriminator weights not found at {load_path_d}")
Key Changes and Considerations:
LinearAttention
Class: Implemented a linear attention mechanism. It uses 1×1 convolutions for Q, K, V projections and ELU+1 feature maps. It avoids computing the full N^2 attention matrix. Gamma is initialized to 0 for the residual connection.UnetSkipConnectionBlock
:- Added
attention_type
parameter to__init__
. - Conditionally creates
self.attn_block
as eitherLinearAttention
,SelfAttention
, orNone
based onself_attention
flag,layer
number, andattention_type
. - Calls
self.attn_block(encoded)
if it exists, right after the downsampling block (self.down
) and before style modulation (self.style_mod
). - Cleaned up norm layer placement slightly (removed redundant
downnorm
asStyleModulation
includesInstanceNorm
).
- Added
UNetGenerator
:- Added
attention_type
to__init__
and passed it down toUnetSkipConnectionBlock
. Defaults to'linear'
.
- Added
Discriminator
:- Changed
global_pool
toAdaptiveAvgPool2d((1, 1))
for robustness to feature map size changes before the linear layers. Adjustedfinal_features
calculation accordingly. - Corrected
input_nc
usage in the first conv layer (it should beinput_nc * 2
because the input is concatenated A+B).
- Changed
PerceptualLoss
: Added optional input normalization using ImageNet stats, assuming input images are in[0, 1]
or[-1, 1]
.Zi2ZiModel
:- Added
attention_type
parameter. - Set
embedding_num=2
(make sure this is correct for your data labels 0 and 1). - Passed
attention_type
down toUNetGenerator
. - Refactored
optimize_parameters
slightly for clarity, usingtorch.cuda.amp.autocast
context manager. - Refined loss function naming (
criterion_L1
, etc.). - Added basic
save_networks
/load_networks
methods (requiresimport os
). - Updated loss calculation logic slightly (
backward_D
,backward_G
) for clarity and consistency (e.g., detaching where needed). - Made
compute_feature_matching_loss
simpler (matching features fromnetD.model
) and added a common weighting factor.
- Added
Now, when you initialize Zi2ZiModel
, you can control whether attention is used (self_attention=True
) and which type (attention_type='linear'
or 'standard'
). For the requested efficient version, use self_attention=True
and attention_type='linear'
.
Python
import os
import math
import time
import random
import argparse # Need argparse to define args
import torch
from torch.utils.data import DataLoader
# Assume these are defined in other files relative to trainer.py
# from model import Zi2ZiModel # Import the updated model class
# from dataset import DatasetFromObj # Import your dataset class
# from utils import ensure_dir, setup_google_drive_service, clear_google_drive_trash # Import utility functions
# --- Mock/Placeholder definitions for missing components ---
# Replace these with your actual imports/definitions
class Zi2ZiModel: # Placeholder based on previous definitions
def __init__(self, input_nc=1, embedding_num=2, embedding_dim=128, ngf=64, ndf=64,
Lconst_penalty=10, Lcategory_penalty=1, L1_penalty=100, perceptual_weight=10.0, gradient_penalty_weight=10.0,
schedule=10, lr=0.001, gpu_ids=None, save_dir='.', is_training=True,
self_attention=False, attention_type='linear',
weight_decay = 1e-5, beta1=0.5, g_blur=False, d_blur=False, epoch=40,
gradient_clip=0.5, norm_type="instance"):
self.device = torch.device("cuda" if gpu_ids and torch.cuda.is_available() else "cpu")
self.save_dir = save_dir
self.epoch = epoch
self.lr=lr
self.beta1=beta1
self.weight_decay=weight_decay
self.embedding_num=embedding_num
self.gradient_clip=gradient_clip
self.Lconst_penalty=Lconst_penalty
self.Lcategory_penalty=Lcategory_penalty
self.L1_penalty=L1_penalty
self.perceptual_weight=perceptual_weight
self.gradient_penalty_weight=gradient_penalty_weight
self.is_training=is_training
print(f"Placeholder Zi2ZiModel initialized on {self.device} with attention={self_attention}, type={attention_type}, embed_num={embedding_num}")
# Mock networks and optimizers for scheduler step
self.netG = torch.nn.Module() # Mock network
self.netD = torch.nn.Module() # Mock network
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=lr)
self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=lr)
eta_min = lr * 0.01
self.scheduler_G = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer_G, T_max=epoch, eta_min=eta_min)
self.scheduler_D = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer_D, T_max=epoch, eta_min=eta_min)
self.loss_D = torch.tensor(0.0) # Mock loss attributes
self.loss_G = torch.tensor(0.0) # Mock loss attributes
def print_networks(self, verbose): print(f"Mock print_networks({verbose})")
def load_networks(self, epoch_label): print(f"Mock load_networks({epoch_label})")
def save_networks(self, step_label):
print(f"Mock save_networks({step_label})")
# Mock saving files to test deletion logic
# try:
# os.makedirs(self.save_dir, exist_ok=True)
# with open(os.path.join(self.save_dir, f"{step_label}_net_G.pth"), 'w') as f: f.write("mock G")
# with open(os.path.join(self.save_dir, f"{step_label}_net_D.pth"), 'w') as f: f.write("mock D")
# except Exception as e:
# print(f"Mock save error: {e}")
pass
def set_input(self, labels, real_A, real_B): print("Mock set_input()")
def optimize_parameters(self, use_autocast):
print(f"Mock optimize_parameters(use_autocast={use_autocast})")
# Return mock loss dictionary matching the latest model's return format
self.loss_D = torch.tensor(random.random())
self.loss_G = torch.tensor(random.random() * 2)
return {
'G_adv': torch.tensor(random.random()).item(), 'G_category': torch.tensor(random.random()).item(),
'G_L1': torch.tensor(random.random()*10).item(), 'G_const': torch.tensor(random.random()).item(),
'G_FM': torch.tensor(random.random()).item(), 'G_perceptual': torch.tensor(random.random()).item(),
'G_total': self.loss_G.item(), 'D_adv': torch.tensor(random.random()).item(),
'D_category': torch.tensor(random.random()).item(), 'D_gp': torch.tensor(random.random()*0.1).item(),
'D_total': self.loss_D.item()
}
def setup(self): pass # Mock method
class DatasetFromObj: # Placeholder
def __init__(self, path, input_nc=1):
self.path = path
self.input_nc = input_nc
# Mock data
self.data = [(torch.tensor([i % 2]), # label (0 or 1)
torch.randn(input_nc, 64, 64), # image B (target font)
torch.randn(input_nc, 64, 64)) # image A (source font)
for i in range(100)] # Mock 100 samples
print(f"Mock DatasetFromObj initialized from {path}")
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
def ensure_dir(path): os.makedirs(path, exist_ok=True); print(f"Mock ensure_dir({path})")
def setup_google_drive_service(): print("Mock setup_google_drive_service()"); return None
def clear_google_drive_trash(service): print("Mock clear_google_drive_trash()")
# --- End Mock Definitions ---
def train(args):
"""訓練主函數"""
print("--- Starting Training ---")
print(f"Arguments: {args}")
# --- Initialization ---
random.seed(args.random_seed)
torch.manual_seed(args.random_seed)
if args.gpu_ids:
torch.cuda.manual_seed_all(args.random_seed) # Seed all GPUs if used
# Directories
data_dir = args.data_dir or os.path.join(args.experiment_dir, "data")
checkpoint_dir = args.checkpoint_dir or os.path.join(args.experiment_dir, "checkpoint")
ensure_dir(checkpoint_dir) # Create checkpoint dir if it doesn't exist
print(f"Data Directory: {data_dir}")
print(f"Checkpoint Directory: {checkpoint_dir}")
# Google Drive (Optional)
drive_service = setup_google_drive_service() if args.checkpoint_only_last else None
# --- Model Initialization ---
# Instantiate the updated Zi2ZiModel with arguments from parser
model = Zi2ZiModel(
input_nc=args.input_nc,
embedding_num=args.embedding_num, # Should be 2 for Noto->Zen task
embedding_dim=args.embedding_dim,
ngf=args.ngf,
ndf=args.ndf,
Lconst_penalty=args.Lconst_penalty,
Lcategory_penalty=args.Lcategory_penalty,
L1_penalty=args.L1_penalty,
perceptual_weight=args.perceptual_weight, # Pass from args
gradient_penalty_weight=args.gradient_penalty_weight, # Pass from args
save_dir=checkpoint_dir,
gpu_ids=args.gpu_ids,
is_training=True, # Set explicitly for training
self_attention=args.self_attention, # Pass attention flag
attention_type=args.attention_type, # Pass attention type
epoch=args.epoch, # Pass total epochs for scheduler
# g_blur=args.g_blur, # Pass blur flags if used by model
# d_blur=args.d_blur,
lr=args.lr,
beta1=args.beta1, # Pass optimizer params
weight_decay=args.weight_decay,
gradient_clip=args.gradient_clip, # Pass gradient clip value
norm_type=args.norm_type # Pass norm type
)
# model.print_networks(True) # Optional: Print network details
# --- Resume Training ---
start_epoch = 0
global_steps = 0
if args.resume:
try:
# Assuming load_networks takes an epoch or step label string
model.load_networks(args.resume)
# Optionally parse the epoch/step from the resume string to continue numbering
# For simplicity, we'll restart step counting unless loading optimizer state too
print(f"Resumed model from step/epoch: {args.resume}")
# If loading optimizer/scheduler state, you'd need to load those too and potentially update start_epoch/global_steps
except Exception as e:
print(f"Could not resume from {args.resume}. Starting from scratch. Error: {e}")
# --- Data Loading ---
train_dataset = DatasetFromObj(os.path.join(data_dir, 'train.obj'), input_nc=args.input_nc)
# Consider adding num_workers for faster data loading
dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True if args.gpu_ids else False)
total_batches = math.ceil(len(train_dataset) / args.batch_size)
print(f"Dataset length: {len(train_dataset)}, Batches per epoch: {total_batches}")
# --- Training Loop ---
start_time = time.time()
print(f"Starting training from epoch {start_epoch}...")
for epoch in range(start_epoch, args.epoch):
epoch_start_time = time.time()
print(f"\n--- Epoch {epoch}/{args.epoch - 1} ---")
for batch_id, batch_data in enumerate(dataloader):
current_step_time = time.time()
# Set model input (ensure correct order: labels, real_A, real_B)
# Based on DatasetFromObj mock: batch = (label, image_B, image_A)
labels, image_B, image_A = batch_data
# Pass to model: labels, real_A (source), real_B (target)
model_input_data = {'label': labels, 'A': image_A, 'B': image_B}
model.set_input(model_input_data) # Pass the dictionary
# Optimize parameters and get losses
# optimize_parameters now returns a dictionary
losses = model.optimize_parameters(args.use_autocast)
global_steps += 1
# --- Logging ---
if batch_id % 100 == 0: # Log every 100 batches
elapsed_batch_time = time.time() - current_step_time
total_elapsed_time = time.time() - start_time
# Use keys from the losses dictionary
print(
f"Epoch: [{epoch:3d}/{args.epoch-1}], Batch: [{batch_id:4d}/{total_batches-1}] "
f"Step: {global_steps} | Time/Batch: {elapsed_batch_time:.2f}s | Total Time: {total_elapsed_time:.0f}s\n"
f" Losses: D_Total={losses['D_total']:.4f}, G_Total={losses['G_total']:.4f}\n"
f" G: [Adv={losses['G_adv']:.4f}, Cat={losses['G_category']:.4f}, L1={losses['G_L1']:.4f}, "
f"Const={losses['G_const']:.4f}, FM={losses['G_FM']:.4f}, Perc={losses['G_perceptual']:.4f}]\n"
f" D: [Adv={losses['D_adv']:.4f}, Cat={losses['D_category']:.4f}, GP={losses['D_gp']:.4f}]"
)
# --- Checkpointing ---
if global_steps % args.checkpoint_steps == 0:
if global_steps >= args.checkpoint_steps_after:
print(f"\nSaving checkpoint at step {global_steps}...")
model.save_networks(global_steps) # Save with step label
print(f"Checkpoint saved.")
# --- Clean up old checkpoints (Optional: only keep last) ---
if args.checkpoint_only_last and drive_service is None: # Only delete local if not using Drive backup
# Find checkpoints older than the current one
current_step_label = global_steps
for step_label in range(args.checkpoint_steps_after, current_step_label, args.checkpoint_steps):
for net_type in ["D", "G"]:
old_filepath = os.path.join(checkpoint_dir, f"{step_label}_net_{net_type}.pth")
if os.path.isfile(old_filepath):
try:
os.remove(old_filepath)
print(f" Removed old checkpoint: {old_filepath}")
except OSError as e:
print(f" Error removing {old_filepath}: {e}")
# Note: GDrive cleanup logic might need refinement based on API usage
# clear_google_drive_trash(drive_service) # This might be too aggressive
else:
print(f"\nCheckpoint step {global_steps} reached, but saving starts after step {args.checkpoint_steps_after}.")
# --- End of Epoch ---
epoch_time = time.time() - epoch_start_time
print(f"\n--- End of Epoch {epoch} --- Time: {epoch_time:.2f}s ---")
# Update Learning Rate Schedulers
model.scheduler_G.step()
model.scheduler_D.step()
print(f"LR Scheduler stepped. Current LR G: {model.scheduler_G.get_last_lr()[0]:.6f}, LR D: {model.scheduler_D.get_last_lr()[0]:.6f}")
# --- Save model at end of epoch? (Optional) ---
# if (epoch + 1) % args.save_epoch_freq == 0: # Add arg save_epoch_freq if needed
# model.save_networks(f'epoch_{epoch+1}')
# --- End of Training ---
print("\n--- Training Finished ---")
# Save the final model state
print("Saving final model...")
model.save_networks('latest') # Save with 'latest' label
print("Final model saved.")
total_training_time = time.time() - start_time
print(f"Total Training Time: {total_training_time:.2f} seconds")
# --- Argparse Setup (Example) ---
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Zi2Zi Training Script')
# --- Paths ---
parser.add_argument('--experiment_dir', required=True, help='Base directory for the experiment')
parser.add_argument('--data_dir', default=None, help='Path to training data directory (default: <experiment_dir>/data)')
parser.add_argument('--checkpoint_dir', default=None, help='Directory to save checkpoints (default: <experiment_dir>/checkpoint)')
# --- Model ---
parser.add_argument('--input_nc', type=int, default=1, help='Input image channels (e.g., 1 for grayscale)')
parser.add_argument('--embedding_num', type=int, default=2, help='Number of distinct styles/fonts (MUST be 2 for Noto->Zen)')
parser.add_argument('--embedding_dim', type=int, default=128, help='Dimension of style embedding vector')
parser.add_argument('--ngf', type=int, default=64, help='Num generator filters in the first conv layer')
parser.add_argument('--ndf', type=int, default=64, help='Num discriminator filters in the first conv layer')
parser.add_argument('--norm_type', type=str, default='instance', choices=['instance', 'batch'], help='Normalization layer type for generator [instance | batch]')
parser.add_argument('--self_attention', action='store_true', help='Enable attention mechanism in generator (at specific layers)')
parser.add_argument('--attention_type', type=str, default='linear', choices=['linear', 'standard', 'none'], help='Type of attention if enabled [linear | standard | none]')
# --- Training Params ---
parser.add_argument('--epoch', type=int, default=100, help='Total number of training epochs')
parser.add_argument('--batch_size', type=int, default=16, help='Training batch size')
parser.add_argument('--lr', type=float, default=0.0002, help='Initial learning rate for AdamW')
parser.add_argument('--beta1', type=float, default=0.5, help='AdamW optimizer beta1 parameter')
parser.add_argument('--weight_decay', type=float, default=1e-5, help='AdamW optimizer weight decay')
parser.add_argument('--use_autocast', action='store_true', help='Use Automatic Mixed Precision (AMP) for training')
parser.add_argument('--gradient_clip', type=float, default=1.0, help='Max gradient norm for generator update')
# parser.add_argument('--schedule', type=int, default=50, help='Deprecated for CosineAnnealingLR')
# --- Loss Weights ---
parser.add_argument('--Lconst_penalty', type=float, default=15.0, help='Weight for embedding consistency loss')
parser.add_argument('--Lcategory_penalty', type=float, default=1.0, help='Weight for category classification loss (D and G)')
parser.add_argument('--L1_penalty', type=float, default=100.0, help='Weight for L1 pixel loss between fake_B and real_B')
parser.add_argument('--perceptual_weight', type=float, default=10.0, help='Weight for VGG perceptual loss')
parser.add_argument('--gradient_penalty_weight', type=float, default=10.0, help='Weight for WGAN-GP gradient penalty')
# --- Checkpointing & Resume ---
parser.add_argument('--resume', type=str, default=None, help='Checkpoint label (e.g., <step> or latest) to resume from')
parser.add_argument('--checkpoint_steps', type=int, default=10000, help='Frequency (in steps) to save checkpoints')
parser.add_argument('--checkpoint_steps_after', type=int, default=0, help='Start saving checkpoints only after this many steps')
parser.add_argument('--checkpoint_only_last', action='store_true', help='If set, delete older checkpoints locally (use with caution or GDrive)')
# --- Misc ---
parser.add_argument('--random_seed', type=int, default=42, help='Seed for random number generators')
parser.add_argument('--gpu_ids', type=str, default='0', help='Comma-separated list of GPU IDs to use (e.g., "0" or "0,1"). Use "" or "-1" for CPU.')
# parser.add_argument('--g_blur', action='store_true', help='Apply blur in Generator (if implemented)')
# parser.add_argument('--d_blur', action='store_true', help='Apply blur in Discriminator')
args = parser.parse_args()
# --- Process GPU IDs ---
if args.gpu_ids and args.gpu_ids != '-1':
str_ids = args.gpu_ids.split(',')
args.gpu_ids = [int(str_id) for str_id in str_ids if str_id.isdigit()]
if not args.gpu_ids:
args.gpu_ids = None # Use CPU if parsing failed or empty string
else:
args.gpu_ids = None # Use CPU
if args.gpu_ids and torch.cuda.is_available():
torch.cuda.set_device(args.gpu_ids[0]) # Set default device
print(f"Using GPUs: {args.gpu_ids}")
elif args.gpu_ids and not torch.cuda.is_available():
print(f"WARNING: GPUs {args.gpu_ids} requested but CUDA not available! Using CPU.")
args.gpu_ids = None
else:
print("Using CPU.")
# --- Sanity Checks ---
if args.embedding_num != 2:
print(f"ERROR: This task (Noto Sans -> Zen Maru Gothic) requires --embedding_num=2. Found {args.embedding_num}.")
exit(1) # Exit if embedding_num is wrong for the task
if not args.self_attention and args.attention_type != 'none':
print(f"INFO: --attention_type is set to '{args.attention_type}' but --self_attention is not specified. Attention will NOT be used.")
args.attention_type = 'none' # Ensure consistency
if args.self_attention and args.attention_type == 'none':
print(f"WARNING: --self_attention is specified but --attention_type is 'none'. Attention will NOT be used.")
args.self_attention = False # Ensure consistency
# --- Run Training ---
train(args)
主要修改:
- 模型初始化 (
Zi2ZiModel
):- 確認傳遞了所有必要的參數,特別是
embedding_num
(應設為 2)、attention_type
、perceptual_weight
,gradient_penalty_weight
等。 - 移除了看起來未使用的
residual_block
參數。 is_training=True
被明確設定。
- 確認傳遞了所有必要的參數,特別是
- 損失處理:
model.optimize_parameters()
現在返回一個字典losses
。- 更新了
print
語句,使用losses
字典中的鍵來顯示各個損失項 (losses['D_total']
,losses['G_total']
,losses['G_L1']
, etc.)。
- 學習率調度器 (
Scheduler
):- 移除了對不存在的
model.update_lr()
的調用。 - 在每個 epoch 結束時,調用
model.scheduler_G.step()
和model.scheduler_D.step()
來更新學習率(這對於CosineAnnealingLR
是正確的用法)。
- 移除了對不存在的
- 輸入設置 (
set_input
):- 稍微修改了
set_input
的調用方式,傳遞一個字典model_input_data
,這需要對應修改Zi2ZiModel.set_input
方法來接收字典(或者保持原樣,確保batch[0]
,batch[2]
,batch[1]
的順序正確對應labels
,real_A
,real_B
)。注意: 上面的Zi2ZiModel
佔位符沒有修改set_input
,假設之前的set_input(labels, real_A, real_B)
仍然有效,並且數據加載順序是labels=batch[0], real_A=batch[2], real_B=batch[1]
。如果你的DatasetFromObj
返回順序不同,請調整這裡。
- 稍微修改了
- 參數解析 (
argparse
):- 添加了
attention_type
、perceptual_weight
、gradient_penalty_weight
等參數。 - 為
embedding_num
設置了default=2
並在運行前添加了檢查,確保其為 2。 - 改進了
gpu_ids
的處理。 - 添加了關於
self_attention
和attention_type
組合使用的檢查/提示。
- 添加了
- 代碼結構和註釋:
- 添加了更多註釋和打印語句以提高可讀性。
- 包含了基本的
argparse
設置和if __name__ == '__main__':
結構。 - 使用了佔位符(Mock/Placeholder)來表示未提供的類(
Zi2ZiModel
,DatasetFromObj
,utils
),你需要將它們替換為你實際的導入。 - 修正了
Discriminator
中第一個卷積層的輸入通道數。 - 修正了
PerceptualLoss
中的 VGG 輸入歸一化邏輯。 - 改進了 checkpoint 清理邏輯,使其更安全。
你需要將這個 train
函數與你更新後的 Zi2ZiModel
、數據集類和工具函數放在同一個項目結構中,並確保導入路徑正確。然後,你可以通過命令行傳遞參數來運行訓練。例如:
Bash
python trainer.py --experiment_dir ./experiments/noto_zen_linear \
--gpu_ids 0 \
--batch_size 8 \
--epoch 150 \
--lr 0.0002 \
--self_attention \
--attention_type linear \
--L1_penalty 100 \
--perceptual_weight 10 \
--Lconst_penalty 15 \
--use_autocast