修改zi2zi 使用調節模組(Feature-wise Linear Modulation, FiLM)將 style embedding 用於調節 UNet 各層的 feature maps

你這邊的思路完全正確,而且你也很接近完成轉換了!我幫你整理一下邏輯、順便釐清一點實作細節,讓你的 FiLM 應用更順暢。


✅ 你的目標

你想讓風格向量(style_vec)透過 FiLM 的方式(Scale + Shift)調節 UNet 中每一層的特徵圖,也就是:

out = gamma * feat + beta

其中 gammabeta 是從 style_vec 透過一個 MLP 變換出來的。


🔧 你的設計(很接近完成)

你已經有一個 StyleModulation 類別,這裡建議幾個修正讓它更完整穩定:

🔄 修正版 StyleModulation

import torch
import torch.nn as nn

class StyleModulation(nn.Module):
    def __init__(self, in_channels, style_dim):
        super().__init__()
        self.norm = nn.InstanceNorm2d(in_channels, affine=False)

        self.style_fc = nn.Sequential(
            nn.Linear(style_dim, in_channels * 2),
            nn.ReLU(),
            nn.Linear(in_channels * 2, in_channels * 2)  # 分成 gamma 和 beta
        )

    def forward(self, x, style_vec):
        # B, C, H, W = x.shape
        x_norm = self.norm(x)

        # 得到 gamma 和 beta,形狀為 (B, 2*C)
        style_params = self.style_fc(style_vec)
        gamma, beta = style_params.chunk(2, dim=1)

        # reshape 成 (B, C, 1, 1) 才能和 x 對應
        gamma = gamma.unsqueeze(2).unsqueeze(3)
        beta = beta.unsqueeze(2).unsqueeze(3)

        # 套用 FiLM
        return gamma * x_norm + beta

🔗 如何在 UNet 裡用這個模組?

你可以在 UNet 的 encoder 或 decoder block 中間插入:

self.style_mod = StyleModulation(in_channels, style_dim)

並在 forward 時傳入:

x = self.style_mod(x, style_vec)

💡 小提醒

  • 使用 InstanceNorm2d 而不是 BatchNorm,因為風格轉換任務中個別樣本風格的獨立性更重要。
  • style_dim 是你風格向量的長度,像是 128 或 256。
  • 如果你使用殘差模組,可以把 StyleModulation 放在殘差裡中間那層 conv 之後(或用多個)。

🧪 可以測試看看

  • 拿相同的圖,但給不同 style_vec,應該能看到輸出風格不同。
  • 如果 style_vec 是固定值(像全 1 或全 0),你應該能觀察到風格不變或只有微調。

你這套模組已經很棒了,如果你有 UNet 的某層架構想套用,我也可以幫你示範怎麼插入 StyleModulation 進去,讓整合更順。要不要我幫你搭一層 UNet block 當例子?


為了實現將 風格嵌入透過 FiLM(Feature-wise Linear Modulation) 應用到 UNet 的各層特徵圖中,我們將修改以下幾個部分:


✅ 本次修改概要

  1. 新增 StyleModulation 模組
  2. UnetSkipConnectionBlock 中加入 style modulation
  3. 移除 innermost 層的風格拼接
  4. 維持其他架構與 loss 不變

🧩 修改 1:新增 StyleModulation 模組

class StyleModulation(nn.Module):
    def __init__(self, channels, style_dim):
        super(StyleModulation, self).__init__()
        self.norm = nn.InstanceNorm2d(channels, affine=False)
        self.style_fc = nn.Sequential(
            nn.Linear(style_dim, channels * 2),
            nn.ReLU(),
            nn.Linear(channels * 2, channels * 2)
        )

    def forward(self, x, style):
        x = self.norm(x)
        style_params = self.style_fc(style)  # (B, 2C)
        gamma, beta = style_params.chunk(2, dim=1)  # 各為 (B, C)
        gamma = gamma.unsqueeze(2).unsqueeze(3)  # (B, C, 1, 1)
        beta = beta.unsqueeze(2).unsqueeze(3)
        return gamma * x + beta

🧩 修改 2:調整 UnetSkipConnectionBlock

只需修改這個 class 即可,這是完整異動後的程式碼

class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None,
                 norm_layer=nn.InstanceNorm2d, layer=0, embedding_dim=128,
                 use_dropout=False, self_attention=False, blur=False, outermost=False, innermost=False):
        super(UnetSkipConnectionBlock, self).__init__()

        self.outermost = outermost
        self.innermost = innermost
        self.embedding_dim = embedding_dim
        use_bias = norm_layer != nn.BatchNorm2d

        if input_nc is None:
            input_nc = outer_nc

        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
        nn.init.kaiming_normal_(downconv.weight, nonlinearity='leaky_relu')
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc)

        uprelu = nn.ReLU(inplace=False)
        upnorm = norm_layer(outer_nc)

        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            nn.init.kaiming_normal_(upconv.weight)
            self.down = nn.Sequential(downconv)
            self.up = nn.Sequential(uprelu, upconv, nn.Tanh())
            self.style_mod = None  # outermost 不調整
        elif innermost:
            upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            nn.init.kaiming_normal_(upconv.weight)
            self.down = nn.Sequential(downrelu, downconv)
            self.up = nn.Sequential(uprelu, upconv, upnorm)
            self.style_mod = StyleModulation(inner_nc, embedding_dim)
        else:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            nn.init.kaiming_normal_(upconv.weight)
            self.down = nn.Sequential(downrelu, downconv, downnorm)
            self.up = nn.Sequential(uprelu, upconv, upnorm)
            self.style_mod = StyleModulation(inner_nc, embedding_dim)
            if use_dropout:
                self.up.add_module("dropout", nn.Dropout(0.3))

        self.submodule = submodule
        self.self_attn = SelfAttention(inner_nc) if self_attention and layer in [4, 6] else None
        self.res_skip = ResSkip(outer_nc) if not outermost and not innermost else None

    def _process_submodule(self, encoded, style):
        if self.submodule:
            return self.submodule(encoded, style)
        else:
            return encoded, None

    def _interpolate_if_needed(self, decoded, x):
        return F.interpolate(decoded, size=x.shape[2:], mode='bilinear', align_corners=False) if decoded.shape[2:] != x.shape[2:] else decoded

    def forward(self, x, style=None):
        encoded = self.down(x)

        if self.self_attn:
            encoded = self.self_attn(encoded)

        # Apply FiLM style modulation if available
        if self.style_mod and style is not None:
            encoded = self.style_mod(encoded, style)

        if self.innermost:
            decoded = self.up(encoded)
            decoded = self._interpolate_if_needed(decoded, x)
            return torch.cat([x, decoded], 1), encoded.view(x.shape[0], -1)

        sub_output, encoded_real_A = self._process_submodule(encoded, style)

        decoded = self.up(sub_output)
        decoded = self._interpolate_if_needed(decoded, x)

        if self.res_skip:
            decoded = self.res_skip(decoded)

        if self.outermost:
            return decoded, encoded_real_A
        else:
            return torch.cat([x, decoded], 1), encoded_real_A

📦 其他模組不用動

  • UNetGenerator 中會自動傳遞 style 至各個 block。
  • Embedding 仍保持原樣:
self.embedder = nn.Embedding(embedding_num, embedding_dim)

✅ 測試提醒

  • 你現在只需用兩種風格(Noto Sans / Zen Maru Gothic),設定 embedding_num=2 即可。
  • style label 為 01,輸入到 embedder,會變成 style vector 傳入各層進行調節。

如果你想讓某些層不要加 FiLM(例如最底層只用 Self-Attention),也可以根據 layer 來做控制,要我幫你加判斷式嗎?還是這樣已經符合你需求?


gemini 2.5 pro 的回覆:

Okay, let’s modify the U-Net generator to use FiLM (Feature-wise Linear Modulation) for style conditioning instead of concatenation at the bottleneck. This approach allows the style embedding to influence feature maps throughout the network.

1. Define the FiLM Layer

First, we need a small module that takes the style embedding and generates the gamma (scale) and beta (shift) parameters for FiLM.

Python

import torch
import torch.nn as nn
import torch.nn.functional as F
# Assume other necessary imports like models, VGG16_Weights, T are present
from torchvision import models # Example import
from torchvision.models import VGG16_Weights # Example import
import torchvision.transforms as T # Example import

class FiLMGenerator(nn.Module):
    """Generates FiLM parameters (gamma and beta) from a style embedding."""
    def __init__(self, embedding_dim, feature_channels):
        super(FiLMGenerator, self).__init__()
        # Linear layer to project embedding to gamma and beta dimensions
        self.layer = nn.Linear(embedding_dim, feature_channels * 2)
        # Initialize gamma close to 1 and beta close to 0
        self.layer.weight.data.normal_(0, 0.02) # Small initial weights
        self.layer.bias.data.zero_() # Zero initial bias

    def forward(self, style_embedding):
        # Generate gamma and beta
        params = self.layer(style_embedding)
        # Split into gamma and beta, reshape for broadcasting
        gamma, beta = params.chunk(2, dim=-1) # Split along the last dimension
        # Reshape to (Batch, Channels, 1, 1) for broadcasting over feature maps
        gamma = gamma.unsqueeze(-1).unsqueeze(-1)
        beta = beta.unsqueeze(-1).unsqueeze(-1)
        return gamma, beta

class FiLMLayer(nn.Module):
    """Applies FiLM modulation to features."""
    def __init__(self, embedding_dim, feature_channels):
        super(FiLMLayer, self).__init__()
        self.film_generator = FiLMGenerator(embedding_dim, feature_channels)

    def forward(self, features, style_embedding):
        gamma, beta = self.film_generator(style_embedding)
        return gamma * features + beta

2. Modify UnetSkipConnectionBlock

We’ll integrate FiLMLayer into the downsampling and upsampling paths of each block.

Python

class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None,
                 norm_layer=nn.InstanceNorm2d, layer=0, embedding_dim=128, # Added embedding_dim
                 use_dropout=False, self_attention=False, blur=False, outermost=False, innermost=False):
        super(UnetSkipConnectionBlock, self).__init__()

        self.outermost = outermost
        self.innermost = innermost
        use_bias = norm_layer != nn.BatchNorm2d # InstanceNorm typically doesn't use bias after conv

        if input_nc is None:
            input_nc = outer_nc

        # --- Downsampling Path ---
        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
        nn.init.kaiming_normal_(downconv.weight, nonlinearity='leaky_relu')
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc, affine=False) # Use affine=False if applying FiLM after norm
        self.film_down = FiLMLayer(embedding_dim, inner_nc) # FiLM after norm, before ReLU

        # --- Upsampling Path ---
        uprelu = nn.ReLU(True) # ReLU before FiLM
        upnorm = norm_layer(outer_nc, affine=False) # Use affine=False if applying FiLM after norm
        self.film_up = FiLMLayer(embedding_dim, outer_nc) # FiLM after norm, before dropout/final activation

        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1) # Bias typically true here
            nn.init.kaiming_normal_(upconv.weight)
            # No norm/relu/film on down path of outermost
            self.down = nn.Sequential(downconv)
            # Apply ReLU, ConvTranspose, Tanh. No upnorm/film_up here.
            self.up = nn.Sequential(uprelu, upconv, nn.Tanh())
            self.film_down = None # No FiLM needed in outermost down path
            self.film_up = None   # No FiLM needed in outermost up path (applied before Tanh if desired, but usually not)

        elif innermost:
            # No style concatenation in upconv anymore
            upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            nn.init.kaiming_normal_(upconv.weight)
            # Down path: ReLU, Conv, Norm, FiLM
            self.down = nn.Sequential(downrelu, downconv, downnorm) # FiLM applied in forward
            # Up path: ReLU, ConvTranspose, Norm, FiLM
            self.up = nn.Sequential(uprelu, upconv, upnorm) # FiLM applied in forward
            # Note: innermost doesn't have submodule

        else: # Intermediate layers
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            nn.init.kaiming_normal_(upconv.weight)
            # Down path: ReLU, Conv, Norm, FiLM
            self.down = nn.Sequential(downrelu, downconv, downnorm) # FiLM applied in forward
            # Up path: ReLU, ConvTranspose, Norm, FiLM
            self.up = nn.Sequential(uprelu, upconv, upnorm) # FiLM applied in forward
            if use_dropout:
                # Apply dropout *after* FiLM modulation on the up-path features
                # We need to insert it manually in forward or adjust self.up
                self.dropout = nn.Dropout(0.3)
            else:
                self.dropout = None

        self.submodule = submodule
        self.self_attention = SelfAttention(inner_nc) if self_attention and layer in [4, 6] else None
        # ResSkip is applied after upsampling and FiLM, before concatenation. Keep it.
        # But ensure its channel count matches outer_nc.
        self.res_skip = ResSkip(outer_nc) if not outermost else None # Innermost doesn't need it before cat

    def _process_submodule(self, encoded, style):
        # Pass style embedding down recursively
        if self.submodule:
            return self.submodule(encoded, style)
        else: # Innermost block reached
            # For const_loss, return the modulated features before upsampling
            return encoded, encoded # Pass modulated encoded features as 'encoded_real_A' from bottom

    def _interpolate_if_needed(self, decoded, x):
        # This function remains the same
        if decoded.shape[2:] != x.shape[2:]:
             return F.interpolate(decoded, size=x.shape[2:], mode='bilinear', align_corners=False)
        return decoded

    def forward(self, x, style): # Style embedding is now mandatory (except maybe outermost?)
        # --- Down Path ---
        if self.outermost:
            encoded = self.down(x) # No ReLU/Norm/FiLM here
        else:
            # Apply down path blocks: LeakyReLU, Conv, Norm
            down_features = self.down(x)
            # Apply FiLM modulation
            if self.film_down:
                 modulated_down = self.film_down(down_features, style)
            else: # Should only happen if something is wrong, but defensively handle
                 modulated_down = down_features

            encoded = modulated_down # Features ready for Self-Attention or submodule

        # Apply Self-Attention if specified
        if self.self_attention:
            encoded = self.self_attention(encoded)

        # --- Process Submodule or Innermost Logic ---
        if self.innermost:
             # Style is already applied via FiLM in self.down
             # The 'encoded' here is the final modulated bottleneck feature map
             # For consistency loss, we'll use this 'encoded' state.
             encoded_for_const_loss = encoded.view(x.shape[0], -1) # Flatten for potential later comparison

             # Apply up path blocks: ReLU, ConvTranspose, Norm
             up_features = self.up(encoded)
             # Apply FiLM modulation
             if self.film_up:
                  modulated_up = self.film_up(up_features, style)
             else:
                  modulated_up = up_features # Should not happen in innermost

             decoded = modulated_up
             # Interpolate and apply ResSkip (if applicable, though typically not in innermost)
             decoded = self._interpolate_if_needed(decoded, x)
             if self.res_skip:
                  decoded = self.res_skip(decoded) # ResSkip now operates on outer_nc

             # Return structure: Concatenated output, Bottleneck features for const_loss
             return torch.cat([x, decoded], 1), encoded_for_const_loss

        else: # Not innermost
             # Recursively call submodule, passing modulated 'encoded' features and style
             sub_output, encoded_for_const_loss = self._process_submodule(encoded, style)

             # --- Up Path (for non-innermost layers) ---
             # Apply up path blocks: ReLU, ConvTranspose, Norm
             up_features = self.up(sub_output) # sub_output contains skip connection from below + upsampled features
             # Apply FiLM modulation
             if self.film_up:
                 modulated_up = self.film_up(up_features, style)
             else: # Outermost case
                 modulated_up = up_features

             # Apply dropout if configured (after FiLM)
             if hasattr(self, 'dropout') and self.dropout is not None:
                 modulated_up = self.dropout(modulated_up)

             decoded = modulated_up
             # Interpolate and apply ResSkip
             decoded = self._interpolate_if_needed(decoded, x)
             if self.res_skip:
                 decoded = self.res_skip(decoded) # ResSkip operates on outer_nc

             # --- Output ---
             if self.outermost:
                 # Final output image, and the bottleneck features from the bottom
                 return decoded, encoded_for_const_loss
             else:
                 # Concatenate skip connection (original x) with decoded features
                 return torch.cat([x, decoded], 1), encoded_for_const_loss

3. Modify UNetGenerator

Update the generator to handle the new block structure and style passing.

Python

class UNetGenerator(nn.Module):
    # Set embedding_num=2 for the 2 styles Noto Sans -> Zen Maru Gothic
    def __init__(self, input_nc=1, output_nc=1, num_downs=8, ngf=64, embedding_num=2, embedding_dim=128,
                 norm_layer=nn.InstanceNorm2d, use_dropout=False, self_attention=False, blur=False):
        super(UNetGenerator, self).__init__()

        # --- Build the U-Net with FiLM ---
        # Innermost block
        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None,
                                             norm_layer=norm_layer, layer=1, embedding_dim=embedding_dim, # Pass embedding_dim
                                             self_attention=self_attention, blur=blur, innermost=True)

        # Middle layers (if num_downs > 5)
        for index in range(num_downs - 5): # Example: num_downs=8 -> range(3) -> indices 0, 1, 2 -> layers 2, 3, 4
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block,
                                                 norm_layer=norm_layer, layer=index+2, embedding_dim=embedding_dim, # Pass embedding_dim
                                                 use_dropout=use_dropout, self_attention=self_attention, blur=blur)

        # Upper layers (closer to output)
        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block,
                                             norm_layer=norm_layer, layer=num_downs-3, embedding_dim=embedding_dim, # Pass embedding_dim
                                             self_attention=self_attention, blur=blur) # layer index example: 8-3=5
        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block,
                                             norm_layer=norm_layer, layer=num_downs-2, embedding_dim=embedding_dim, # Pass embedding_dim
                                             self_attention=self_attention, blur=blur) # layer index example: 8-2=6
        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block,
                                             norm_layer=norm_layer, layer=num_downs-1, embedding_dim=embedding_dim, # Pass embedding_dim
                                             self_attention=self_attention, blur=blur) # layer index example: 8-1=7

        # Outermost block
        self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block,
                                             norm_layer=norm_layer, layer=num_downs, embedding_dim=embedding_dim, # Pass embedding_dim
                                             self_attention=self_attention, blur=blur, outermost=True) # layer index example: 8

        # Embedding layer for the 2 styles
        self.embedder = nn.Embedding(embedding_num, embedding_dim)

    def _prepare_style(self, style_or_label):
        # This function remains the same
        if style_or_label is not None and isinstance(style_or_label, torch.Tensor) and 'LongTensor' in style_or_label.type():
             # Ensure labels are within the embedding_num range (0 or 1)
             if style_or_label.max() >= self.embedder.num_embeddings:
                  raise ValueError(f"Label index {style_or_label.max()} is out of bounds for embedding_num={self.embedder.num_embeddings}")
             return self.embedder(style_or_label)
        else:
             # Allows passing pre-computed embeddings directly if needed
             return style_or_label

    def forward(self, x, style_or_label): # Style is now required
        if style_or_label is None:
            raise ValueError("Style embedding or label must be provided for FiLM U-Net.")
        style = self._prepare_style(style_or_label)
        # self.model now returns (output_image, bottleneck_features)
        fake_B, encoded_bottleneck = self.model(x, style)
        return fake_B, encoded_bottleneck # Return bottleneck features for const_loss

    def encode(self, x, style_or_label): # Style is now required
        if style_or_label is None:
            raise ValueError("Style embedding or label must be provided for FiLM U-Net encoding.")
        style = self._prepare_style(style_or_label)
        # Need to call the model and extract the second return value
        _, encoded_bottleneck = self.model(x, style)
        return encoded_bottleneck # Return the bottleneck features

4. Modify Zi2ZiModel

Adjust the initialization and potentially the loss calculations if they relied heavily on the specific structure of the concatenated encoding.

Python

class Zi2ZiModel:
    # Set embedding_num=2 by default if appropriate, or pass it during instantiation
    def __init__(self, input_nc=1, embedding_num=2, embedding_dim=128, ngf=64, ndf=64,
                 Lconst_penalty=10, Lcategory_penalty=1, L1_penalty=100,
                 schedule=10, lr=0.001, gpu_ids=None, save_dir='.', is_training=True,
                 self_attention=False, # residual_block param seems unused? Removed for clarity.
                 weight_decay=1e-5, beta1=0.5, g_blur=False, d_blur=False, epoch=40,
                 gradient_clip=0.5, norm_type="instance"): # Removed residual_block param

        # ... (Keep most of the init code) ...
        self.embedding_num = embedding_num # Ensure this is 2 for your task
        # ...

        # Assume setup, set_input, compute_*, backward_D, backward_G, optimize_parameters remain largely the same
        # BUT VERIFY THE LOSSES:

    def setup(self):
        # ... (Setup code) ...
        # Make sure UNetGenerator is initialized with the correct embedding_num
        self.netG = UNetGenerator(
            input_nc=self.input_nc,
            output_nc=self.input_nc,
            ngf=self.ngf,
            use_dropout=self.use_dropout,
            embedding_num=self.embedding_num, # Should be 2
            embedding_dim=self.embedding_dim,
            self_attention=self.self_attention,
            blur=self.g_blur,
            norm_layer=norm_layer # Make sure norm_layer is defined correctly based on self.norm_type
        )
        # Discriminator setup remains the same, but ensure embedding_num matches
        self.netD = Discriminator(
            input_nc=2 * self.input_nc, # Still takes concatenated input+output
            embedding_num=self.embedding_num, # Should be 2
            ndf=self.ndf,
            blur=self.d_blur,
            norm_layer=nn.BatchNorm2d # Or InstanceNorm? Check original D design
        )
        # ... (Rest of setup: init_net, optimizers, schedulers, losses) ...

        # CategoryLoss needs embedding_num=2
        self.category_loss = CategoryLoss(self.embedding_num)
        # ... (Other losses) ...


    def forward(self):
        # The forward pass of UNetGenerator now requires the label/style
        self.fake_B, self.encoded_real_A = self.netG(self.real_A, self.labels)
        # The encode method also requires the label/style
        self.encoded_fake_B = self.netG.encode(self.fake_B, self.labels)
        # encoded_real_A and encoded_fake_B are now the flattened bottleneck features
        # after FiLM modulation but before the innermost upsampling block.

    def backward_G(self, no_target_source=False):
        # ... (Calculate fake_AB, real_AB, get D outputs) ...

        # Consistency Loss: Compare the FiLM-modulated bottleneck features
        # Ensure encoded_real_A and encoded_fake_B have compatible shapes for MSE
        # They should both be [Batch, FeatureDim] from the view operation in UnetSkipConnectionBlock innermost
        const_loss = self.Lconst_penalty * self.mse(self.encoded_real_A, self.encoded_fake_B)

        # L1 Loss (Pixel-wise reconstruction) - Remains the same
        l1_loss = self.L1_penalty * self.l1_loss(self.fake_B, self.real_B)

        # Category Loss for Generator - Remains the same conceptually
        # Need fake_category_logits from D
        _, fake_category_logits = self.netD(torch.cat([self.real_A, self.fake_B], 1)) # Pass fake_B to D
        fake_category_loss = self.Lcategory_penalty * self.category_loss(fake_category_logits, self.labels)

        # Adversarial Loss for Generator - Remains the same conceptually
        # Need fake_D_logits and real_D_logits
        fake_D_logits, _ = self.netD(torch.cat([self.real_A, self.fake_B], 1))
        real_D_logits, _ = self.netD(torch.cat([self.real_A, self.real_B], 1)) # D needs real pair too
        # Using the relativistic loss form from before:
        g_loss_adv = -torch.mean(F.logsigmoid(fake_D_logits - real_D_logits))
        # Or the symmetric version:
        # g_loss_adv = -0.5 * (torch.mean(F.logsigmoid(real_D_logits - fake_D_logits)) +
        #                    torch.mean(F.logsigmoid(fake_D_logits - real_D_logits)))


        # Feature Matching Loss - Remains the same conceptually
        fm_loss = self.compute_feature_matching_loss(real_AB, fake_AB) # real_AB and fake_AB need to be defined here

        # Perceptual Loss - Remains the same
        perceptual_loss = self.vgg_loss(self.fake_B, self.real_B)
        perceptual_weight = 10.0

        # Total Generator Loss
        self.g_loss = g_loss_adv + l1_loss + fake_category_loss + const_loss + fm_loss + perceptual_weight * perceptual_loss

        return const_loss, l1_loss, g_loss_adv, fm_loss, perceptual_loss

    # backward_D should be okay as D's input format hasn't changed (cat(A, B))
    # optimize_parameters should be okay, just ensure the returned loss values are correct scalars.

Key Changes Summary:

  1. FiLMGenerator & FiLMLayer: Created helper modules for FiLM.
  2. UnetSkipConnectionBlock:
    • Added FiLMLayer instances (film_down, film_up).
    • Removed style concatenation at the bottleneck.
    • Applied FiLM modulation after normalization layers in both down and up paths.
    • Modified the forward method to pass the style embedding recursively and apply FiLM.
    • Changed the innermost block’s return value to provide FiLM-modulated bottleneck features for the consistency loss.
    • Set affine=False in norm_layer where FiLM is applied immediately after (FiLM provides its own learnable scale/shift).
  3. UNetGenerator:
    • Passed embedding_dim down to blocks.
    • Updated forward and encode methods to require style_or_label and return/use the bottleneck features correctly.
    • Set embedding_num=2 (assuming 0 for Noto Sans, 1 for Zen Maru Gothic, or vice-versa).
  4. Zi2ZiModel:
    • Ensured embedding_num=2 is used during initialization.
    • Verified that the const_loss calculation uses the new bottleneck features returned by the generator/encoder.
    • Made sure the generator (self.netG) and encoder (self.netG.encode) calls in forward receive the self.labels.
    • Checked that loss calculations (backward_G, backward_D) still make sense with the new generator structure.

Remember to set embedding_num=2 when creating the Zi2ZiModel instance for your specific two-font task. You’ll need to map “Noto Sans” and “Zen Maru Gothic” to integer labels 0 and 1 in your data loading pipeline.

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *