Implement Linear Attention Module as a replacement for the standard SelfAttention

不知道為什麼 gemini 2.5 pro 都用英文在回答, 明明就用中文去問他.

Okay, let’s replace the standard SelfAttention with a more efficient LinearAttention implementation. We’ll keep the StyleModulation (FiLM-like) block introduced previously.

1. Define the LinearAttention Module

This version avoids the N^2 complexity of standard self-attention. We’ll use the ELU feature map approach.

Python

import torch
import torch.nn as nn
import torch.nn.functional as F
# Assume other necessary imports like models, VGG16_Weights, T are present
from torchvision import models # Example import
from torchvision.models import VGG16_Weights # Example import
import torchvision.transforms as T # Example import

# --- Keep StyleModulation and ResSkip as defined before ---
class StyleModulation(nn.Module):
    def __init__(self, channels, style_dim):
        super(StyleModulation, self).__init__()
        # InstanceNorm is often used before style modulation
        self.norm = nn.InstanceNorm2d(channels, affine=False)
        # Simpler MLP for generating gamma/beta
        self.style_fc = nn.Linear(style_dim, channels * 2)
        # Initialize weights for style_fc for better starting point
        self.style_fc.weight.data.normal_(0, 0.02)
        self.style_fc.bias.data.zero_()


    def forward(self, x, style):
        x = self.norm(x)
        style_params = self.style_fc(style)  # (B, 2C)
        gamma, beta = style_params.chunk(2, dim=1)  # 各為 (B, C)
        # Reshape gamma and beta for broadcasting: (B, C, 1, 1)
        gamma = gamma.unsqueeze(2).unsqueeze(3)
        beta = beta.unsqueeze(2).unsqueeze(3)
        # Apply FiLM: gamma * x + beta
        return gamma * x + beta

class ResSkip(nn.Module):
    def __init__(self, channels):
        super(ResSkip, self).__init__()
        self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.norm = nn.GroupNorm(8, channels)  # GroupNorm is less batch-dependent
        self.relu = nn.SiLU(inplace=True)      # SiLU activation

    def forward(self, x):
        identity = x
        out = self.conv(x)
        out = self.norm(out)
        out = self.relu(out)
        return identity + out # Residual connection

# --- Efficient Linear Attention ---
class LinearAttention(nn.Module):
    """
    Linear Self-Attention mechanism.
    Approximates softmax(Q K^T / sqrt(d)) V with Q (K^T V) using feature maps.
    Complexity: O(N * d^2) where N=H*W, d=key_channels. More efficient than O(N^2 * d).
    """
    def __init__(self, channels, key_channels=None):
        super(LinearAttention, self).__init__()
        # Default key_channels to channels // 8 like original SelfAttention if None
        self.key_channels = key_channels if key_channels is not None else channels // 8
        self.value_channels = channels # Keep value channels same as input/output

        # Use Conv1x1 for Q, K, V projections
        self.query = nn.Conv2d(channels, self.key_channels, kernel_size=1)
        self.key = nn.Conv2d(channels, self.key_channels, kernel_size=1)
        self.value = nn.Conv2d(channels, self.value_channels, kernel_size=1)

        # Output projection (here, it's identity since value_channels == channels)
        self.out_proj = nn.Identity()

        # Gamma for residual connection, initialize to zero to start with identity
        self.gamma = nn.Parameter(torch.zeros(1))

        # Feature map for linearizing attention: elu(x) + 1 ensures non-negativity
        self.feature_map = lambda x: F.elu(x) + 1.0

    def forward(self, x):
        B, C, H, W = x.shape
        N = H * W # Sequence length

        # 1. Project inputs Q, K, V
        q = self.query(x) # (B, key_channels, H, W)
        k = self.key(x)   # (B, key_channels, H, W)
        v = self.value(x) # (B, value_channels, H, W)

        # 2. Apply feature map and reshape for matrix multiplication
        # Reshape to (B, channels, N)
        q_mapped = self.feature_map(q).view(B, self.key_channels, N)
        k_mapped = self.feature_map(k).view(B, self.key_channels, N)
        v_reshaped = v.view(B, self.value_channels, N)

        # 3. Compute K^T V equivalent (without explicit N^2 matrix)
        # kv = K^T @ V^T ? No. kv = K @ V^T?
        # k_mapped: (B, key_c, N)
        # v_reshaped.transpose(-1, -2): (B, N, value_c)
        # kv = torch.bmm(k_mapped, v_reshaped.transpose(-1, -2)) # (B, key_c, value_c)
        kv_context = torch.bmm(k_mapped, v_reshaped.transpose(-1, -2))

        # 4. Compute normalization factor Z = K^T @ 1 (sum over N dimension)
        # k_mapped: (B, key_c, N)
        # z_norm = torch.sum(k_mapped, dim=-1, keepdim=True) # (B, key_c, 1)
        z_norm_factor = k_mapped.sum(dim=-1, keepdim=True) # (B, key_channels, 1)


        # 5. Compute Q @ (K^T V) equivalent = Q @ kv_context
        # q_mapped.transpose(-1, -2): (B, N, key_c)
        # kv_context: (B, key_c, value_c)
        # qkv = torch.bmm(q_mapped.transpose(-1, -2), kv_context) # (B, N, value_c)
        qkv_aggregated = torch.bmm(q_mapped.transpose(-1,-2), kv_context) # (B, N, value_channels)

        # 6. Compute Q @ Z equivalent = Q @ z_norm_factor
        # q_mapped.transpose(-1, -2): (B, N, key_c)
        # z_norm_factor: (B, key_c, 1)
        # qz = torch.bmm(q_mapped.transpose(-1, -2), z_norm_factor) # (B, N, 1)
        qz_normalization = torch.bmm(q_mapped.transpose(-1,-2), z_norm_factor) # (B, N, 1)


        # 7. Normalize: Add small epsilon for stability during division
        # Transpose result back to (B, value_c, N) before reshaping
        # out_reshaped = (qkv / (qz + 1e-6)).transpose(-1, -2)
        normalized_out = (qkv_aggregated / (qz_normalization.clamp(min=1e-6))).transpose(-1,-2) # (B, value_channels, N)


        # 8. Reshape back to image format (B, value_channels, H, W)
        out = normalized_out.view(B, self.value_channels, H, W)

        # 9. Optional output projection (currently Identity)
        out = self.out_proj(out) # (B, channels, H, W)

        # 10. Add residual connection weighted by gamma
        # Initialize gamma=0 -> starts as identity mapping (output = x)
        return self.gamma * out + x


# --- Keep original SelfAttention for comparison if needed ---
class SelfAttention(nn.Module):
    # (Definition from previous code)
    def __init__(self, channels):
        super(SelfAttention, self).__init__()
        self.query = nn.Conv2d(channels, channels // 8, kernel_size=1)
        self.key = nn.Conv2d(channels, channels // 8, kernel_size=1)
        self.value = nn.Conv2d(channels, channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1)) # Initialize gamma to 0
        self.scale = (channels // 8) ** -0.5

    def forward(self, x):
        B, C, H, W = x.shape
        N = H * W
        proj_query = self.query(x).view(B, -1, N).permute(0, 2, 1) # (B, N, C')
        proj_key = self.key(x).view(B, -1, N)                     # (B, C', N)
        energy = torch.bmm(proj_query, proj_key) * self.scale     # (B, N, N)
        attention = F.softmax(energy, dim=-1)                     # (B, N, N)
        proj_value = self.value(x).view(B, -1, N)                 # (B, C, N)

        # out = torch.bmm(proj_value, attention.permute(0, 2, 1)) # (B, C, N)
        out = torch.bmm(proj_value, attention.transpose(-1,-2)) # More intuitive transpose
        out = out.view(B, C, H, W)

        return self.gamma * out + x


# --- Modify UnetSkipConnectionBlock to use attention_type ---
class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None,
                 norm_layer=nn.InstanceNorm2d, layer=0, embedding_dim=128,
                 use_dropout=False, self_attention=False, blur=False, outermost=False, innermost=False,
                 attention_type='linear'): # Added attention_type ('standard', 'linear', or 'none')
        super(UnetSkipConnectionBlock, self).__init__()

        self.outermost = outermost
        self.innermost = innermost
        self.embedding_dim = embedding_dim
        use_bias = norm_layer != nn.BatchNorm2d # InstanceNorm handles bias via affine=True

        if input_nc is None:
            input_nc = outer_nc

        # --- Downsampling Layers ---
        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
        nn.init.kaiming_normal_(downconv.weight, nonlinearity='leaky_relu')
        downrelu = nn.LeakyReLU(0.2, True)
        # Norm layer might be part of StyleModulation or separate
        # Let StyleModulation handle its own norm for now
        # downnorm = norm_layer(inner_nc) # Removed, handled by StyleModulation

        # --- Upsampling Layers ---
        uprelu = nn.ReLU(inplace=True) # Use inplace=True for uprelu as well
        upnorm = norm_layer(outer_nc) # Norm for the output of upconv

        # --- Block Structure Definition ---
        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=True) # Use bias in final layer often okay
            nn.init.kaiming_normal_(upconv.weight)
            self.down = nn.Sequential(downrelu, downconv) # Add ReLU before conv here too? Original had just conv. Let's match original: just conv
            self.down = nn.Sequential(downconv)
            self.up = nn.Sequential(uprelu, upconv, nn.Tanh())
            self.attn_block = None
            self.style_mod = None
            self.res_skip = None
        elif innermost:
            # Upconv input is just inner_nc because no submodule output to concat
            upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            nn.init.kaiming_normal_(upconv.weight)
            self.down = nn.Sequential(downrelu, downconv) # No norm here, style_mod has it
            self.up = nn.Sequential(uprelu, upconv, upnorm) # Add norm after upconv
            # Style modulation applied after down block + attention
            self.style_mod = StyleModulation(inner_nc, embedding_dim)
            self.res_skip = None # No skip connection applied before cat in innermost return
        else: # Intermediate blocks
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            nn.init.kaiming_normal_(upconv.weight)
            self.down = nn.Sequential(downrelu, downconv) # No norm here, style_mod has it
            self.up = nn.Sequential(uprelu, upconv, upnorm) # Add norm after upconv
            # Style modulation applied after down block + attention
            self.style_mod = StyleModulation(inner_nc, embedding_dim)
            # ResSkip applied *after* upsampling, before final cat
            self.res_skip = ResSkip(outer_nc)
            if use_dropout:
                 # Add dropout module to self.up sequence AFTER upnorm
                 self.up.add_module("dropout", nn.Dropout(0.3))


        self.submodule = submodule

        # --- Attention Block Instantiation ---
        if self_attention and layer in [4, 6]: # Apply attention only at specified layers
            if attention_type == 'linear':
                print(f"INFO: Using Linear Attention in U-Net layer {layer} (inner_nc={inner_nc})")
                self.attn_block = LinearAttention(inner_nc)
            elif attention_type == 'standard':
                print(f"INFO: Using Standard Self-Attention in U-Net layer {layer} (inner_nc={inner_nc})")
                self.attn_block = SelfAttention(inner_nc)
            elif attention_type == 'none':
                print(f"INFO: No Attention in U-Net layer {layer}")
                self.attn_block = None
            else:
                print(f"WARNING: Invalid attention_type '{attention_type}' in U-Net layer {layer}. No attention applied.")
                self.attn_block = None
        else:
             self.attn_block = None # No attention if flag is False or not the specified layer


    def _process_submodule(self, encoded, style):
        # Pass style embedding down recursively
        if self.submodule:
            # Submodule forward returns (concatenated_output, bottleneck_features)
            return self.submodule(encoded, style)
        else: # Innermost block reached
             # For const_loss, return the modulated features before upsampling
             # The 'encoded' passed here is already modulated by style_mod in the innermost forward
             bottleneck_features = encoded.view(encoded.shape[0], -1)
             # Innermost doesn't return concatenated output, just the raw upsampled features (decoded)
             # Let's return None for the first element to signal it's the bottom.
             # No, the caller expects (sub_output, encoded_real_A).
             # Let sub_output be the raw encoded features, and encoded_real_A be the flattened ones.
             return encoded, bottleneck_features

    def _interpolate_if_needed(self, decoded, x):
        # This function remains the same
        if decoded.shape[2:] != x.shape[2:]:
             return F.interpolate(decoded, size=x.shape[2:], mode='bilinear', align_corners=False)
        return decoded

    def forward(self, x, style): # Style embedding is now required for modulation
        # --- Down Path ---
        encoded_pre_attn = self.down(x)

        # --- Attention (Optional) ---
        if self.attn_block:
            encoded_post_attn = self.attn_block(encoded_pre_attn)
        else:
            encoded_post_attn = encoded_pre_attn

        # --- Style Modulation (Applies InstanceNorm first) ---
        if self.style_mod and style is not None:
            encoded_modulated = self.style_mod(encoded_post_attn, style)
        else:
            # Outermost block doesn't have style_mod
            # Or handle cases where style is None defensively (though should always be provided now)
            encoded_modulated = encoded_post_attn

        # --- Process Submodule or Innermost Logic ---
        if self.innermost:
             # encoded_modulated contains the final bottleneck features
             bottleneck_features_flat = encoded_modulated.view(x.shape[0], -1)

             # Apply up path blocks: ReLU, ConvTranspose, Norm, [Dropout]
             decoded = self.up(encoded_modulated)
             # Interpolate
             decoded = self._interpolate_if_needed(decoded, x)

             # Innermost returns concatenated input and decoded output for the layer above it,
             # plus the flattened bottleneck features for the consistency loss.
             return torch.cat([x, decoded], 1), bottleneck_features_flat

        else: # Intermediate or Outermost blocks
             # Recursively call submodule
             # Pass the modulated features down
             sub_output_cat, bottleneck_features_flat = self._process_submodule(encoded_modulated, style)

             # --- Up Path ---
             # Apply up path blocks: ReLU, ConvTranspose, Norm, [Dropout]
             # Input to self.up is the concatenated output from the layer below (sub_output_cat)
             decoded = self.up(sub_output_cat)

             # Interpolate
             decoded = self._interpolate_if_needed(decoded, x)

             # Apply ResSkip (if not outermost)
             if self.res_skip:
                 decoded = self.res_skip(decoded)

             # --- Output ---
             if self.outermost:
                 # Final output image and the bottleneck features from the bottom
                 return decoded, bottleneck_features_flat
             else:
                 # Concatenate skip connection (original x) with decoded features
                 # Return this for the layer above, along with bottleneck features
                 return torch.cat([x, decoded], 1), bottleneck_features_flat

# --- Modify UNetGenerator to accept attention_type ---
class UNetGenerator(nn.Module):
    def __init__(self, input_nc=1, output_nc=1, num_downs=8, ngf=64, embedding_num=2, embedding_dim=128,
                 norm_layer=nn.InstanceNorm2d, use_dropout=False, self_attention=False, blur=False,
                 attention_type='linear'): # Added attention_type, default to linear
        super(UNetGenerator, self).__init__()
        print(f"--- Initializing UNetGenerator with {num_downs} downs, Attention: {self_attention}, Type: {attention_type} ---")

        # --- Build the U-Net with FiLM and selectable Attention ---
        # Innermost block
        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None,
                                             norm_layer=norm_layer, layer=1, embedding_dim=embedding_dim,
                                             self_attention=self_attention, blur=blur, innermost=True,
                                             attention_type=attention_type) # Pass params

        # Middle layers (if num_downs > 5)
        for index in range(num_downs - 5):
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block,
                                                 norm_layer=norm_layer, layer=index+2, embedding_dim=embedding_dim,
                                                 use_dropout=use_dropout, self_attention=self_attention, blur=blur,
                                                 attention_type=attention_type) # Pass params

        # Upper layers (closer to output)
        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block,
                                             norm_layer=norm_layer, layer=num_downs-3, embedding_dim=embedding_dim,
                                             self_attention=self_attention, blur=blur,
                                             attention_type=attention_type) # Pass params
        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block,
                                             norm_layer=norm_layer, layer=num_downs-2, embedding_dim=embedding_dim,
                                             self_attention=self_attention, blur=blur,
                                             attention_type=attention_type) # Pass params
        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block,
                                             norm_layer=norm_layer, layer=num_downs-1, embedding_dim=embedding_dim,
                                             self_attention=self_attention, blur=blur,
                                             attention_type=attention_type) # Pass params

        # Outermost block
        self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block,
                                             norm_layer=norm_layer, layer=num_downs, embedding_dim=embedding_dim,
                                             self_attention=self_attention, blur=blur, outermost=True,
                                             attention_type=attention_type) # Pass params

        # Embedding layer for the 2 styles
        self.embedder = nn.Embedding(embedding_num, embedding_dim)
        print(f"Style Embedder: {embedding_num} styles, dim={embedding_dim}")

    def _prepare_style(self, style_or_label):
        # Ensure labels are long type and on the correct device
        if style_or_label is not None and isinstance(style_or_label, torch.Tensor):
             if 'LongTensor' in style_or_label.type():
                  # Check device consistency if needed, assume input tensor is already on correct device
                  if style_or_label.max() >= self.embedder.num_embeddings:
                       raise ValueError(f"Label index {style_or_label.max()} is out of bounds for embedding_num={self.embedder.num_embeddings}")
                  return self.embedder(style_or_label)
             else: # Assume it's already an embedding
                  return style_or_label
        elif style_or_label is None:
             raise ValueError("Style/label cannot be None for StyleModulation U-Net.")
        else: # Handle non-tensor case? Maybe raise error.
             raise TypeError(f"Unsupported type for style_or_label: {type(style_or_label)}")


    def forward(self, x, style_or_label): # Style is required
        style = self._prepare_style(style_or_label)
        # self.model returns (output_image, bottleneck_features_flat)
        fake_B, encoded_bottleneck = self.model(x, style)
        return fake_B, encoded_bottleneck

    def encode(self, x, style_or_label): # Style is required
        style = self._prepare_style(style_or_label)
        # Call the model and extract the second return value (bottleneck features)
        _, encoded_bottleneck = self.model(x, style)
        return encoded_bottleneck

# --- Keep Discriminator, CategoryLoss, PerceptualLoss as before ---
# Make sure imports (like T) are present
import torchvision.transforms as T # Example: Ensure T is imported if used in Discriminator

class Discriminator(nn.Module):
    def __init__(self, input_nc, embedding_num=2, ndf=64, norm_layer=nn.BatchNorm2d, blur=False): # Default embedding_num=2
        super(Discriminator, self).__init__()

        use_bias = norm_layer != nn.BatchNorm2d
        kw = 5
        padw = 2

        sequence = [
            # Input is cat(real_A, fake_B) or cat(real_A, real_B), so input_nc * 2 channels
            nn.utils.spectral_norm(nn.Conv2d(input_nc * 2, ndf, kernel_size=kw, stride=2, padding=padw)),
            nn.LeakyReLU(0.2, True)
        ]

        nf_mult = 1
        # Typically 3 downsampling layers in PatchGAN-like D
        for n in range(1, 3): # n = 1, 2
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8) # nf_mult = 2, 4
            sequence += [
                nn.utils.spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias)),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]
            # Output channels: ndf*2, ndf*4

        # Final convolution before output layers
        nf_mult_prev = nf_mult # 4
        nf_mult = min(2 ** 3, 8) # 8
        sequence += [
            nn.utils.spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias)),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]
        # Output channels: ndf*8

        self.model = nn.Sequential(*sequence)

        # Output layers: Use AdaptiveAvgPool to handle variable input sizes before FC layers
        # Output size from model depends on input image size and strides.
        # Example: 256 -> 128 -> 64 -> 32. Final feature map size = 32x32
        # If using AdaptiveAvgPool((1, 1)), final_features = ndf * nf_mult
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        final_features = ndf * nf_mult # ndf * 8

        self.binary = nn.Linear(final_features, 1)
        self.category = nn.Linear(final_features, embedding_num) # Use embedding_num here

        self.blur = blur
        if blur:
            # Ensure T is imported: import torchvision.transforms as T
            self.gaussian_blur = T.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)) # Use range for sigma

    def forward(self, input_concat): # Input is already concatenated real_A + real_B/fake_B
        if self.blur:
            input_concat = self.gaussian_blur(input_concat)

        features = self.model(input_concat)
        pooled_features = self.global_pool(features)
        flat_features = pooled_features.view(input_concat.shape[0], -1) # Flatten

        binary_logits = self.binary(flat_features)
        category_logits = self.category(flat_features)
        return binary_logits, category_logits

class CategoryLoss(nn.Module):
    # (Definition from previous code)
    def __init__(self, category_num=2): # Default to 2
        super(CategoryLoss, self).__init__()
        # Use identity matrix for one-hot encoding target
        self.register_buffer('identity', torch.eye(category_num))
        self.loss = nn.BCEWithLogitsLoss()

    def forward(self, category_logits, labels):
        # Ensure labels are on the same device as identity matrix and logits
        target = self.identity[labels].to(category_logits.device)
        return self.loss(category_logits, target)


class PerceptualLoss(nn.Module):
    # (Definition from previous code - remember VGG input normalization)
    def __init__(self, normalize_input=True): # Add flag for normalization
        super(PerceptualLoss, self).__init__()
        vgg = models.vgg16(weights=VGG16_Weights.DEFAULT).features
        self.slice1 = nn.Sequential(*list(vgg[:4]))   # Conv1_2
        self.slice2 = nn.Sequential(*list(vgg[4:9]))  # Conv2_2
        self.slice3 = nn.Sequential(*list(vgg[9:16])) # Conv3_3
        self.slice4 = nn.Sequential(*list(vgg[16:23]))# Conv4_3

        for param in self.parameters():
            param.requires_grad = False

        self.normalize_input = normalize_input
        if normalize_input:
            # ImageNet mean and std
            self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
            self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

    def _normalize(self, x):
        # Assumes input x is in range [0, 1] (adjust if it's [-1, 1])
        if not self.normalize_input:
             return x
        if x.min() < -0.1 or x.max() > 1.1: # Basic check if data is likely [-1, 1]
             x = (x + 1.0) / 2.0 # Convert from [-1, 1] to [0, 1]
        # Normalize to ImageNet stats
        return (x - self.mean) / self.std

    def forward(self, x, y):
        """Input x, y assumed to be generator output / target image"""
        if x.shape[1] == 1:
            x = x.repeat(1, 3, 1, 1)
            y = y.repeat(1, 3, 1, 1)

        # Normalize inputs for VGG
        x = self._normalize(x.to(self.mean.device))
        y = self._normalize(y.to(self.mean.device))

        # Extract features
        fx1, fy1 = self.slice1(x), self.slice1(y)
        fx2, fy2 = self.slice2(fx1), self.slice2(fy1)
        fx3, fy3 = self.slice3(fx2), self.slice3(fy2)
        fx4, fy4 = self.slice4(fx3), self.slice4(fy3)

        # L1 loss on feature maps
        loss = (
            F.l1_loss(fx1, fy1) +
            F.l1_loss(fx2, fy2) +
            F.l1_loss(fx3, fy3) +
            F.l1_loss(fx4, fy4)
        )
        return loss


# --- Modify Zi2ZiModel to accept attention_type and set embedding_num=2 ---
class Zi2ZiModel:
    # Default embedding_num=2, attention_type='linear'
    def __init__(self, input_nc=1, embedding_num=2, embedding_dim=128, ngf=64, ndf=64,
                 Lconst_penalty=10, Lcategory_penalty=1, L1_penalty=100,
                 schedule=10, lr=0.001, gpu_ids=None, save_dir='.', is_training=True,
                 self_attention=False, attention_type='linear', # Added attention_type
                 weight_decay = 1e-5, beta1=0.5, g_blur=False, d_blur=False, epoch=40,
                 gradient_clip=0.5, norm_type="instance"):

        print(f"--- Initializing Zi2ZiModel ---")
        print(f"Styles: {embedding_num}, Dim: {embedding_dim}")
        print(f"Generator Attention: {self_attention}, Type: {attention_type}")
        print(f"Norm Type: {norm_type}")
        # Store attention config
        self.self_attention = self_attention
        self.attention_type = attention_type

        self.norm_type = norm_type
        if is_training:
            self.use_dropout = True
        else:
            self.use_dropout = False

        # Loss weights
        self.Lconst_penalty = Lconst_penalty
        self.Lcategory_penalty = Lcategory_penalty
        self.L1_penalty = L1_penalty
        self.perceptual_weight = 10.0 # Make perceptual weight configurable?
        self.gradient_penalty_weight = 10.0 # Make GP weight configurable?


        self.epoch = epoch
        self.schedule = schedule # Unused? Check usage.
        self.save_dir = save_dir
        self.gpu_ids = gpu_ids
        self.device = torch.device("cuda" if self.gpu_ids and torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")

        self.input_nc = input_nc
        self.embedding_dim = embedding_dim
        self.embedding_num = embedding_num # Should be 2
        self.ngf = ngf
        self.ndf = ndf
        self.lr = lr
        self.beta1 = beta1
        self.weight_decay = weight_decay
        self.is_training = is_training
        # self.residual_block param was unused, removed.
        self.g_blur = g_blur
        self.d_blur = d_blur
        self.gradient_clip = gradient_clip

        # Setup components (networks, optimizers, losses)
        self.setup()

        # Grad scalers for mixed precision
        self.scaler_G = torch.cuda.amp.GradScaler(enabled=self.device.type == 'cuda')
        self.scaler_D = torch.cuda.amp.GradScaler(enabled=self.device.type == 'cuda')


    def setup(self):
        """Initializes networks, optimizers, and loss functions."""
        print("Setting up networks, optimizers, and losses...")
        if self.norm_type.lower() == "batch":
            # Note: BatchNorm might interfere with style modulation / instance-level tasks
            norm_layer_g = nn.BatchNorm2d
            print("Using BatchNorm for Generator U-Net.")
        elif self.norm_type.lower() == "instance":
            norm_layer_g = nn.InstanceNorm2d
            print("Using InstanceNorm for Generator U-Net.")
        else:
            raise ValueError(f"Unsupported norm_type for Generator: {self.norm_type}")

        # Use BatchNorm for Discriminator as per original code? Or InstanceNorm?
        # Often BatchNorm is fine for Discriminator.
        norm_layer_d = nn.BatchNorm2d
        print("Using BatchNorm for Discriminator.")


        # --- Initialize Generator ---
        self.netG = UNetGenerator(
            input_nc=self.input_nc,
            output_nc=self.input_nc, # Output is single channel like input
            ngf=self.ngf,
            use_dropout=self.use_dropout,
            embedding_num=self.embedding_num,
            embedding_dim=self.embedding_dim,
            self_attention=self.self_attention, # Pass flag
            attention_type=self.attention_type, # Pass type
            blur=self.g_blur,
            norm_layer=norm_layer_g # Pass chosen norm layer
        ).to(self.device)

        # --- Initialize Discriminator ---
        self.netD = Discriminator(
            input_nc=self.input_nc, # D sees concatenated input (A+B = 1+1=2 channels)
            embedding_num=self.embedding_num,
            ndf=self.ndf,
            blur=self.d_blur,
            norm_layer=norm_layer_d # Use specified norm for D
        ).to(self.device)

        # Weight Initialization (Assuming init_net exists and works)
        # Make sure init_net is defined somewhere
        # init_net(self.netG, gpu_ids=self.gpu_ids)
        # init_net(self.netD, gpu_ids=self.gpu_ids)
        print("Skipping init_net call (ensure it's defined and used if needed).")


        # --- Optimizers (AdamW is a good choice) ---
        self.optimizer_G = torch.optim.AdamW(self.netG.parameters(), lr=self.lr, betas=(self.beta1, 0.999), weight_decay=self.weight_decay)
        self.optimizer_D = torch.optim.AdamW(self.netD.parameters(), lr=self.lr, betas=(self.beta1, 0.999), weight_decay=self.weight_decay)
        print(f"Optimizers: AdamW (lr={self.lr}, beta1={self.beta1}, wd={self.weight_decay})")

        # --- Schedulers (Cosine Annealing) ---
        eta_min = self.lr * 0.01 # Example: anneal down to 1% of initial LR
        self.scheduler_G = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer_G, T_max=self.epoch, eta_min=eta_min)
        self.scheduler_D = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer_D, T_max=self.epoch, eta_min=eta_min)
        print(f"Schedulers: CosineAnnealingLR (T_max={self.epoch}, eta_min={eta_min})")


        # --- Loss Functions ---
        self.criterion_L1 = nn.L1Loss().to(self.device)
        self.criterion_MSE = nn.MSELoss().to(self.device) # For const_loss
        self.criterion_Category = CategoryLoss(self.embedding_num).to(self.device)
        self.criterion_Perceptual = PerceptualLoss().to(self.device) # Assumes VGG normalization inside
        self.criterion_FeatureMatch = nn.L1Loss().to(self.device) # For Feature Matching Loss

        print("Loss functions initialized.")

        # Set training/eval mode
        self.set_train_eval_mode()

    def set_train_eval_mode(self):
        if self.is_training:
            self.netG.train()
            self.netD.train()
            print("Model set to TRAIN mode.")
        else:
            self.netG.eval()
            self.netD.eval()
            print("Model set to EVAL mode.")

    def set_input(self, data):
        """Unpack input data from the dataloader and perform necessary pre-processing."""
        # Assuming data is a dictionary or tuple: {'label': labels, 'A': real_A, 'B': real_B}
        self.labels = data['label'].to(self.device)
        self.real_A = data['A'].to(self.device) # Input font image
        self.real_B = data['B'].to(self.device) # Target font image

    def forward(self):
        """Run forward pass; called by both optimize_parameters and inference."""
        # Generate fake image and get bottleneck features for consistency loss
        self.fake_B, self.encoded_real_A = self.netG(self.real_A, self.labels)
        # Encode the generated fake image using the same style label
        self.encoded_fake_B = self.netG.encode(self.fake_B, self.labels)

    def compute_feature_matching_loss(self, real_features_list, fake_features_list):
        """Computes L1 loss between intermediate features of the discriminator."""
        fm_loss = 0.0
        # Assuming D's forward pass returns intermediate features if needed,
        # or modify D to extract features. For now, let's assume simple L1 on D's output pre-pool.
        # This requires modifying D's forward or using hooks.
        # Simpler: Use the feature matching loss from original code (on output before pooling?)
        # D.model produces features before pooling/FC layers
        real_fm = self.netD.model(torch.cat([self.real_A, self.real_B], 1))
        fake_fm = self.netD.model(torch.cat([self.real_A, self.fake_B], 1))
        fm_loss = self.criterion_FeatureMatch(fake_fm, real_fm.detach()) # Detach real features
        return fm_loss * 10.0 # Add weight to FM loss (common practice)

    def compute_gradient_penalty(self, real_samples, fake_samples):
        """Calculates the gradient penalty loss for WGAN-GP"""
        # Random weight term for interpolation between real and fake samples
        alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=self.device)
        # Get random interpolation between real and fake samples
        interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)

        d_interpolates, _ = self.netD(interpolates) # Get discriminator output for interpolates

        # Use autograd to compute gradients
        grad_outputs = torch.ones(d_interpolates.size(), device=self.device, requires_grad=False)
        gradients = torch.autograd.grad(
            outputs=d_interpolates,
            inputs=interpolates,
            grad_outputs=grad_outputs,
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]
        gradients = gradients.view(gradients.size(0), -1) # Flatten gradients
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        return gradient_penalty


    def backward_D(self):
        """Calculate GAN loss for the discriminator"""
        # Real samples (pair real source image with real target image)
        real_AB = torch.cat([self.real_A, self.real_B], 1)
        # Fake samples (pair real source image with generated fake image)
        # Detach fake_B to prevent gradients flowing back to the generator here
        fake_AB = torch.cat([self.real_A, self.fake_B.detach()], 1)

        # Get discriminator predictions
        real_D_logits, real_category_logits = self.netD(real_AB)
        fake_D_logits, fake_category_logits = self.netD(fake_AB)

        # --- Adversarial Loss (Relativistic Average Logit) ---
        # D tries to make real_D_logits higher than fake_D_logits on average
        loss_D_adv = -torch.mean(F.logsigmoid(real_D_logits - fake_D_logits) +
                                  F.logsigmoid(fake_D_logits - real_D_logits))


        # --- Category Loss ---
        # D should correctly classify the style of both real and fake images
        loss_D_real_category = self.criterion_Category(real_category_logits, self.labels)
        loss_D_fake_category = self.criterion_Category(fake_category_logits, self.labels)
        loss_D_category = (loss_D_real_category + loss_D_fake_category) * 0.5 # Average category loss

        # --- Gradient Penalty ---
        # Optional: Apply only if using WGAN-GP loss type instead of relativistic
        # gp = self.compute_gradient_penalty(real_AB.data, fake_AB.data)
        # Using RaLSGAN, GP is less standard, maybe remove or reduce weight.
        gp = self.compute_gradient_penalty(real_AB, fake_AB) # Keep it for now

        # --- Total Discriminator Loss ---
        self.loss_D = loss_D_adv + \
                      loss_D_category * self.Lcategory_penalty + \
                      gp * self.gradient_penalty_weight

        # Return individual components for logging
        return loss_D_adv, loss_D_category, gp

    def backward_G(self):
        """Calculate GAN loss for the generator"""
        # --- Adversarial Loss ---
        # Generator tries to fool the discriminator
        real_AB = torch.cat([self.real_A, self.real_B], 1).detach() # Detach real pair
        fake_AB = torch.cat([self.real_A, self.fake_B], 1)
        real_D_logits, _ = self.netD(real_AB)
        fake_D_logits, fake_category_logits = self.netD(fake_AB)

        # Relativistic Average Logit loss for G
        # G tries to make fake logits higher than real logits on average
        loss_G_adv = -torch.mean(F.logsigmoid(fake_D_logits - real_D_logits) +
                                  F.logsigmoid(real_D_logits - fake_D_logits))


        # --- Category Loss ---
        # Generator should produce images of the target style
        loss_G_category = self.criterion_Category(fake_category_logits, self.labels)

        # --- L1 Reconstruction Loss ---
        loss_G_L1 = self.criterion_L1(self.fake_B, self.real_B)

        # --- Consistency Loss ---
        # Compare bottleneck features of real source and fake target
        loss_G_const = self.criterion_MSE(self.encoded_fake_B, self.encoded_real_A.detach()) # Detach real encoding

        # --- Feature Matching Loss ---
        # Match intermediate features in Discriminator
        loss_G_FM = self.compute_feature_matching_loss(real_AB, fake_AB)

        # --- Perceptual Loss ---
        loss_G_perceptual = self.criterion_Perceptual(self.fake_B, self.real_B)

        # --- Total Generator Loss ---
        self.loss_G = loss_G_adv + \
                      loss_G_category * self.Lcategory_penalty + \
                      loss_G_L1 * self.L1_penalty + \
                      loss_G_const * self.Lconst_penalty + \
                      loss_G_FM + \
                      loss_G_perceptual * self.perceptual_weight

        # Return individual components for logging
        return loss_G_adv, loss_G_category, loss_G_L1, loss_G_const, loss_G_FM, loss_G_perceptual


    def optimize_parameters(self, use_autocast=True): # Enable autocast by default if cuda available
        """Calculate losses, gradients, and update network weights."""
        autocast_enabled = use_autocast and self.device.type == 'cuda'

        # === Forward Pass ===
        # Moved forward call inside optimizer steps if needed for gradient penalty/autocast context
        # self.forward() # Run forward pass to get fake_B, encodings etc.

        # === Update Discriminator ===
        self.set_requires_grad(self.netD, True) # Enable grads for D
        self.optimizer_D.zero_grad()

        with torch.cuda.amp.autocast(enabled=autocast_enabled):
             self.forward() # Forward pass needed to compute fake_B for D loss
             # Calculate D loss components
             loss_D_adv, loss_D_category, gp = self.backward_D()
             # Scale the loss for mixed precision
             scaled_loss_D = self.scaler_D.scale(self.loss_D)

        # Backward pass for D loss
        scaled_loss_D.backward()

        # Unscale gradients and step optimizer for D
        self.scaler_D.unscale_(self.optimizer_D)
        # Optional: Clip gradients for D
        # torch.nn.utils.clip_grad_norm_(self.netD.parameters(), self.gradient_clip)
        self.scaler_D.step(self.optimizer_D)
        self.scaler_D.update()

        # Check for NaN Discriminator loss
        if torch.isnan(self.loss_D):
             print("ERROR: Discriminator loss is NaN. Stopping training.")
             # Handle error appropriately, e.g., raise exception or exit
             raise RuntimeError("Discriminator loss is NaN")


        # === Update Generator ===
        self.set_requires_grad(self.netD, False) # Disable grads for D
        self.optimizer_G.zero_grad()

        with torch.cuda.amp.autocast(enabled=autocast_enabled):
             self.forward() # Forward pass needed again for G loss components
             # Calculate G loss components
             loss_G_adv, loss_G_category, loss_G_L1, loss_G_const, loss_G_FM, loss_G_perceptual = self.backward_G()
             # Scale the loss for mixed precision
             scaled_loss_G = self.scaler_G.scale(self.loss_G)

        # Backward pass for G loss
        scaled_loss_G.backward()

        # Unscale gradients and step optimizer for G
        self.scaler_G.unscale_(self.optimizer_G)
        # Clip gradients for G
        torch.nn.utils.clip_grad_norm_(self.netG.parameters(), self.gradient_clip)
        self.scaler_G.step(self.optimizer_G)
        self.scaler_G.update()

        # Check for NaN Generator loss
        if torch.isnan(self.loss_G):
             print("ERROR: Generator loss is NaN. Stopping training.")
             raise RuntimeError("Generator loss is NaN")


        # Return scalar values of main losses for logging
        # Use .item() to get Python float from tensor
        return {
            'G_adv': loss_G_adv.item(),
            'G_category': loss_G_category.item(),
            'G_L1': loss_G_L1.item(),
            'G_const': loss_G_const.item(),
            'G_FM': loss_G_FM.item(),
            'G_perceptual': loss_G_perceptual.item(),
            'G_total': self.loss_G.item(),
            'D_adv': loss_D_adv.item(),
            'D_category': loss_D_category.item(),
            'D_gp': gp.item(),
            'D_total': self.loss_D.item()
        }

    def set_requires_grad(self, nets, requires_grad=False):
        """Set requires_grad=False for all parameters in a network list to avoid unnecessary computations"""
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad

    # Add methods for saving/loading models, inference etc. as needed
    def save_networks(self, epoch_label):
        """Saves generator and discriminator networks."""
        save_filename_g = f'{epoch_label}_net_G.pth'
        save_path_g = os.path.join(self.save_dir, save_filename_g)
        save_filename_d = f'{epoch_label}_net_D.pth'
        save_path_d = os.path.join(self.save_dir, save_filename_d)

        torch.save(self.netG.cpu().state_dict(), save_path_g)
        torch.save(self.netD.cpu().state_dict(), save_path_d)

        # Move back to original device if necessary
        if torch.cuda.is_available():
            self.netG.to(self.device)
            self.netD.to(self.device)

        print(f"Saved networks for epoch {epoch_label} to {self.save_dir}")

    def load_networks(self, epoch_label):
        """Loads generator and discriminator networks."""
        load_filename_g = f'{epoch_label}_net_G.pth'
        load_path_g = os.path.join(self.save_dir, load_filename_g)
        load_filename_d = f'{epoch_label}_net_D.pth'
        load_path_d = os.path.join(self.save_dir, load_filename_d)

        if os.path.exists(load_path_g):
            state_dict_g = torch.load(load_path_g, map_location=self.device)
            self.netG.load_state_dict(state_dict_g)
            print(f"Loaded generator weights from {load_path_g}")
        else:
            print(f"WARNING: Generator weights not found at {load_path_g}")

        if os.path.exists(load_path_d):
            state_dict_d = torch.load(load_path_d, map_location=self.device)
            self.netD.load_state_dict(state_dict_d)
            print(f"Loaded discriminator weights from {load_path_d}")
        else:
            print(f"WARNING: Discriminator weights not found at {load_path_d}")

Key Changes and Considerations:

  1. LinearAttention Class: Implemented a linear attention mechanism. It uses 1×1 convolutions for Q, K, V projections and ELU+1 feature maps. It avoids computing the full N^2 attention matrix. Gamma is initialized to 0 for the residual connection.
  2. UnetSkipConnectionBlock:
    • Added attention_type parameter to __init__.
    • Conditionally creates self.attn_block as either LinearAttention, SelfAttention, or None based on self_attention flag, layer number, and attention_type.
    • Calls self.attn_block(encoded) if it exists, right after the downsampling block (self.down) and before style modulation (self.style_mod).
    • Cleaned up norm layer placement slightly (removed redundant downnorm as StyleModulation includes InstanceNorm).
  3. UNetGenerator:
    • Added attention_type to __init__ and passed it down to UnetSkipConnectionBlock. Defaults to 'linear'.
  4. Discriminator:
    • Changed global_pool to AdaptiveAvgPool2d((1, 1)) for robustness to feature map size changes before the linear layers. Adjusted final_features calculation accordingly.
    • Corrected input_nc usage in the first conv layer (it should be input_nc * 2 because the input is concatenated A+B).
  5. PerceptualLoss: Added optional input normalization using ImageNet stats, assuming input images are in [0, 1] or [-1, 1].
  6. Zi2ZiModel:
    • Added attention_type parameter.
    • Set embedding_num=2 (make sure this is correct for your data labels 0 and 1).
    • Passed attention_type down to UNetGenerator.
    • Refactored optimize_parameters slightly for clarity, using torch.cuda.amp.autocast context manager.
    • Refined loss function naming (criterion_L1, etc.).
    • Added basic save_networks/load_networks methods (requires import os).
    • Updated loss calculation logic slightly (backward_D, backward_G) for clarity and consistency (e.g., detaching where needed).
    • Made compute_feature_matching_loss simpler (matching features from netD.model) and added a common weighting factor.

Now, when you initialize Zi2ZiModel, you can control whether attention is used (self_attention=True) and which type (attention_type='linear' or 'standard'). For the requested efficient version, use self_attention=True and attention_type='linear'.


Python

import os
import math
import time
import random
import argparse # Need argparse to define args

import torch
from torch.utils.data import DataLoader
# Assume these are defined in other files relative to trainer.py
# from model import Zi2ZiModel # Import the updated model class
# from dataset import DatasetFromObj # Import your dataset class
# from utils import ensure_dir, setup_google_drive_service, clear_google_drive_trash # Import utility functions

# --- Mock/Placeholder definitions for missing components ---
# Replace these with your actual imports/definitions
class Zi2ZiModel: # Placeholder based on previous definitions
    def __init__(self, input_nc=1, embedding_num=2, embedding_dim=128, ngf=64, ndf=64,
                 Lconst_penalty=10, Lcategory_penalty=1, L1_penalty=100, perceptual_weight=10.0, gradient_penalty_weight=10.0,
                 schedule=10, lr=0.001, gpu_ids=None, save_dir='.', is_training=True,
                 self_attention=False, attention_type='linear',
                 weight_decay = 1e-5, beta1=0.5, g_blur=False, d_blur=False, epoch=40,
                 gradient_clip=0.5, norm_type="instance"):
        self.device = torch.device("cuda" if gpu_ids and torch.cuda.is_available() else "cpu")
        self.save_dir = save_dir
        self.epoch = epoch
        self.lr=lr
        self.beta1=beta1
        self.weight_decay=weight_decay
        self.embedding_num=embedding_num
        self.gradient_clip=gradient_clip
        self.Lconst_penalty=Lconst_penalty
        self.Lcategory_penalty=Lcategory_penalty
        self.L1_penalty=L1_penalty
        self.perceptual_weight=perceptual_weight
        self.gradient_penalty_weight=gradient_penalty_weight
        self.is_training=is_training
        print(f"Placeholder Zi2ZiModel initialized on {self.device} with attention={self_attention}, type={attention_type}, embed_num={embedding_num}")
        # Mock networks and optimizers for scheduler step
        self.netG = torch.nn.Module() # Mock network
        self.netD = torch.nn.Module() # Mock network
        self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=lr)
        self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=lr)
        eta_min = lr * 0.01
        self.scheduler_G = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer_G, T_max=epoch, eta_min=eta_min)
        self.scheduler_D = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer_D, T_max=epoch, eta_min=eta_min)
        self.loss_D = torch.tensor(0.0) # Mock loss attributes
        self.loss_G = torch.tensor(0.0) # Mock loss attributes

    def print_networks(self, verbose): print(f"Mock print_networks({verbose})")
    def load_networks(self, epoch_label): print(f"Mock load_networks({epoch_label})")
    def save_networks(self, step_label):
        print(f"Mock save_networks({step_label})")
        # Mock saving files to test deletion logic
        # try:
        #     os.makedirs(self.save_dir, exist_ok=True)
        #     with open(os.path.join(self.save_dir, f"{step_label}_net_G.pth"), 'w') as f: f.write("mock G")
        #     with open(os.path.join(self.save_dir, f"{step_label}_net_D.pth"), 'w') as f: f.write("mock D")
        # except Exception as e:
        #     print(f"Mock save error: {e}")
        pass
    def set_input(self, labels, real_A, real_B): print("Mock set_input()")
    def optimize_parameters(self, use_autocast):
        print(f"Mock optimize_parameters(use_autocast={use_autocast})")
        # Return mock loss dictionary matching the latest model's return format
        self.loss_D = torch.tensor(random.random())
        self.loss_G = torch.tensor(random.random() * 2)
        return {
            'G_adv': torch.tensor(random.random()).item(), 'G_category': torch.tensor(random.random()).item(),
            'G_L1': torch.tensor(random.random()*10).item(), 'G_const': torch.tensor(random.random()).item(),
            'G_FM': torch.tensor(random.random()).item(), 'G_perceptual': torch.tensor(random.random()).item(),
            'G_total': self.loss_G.item(), 'D_adv': torch.tensor(random.random()).item(),
            'D_category': torch.tensor(random.random()).item(), 'D_gp': torch.tensor(random.random()*0.1).item(),
            'D_total': self.loss_D.item()
        }
    def setup(self): pass # Mock method


class DatasetFromObj: # Placeholder
    def __init__(self, path, input_nc=1):
        self.path = path
        self.input_nc = input_nc
        # Mock data
        self.data = [(torch.tensor([i % 2]), # label (0 or 1)
                      torch.randn(input_nc, 64, 64), # image B (target font)
                      torch.randn(input_nc, 64, 64)) # image A (source font)
                     for i in range(100)] # Mock 100 samples
        print(f"Mock DatasetFromObj initialized from {path}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

def ensure_dir(path): os.makedirs(path, exist_ok=True); print(f"Mock ensure_dir({path})")
def setup_google_drive_service(): print("Mock setup_google_drive_service()"); return None
def clear_google_drive_trash(service): print("Mock clear_google_drive_trash()")
# --- End Mock Definitions ---


def train(args):
    """訓練主函數"""
    print("--- Starting Training ---")
    print(f"Arguments: {args}")

    # --- Initialization ---
    random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)
    if args.gpu_ids:
        torch.cuda.manual_seed_all(args.random_seed) # Seed all GPUs if used

    # Directories
    data_dir = args.data_dir or os.path.join(args.experiment_dir, "data")
    checkpoint_dir = args.checkpoint_dir or os.path.join(args.experiment_dir, "checkpoint")
    ensure_dir(checkpoint_dir) # Create checkpoint dir if it doesn't exist

    print(f"Data Directory: {data_dir}")
    print(f"Checkpoint Directory: {checkpoint_dir}")

    # Google Drive (Optional)
    drive_service = setup_google_drive_service() if args.checkpoint_only_last else None

    # --- Model Initialization ---
    # Instantiate the updated Zi2ZiModel with arguments from parser
    model = Zi2ZiModel(
        input_nc=args.input_nc,
        embedding_num=args.embedding_num, # Should be 2 for Noto->Zen task
        embedding_dim=args.embedding_dim,
        ngf=args.ngf,
        ndf=args.ndf,
        Lconst_penalty=args.Lconst_penalty,
        Lcategory_penalty=args.Lcategory_penalty,
        L1_penalty=args.L1_penalty,
        perceptual_weight=args.perceptual_weight, # Pass from args
        gradient_penalty_weight=args.gradient_penalty_weight, # Pass from args
        save_dir=checkpoint_dir,
        gpu_ids=args.gpu_ids,
        is_training=True, # Set explicitly for training
        self_attention=args.self_attention,   # Pass attention flag
        attention_type=args.attention_type, # Pass attention type
        epoch=args.epoch, # Pass total epochs for scheduler
        # g_blur=args.g_blur, # Pass blur flags if used by model
        # d_blur=args.d_blur,
        lr=args.lr,
        beta1=args.beta1, # Pass optimizer params
        weight_decay=args.weight_decay,
        gradient_clip=args.gradient_clip, # Pass gradient clip value
        norm_type=args.norm_type # Pass norm type
    )

    # model.print_networks(True) # Optional: Print network details

    # --- Resume Training ---
    start_epoch = 0
    global_steps = 0
    if args.resume:
        try:
            # Assuming load_networks takes an epoch or step label string
            model.load_networks(args.resume)
            # Optionally parse the epoch/step from the resume string to continue numbering
            # For simplicity, we'll restart step counting unless loading optimizer state too
            print(f"Resumed model from step/epoch: {args.resume}")
            # If loading optimizer/scheduler state, you'd need to load those too and potentially update start_epoch/global_steps
        except Exception as e:
            print(f"Could not resume from {args.resume}. Starting from scratch. Error: {e}")

    # --- Data Loading ---
    train_dataset = DatasetFromObj(os.path.join(data_dir, 'train.obj'), input_nc=args.input_nc)
    # Consider adding num_workers for faster data loading
    dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True if args.gpu_ids else False)
    total_batches = math.ceil(len(train_dataset) / args.batch_size)
    print(f"Dataset length: {len(train_dataset)}, Batches per epoch: {total_batches}")

    # --- Training Loop ---
    start_time = time.time()
    print(f"Starting training from epoch {start_epoch}...")

    for epoch in range(start_epoch, args.epoch):
        epoch_start_time = time.time()
        print(f"\n--- Epoch {epoch}/{args.epoch - 1} ---")

        for batch_id, batch_data in enumerate(dataloader):
            current_step_time = time.time()

            # Set model input (ensure correct order: labels, real_A, real_B)
            # Based on DatasetFromObj mock: batch = (label, image_B, image_A)
            labels, image_B, image_A = batch_data
            # Pass to model: labels, real_A (source), real_B (target)
            model_input_data = {'label': labels, 'A': image_A, 'B': image_B}
            model.set_input(model_input_data) # Pass the dictionary

            # Optimize parameters and get losses
            # optimize_parameters now returns a dictionary
            losses = model.optimize_parameters(args.use_autocast)
            global_steps += 1

            # --- Logging ---
            if batch_id % 100 == 0: # Log every 100 batches
                elapsed_batch_time = time.time() - current_step_time
                total_elapsed_time = time.time() - start_time
                # Use keys from the losses dictionary
                print(
                    f"Epoch: [{epoch:3d}/{args.epoch-1}], Batch: [{batch_id:4d}/{total_batches-1}] "
                    f"Step: {global_steps} | Time/Batch: {elapsed_batch_time:.2f}s | Total Time: {total_elapsed_time:.0f}s\n"
                    f"  Losses: D_Total={losses['D_total']:.4f}, G_Total={losses['G_total']:.4f}\n"
                    f"    G: [Adv={losses['G_adv']:.4f}, Cat={losses['G_category']:.4f}, L1={losses['G_L1']:.4f}, "
                    f"Const={losses['G_const']:.4f}, FM={losses['G_FM']:.4f}, Perc={losses['G_perceptual']:.4f}]\n"
                    f"    D: [Adv={losses['D_adv']:.4f}, Cat={losses['D_category']:.4f}, GP={losses['D_gp']:.4f}]"
                )

            # --- Checkpointing ---
            if global_steps % args.checkpoint_steps == 0:
                if global_steps >= args.checkpoint_steps_after:
                    print(f"\nSaving checkpoint at step {global_steps}...")
                    model.save_networks(global_steps) # Save with step label
                    print(f"Checkpoint saved.")

                    # --- Clean up old checkpoints (Optional: only keep last) ---
                    if args.checkpoint_only_last and drive_service is None: # Only delete local if not using Drive backup
                        # Find checkpoints older than the current one
                        current_step_label = global_steps
                        for step_label in range(args.checkpoint_steps_after, current_step_label, args.checkpoint_steps):
                            for net_type in ["D", "G"]:
                                old_filepath = os.path.join(checkpoint_dir, f"{step_label}_net_{net_type}.pth")
                                if os.path.isfile(old_filepath):
                                    try:
                                        os.remove(old_filepath)
                                        print(f"  Removed old checkpoint: {old_filepath}")
                                    except OSError as e:
                                        print(f"  Error removing {old_filepath}: {e}")
                        # Note: GDrive cleanup logic might need refinement based on API usage
                        # clear_google_drive_trash(drive_service) # This might be too aggressive

                else:
                    print(f"\nCheckpoint step {global_steps} reached, but saving starts after step {args.checkpoint_steps_after}.")


        # --- End of Epoch ---
        epoch_time = time.time() - epoch_start_time
        print(f"\n--- End of Epoch {epoch} --- Time: {epoch_time:.2f}s ---")

        # Update Learning Rate Schedulers
        model.scheduler_G.step()
        model.scheduler_D.step()
        print(f"LR Scheduler stepped. Current LR G: {model.scheduler_G.get_last_lr()[0]:.6f}, LR D: {model.scheduler_D.get_last_lr()[0]:.6f}")

        # --- Save model at end of epoch? (Optional) ---
        # if (epoch + 1) % args.save_epoch_freq == 0: # Add arg save_epoch_freq if needed
        #     model.save_networks(f'epoch_{epoch+1}')


    # --- End of Training ---
    print("\n--- Training Finished ---")
    # Save the final model state
    print("Saving final model...")
    model.save_networks('latest') # Save with 'latest' label
    print("Final model saved.")

    total_training_time = time.time() - start_time
    print(f"Total Training Time: {total_training_time:.2f} seconds")

# --- Argparse Setup (Example) ---
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Zi2Zi Training Script')

    # --- Paths ---
    parser.add_argument('--experiment_dir', required=True, help='Base directory for the experiment')
    parser.add_argument('--data_dir', default=None, help='Path to training data directory (default: <experiment_dir>/data)')
    parser.add_argument('--checkpoint_dir', default=None, help='Directory to save checkpoints (default: <experiment_dir>/checkpoint)')

    # --- Model ---
    parser.add_argument('--input_nc', type=int, default=1, help='Input image channels (e.g., 1 for grayscale)')
    parser.add_argument('--embedding_num', type=int, default=2, help='Number of distinct styles/fonts (MUST be 2 for Noto->Zen)')
    parser.add_argument('--embedding_dim', type=int, default=128, help='Dimension of style embedding vector')
    parser.add_argument('--ngf', type=int, default=64, help='Num generator filters in the first conv layer')
    parser.add_argument('--ndf', type=int, default=64, help='Num discriminator filters in the first conv layer')
    parser.add_argument('--norm_type', type=str, default='instance', choices=['instance', 'batch'], help='Normalization layer type for generator [instance | batch]')
    parser.add_argument('--self_attention', action='store_true', help='Enable attention mechanism in generator (at specific layers)')
    parser.add_argument('--attention_type', type=str, default='linear', choices=['linear', 'standard', 'none'], help='Type of attention if enabled [linear | standard | none]')

    # --- Training Params ---
    parser.add_argument('--epoch', type=int, default=100, help='Total number of training epochs')
    parser.add_argument('--batch_size', type=int, default=16, help='Training batch size')
    parser.add_argument('--lr', type=float, default=0.0002, help='Initial learning rate for AdamW')
    parser.add_argument('--beta1', type=float, default=0.5, help='AdamW optimizer beta1 parameter')
    parser.add_argument('--weight_decay', type=float, default=1e-5, help='AdamW optimizer weight decay')
    parser.add_argument('--use_autocast', action='store_true', help='Use Automatic Mixed Precision (AMP) for training')
    parser.add_argument('--gradient_clip', type=float, default=1.0, help='Max gradient norm for generator update')
    # parser.add_argument('--schedule', type=int, default=50, help='Deprecated for CosineAnnealingLR')

    # --- Loss Weights ---
    parser.add_argument('--Lconst_penalty', type=float, default=15.0, help='Weight for embedding consistency loss')
    parser.add_argument('--Lcategory_penalty', type=float, default=1.0, help='Weight for category classification loss (D and G)')
    parser.add_argument('--L1_penalty', type=float, default=100.0, help='Weight for L1 pixel loss between fake_B and real_B')
    parser.add_argument('--perceptual_weight', type=float, default=10.0, help='Weight for VGG perceptual loss')
    parser.add_argument('--gradient_penalty_weight', type=float, default=10.0, help='Weight for WGAN-GP gradient penalty')

    # --- Checkpointing & Resume ---
    parser.add_argument('--resume', type=str, default=None, help='Checkpoint label (e.g., <step> or latest) to resume from')
    parser.add_argument('--checkpoint_steps', type=int, default=10000, help='Frequency (in steps) to save checkpoints')
    parser.add_argument('--checkpoint_steps_after', type=int, default=0, help='Start saving checkpoints only after this many steps')
    parser.add_argument('--checkpoint_only_last', action='store_true', help='If set, delete older checkpoints locally (use with caution or GDrive)')

    # --- Misc ---
    parser.add_argument('--random_seed', type=int, default=42, help='Seed for random number generators')
    parser.add_argument('--gpu_ids', type=str, default='0', help='Comma-separated list of GPU IDs to use (e.g., "0" or "0,1"). Use "" or "-1" for CPU.')
    # parser.add_argument('--g_blur', action='store_true', help='Apply blur in Generator (if implemented)')
    # parser.add_argument('--d_blur', action='store_true', help='Apply blur in Discriminator')

    args = parser.parse_args()

    # --- Process GPU IDs ---
    if args.gpu_ids and args.gpu_ids != '-1':
        str_ids = args.gpu_ids.split(',')
        args.gpu_ids = [int(str_id) for str_id in str_ids if str_id.isdigit()]
        if not args.gpu_ids:
             args.gpu_ids = None # Use CPU if parsing failed or empty string
    else:
        args.gpu_ids = None # Use CPU

    if args.gpu_ids and torch.cuda.is_available():
         torch.cuda.set_device(args.gpu_ids[0]) # Set default device
         print(f"Using GPUs: {args.gpu_ids}")
    elif args.gpu_ids and not torch.cuda.is_available():
         print(f"WARNING: GPUs {args.gpu_ids} requested but CUDA not available! Using CPU.")
         args.gpu_ids = None
    else:
         print("Using CPU.")

    # --- Sanity Checks ---
    if args.embedding_num != 2:
        print(f"ERROR: This task (Noto Sans -> Zen Maru Gothic) requires --embedding_num=2. Found {args.embedding_num}.")
        exit(1) # Exit if embedding_num is wrong for the task
    if not args.self_attention and args.attention_type != 'none':
        print(f"INFO: --attention_type is set to '{args.attention_type}' but --self_attention is not specified. Attention will NOT be used.")
        args.attention_type = 'none' # Ensure consistency
    if args.self_attention and args.attention_type == 'none':
        print(f"WARNING: --self_attention is specified but --attention_type is 'none'. Attention will NOT be used.")
        args.self_attention = False # Ensure consistency

    # --- Run Training ---
    train(args)

主要修改:

  1. 模型初始化 (Zi2ZiModel)
    • 確認傳遞了所有必要的參數,特別是 embedding_num (應設為 2)、attention_typeperceptual_weight, gradient_penalty_weight 等。
    • 移除了看起來未使用的 residual_block 參數。
    • is_training=True 被明確設定。
  2. 損失處理
    • model.optimize_parameters() 現在返回一個字典 losses
    • 更新了 print 語句,使用 losses 字典中的鍵來顯示各個損失項 (losses['D_total'], losses['G_total'], losses['G_L1'], etc.)。
  3. 學習率調度器 (Scheduler)
    • 移除了對不存在的 model.update_lr() 的調用。
    • 在每個 epoch 結束時,調用 model.scheduler_G.step()model.scheduler_D.step() 來更新學習率(這對於 CosineAnnealingLR 是正確的用法)。
  4. 輸入設置 (set_input)
    • 稍微修改了 set_input 的調用方式,傳遞一個字典 model_input_data,這需要對應修改 Zi2ZiModel.set_input 方法來接收字典(或者保持原樣,確保 batch[0], batch[2], batch[1] 的順序正確對應 labels, real_A, real_B)。注意: 上面的 Zi2ZiModel 佔位符沒有修改 set_input,假設之前的 set_input(labels, real_A, real_B) 仍然有效,並且數據加載順序是 labels=batch[0], real_A=batch[2], real_B=batch[1]。如果你的 DatasetFromObj 返回順序不同,請調整這裡。
  5. 參數解析 (argparse)
    • 添加了 attention_typeperceptual_weightgradient_penalty_weight 等參數。
    • embedding_num 設置了 default=2 並在運行前添加了檢查,確保其為 2。
    • 改進了 gpu_ids 的處理。
    • 添加了關於 self_attentionattention_type 組合使用的檢查/提示。
  6. 代碼結構和註釋
    • 添加了更多註釋和打印語句以提高可讀性。
    • 包含了基本的 argparse 設置和 if __name__ == '__main__': 結構。
    • 使用了佔位符(Mock/Placeholder)來表示未提供的類(Zi2ZiModel, DatasetFromObj, utils),你需要將它們替換為你實際的導入。
    • 修正了 Discriminator 中第一個卷積層的輸入通道數。
    • 修正了 PerceptualLoss 中的 VGG 輸入歸一化邏輯。
    • 改進了 checkpoint 清理邏輯,使其更安全。

你需要將這個 train 函數與你更新後的 Zi2ZiModel、數據集類和工具函數放在同一個項目結構中,並確保導入路徑正確。然後,你可以通過命令行傳遞參數來運行訓練。例如:

Bash

python trainer.py --experiment_dir ./experiments/noto_zen_linear \
                  --gpu_ids 0 \
                  --batch_size 8 \
                  --epoch 150 \
                  --lr 0.0002 \
                  --self_attention \
                  --attention_type linear \
                  --L1_penalty 100 \
                  --perceptual_weight 10 \
                  --Lconst_penalty 15 \
                  --use_autocast

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *