你這邊的思路完全正確,而且你也很接近完成轉換了!我幫你整理一下邏輯、順便釐清一點實作細節,讓你的 FiLM 應用更順暢。
✅ 你的目標
你想讓風格向量(style_vec
)透過 FiLM 的方式(Scale + Shift)調節 UNet 中每一層的特徵圖,也就是:
out = gamma * feat + beta
其中 gamma
和 beta
是從 style_vec
透過一個 MLP 變換出來的。
🔧 你的設計(很接近完成)
你已經有一個 StyleModulation
類別,這裡建議幾個修正讓它更完整穩定:
🔄 修正版 StyleModulation
import torch
import torch.nn as nn
class StyleModulation(nn.Module):
def __init__(self, in_channels, style_dim):
super().__init__()
self.norm = nn.InstanceNorm2d(in_channels, affine=False)
self.style_fc = nn.Sequential(
nn.Linear(style_dim, in_channels * 2),
nn.ReLU(),
nn.Linear(in_channels * 2, in_channels * 2) # 分成 gamma 和 beta
)
def forward(self, x, style_vec):
# B, C, H, W = x.shape
x_norm = self.norm(x)
# 得到 gamma 和 beta,形狀為 (B, 2*C)
style_params = self.style_fc(style_vec)
gamma, beta = style_params.chunk(2, dim=1)
# reshape 成 (B, C, 1, 1) 才能和 x 對應
gamma = gamma.unsqueeze(2).unsqueeze(3)
beta = beta.unsqueeze(2).unsqueeze(3)
# 套用 FiLM
return gamma * x_norm + beta
🔗 如何在 UNet 裡用這個模組?
你可以在 UNet 的 encoder 或 decoder block 中間插入:
self.style_mod = StyleModulation(in_channels, style_dim)
並在 forward 時傳入:
x = self.style_mod(x, style_vec)
💡 小提醒
- 使用
InstanceNorm2d
而不是 BatchNorm,因為風格轉換任務中個別樣本風格的獨立性更重要。 style_dim
是你風格向量的長度,像是 128 或 256。- 如果你使用殘差模組,可以把
StyleModulation
放在殘差裡中間那層 conv 之後(或用多個)。
🧪 可以測試看看
- 拿相同的圖,但給不同
style_vec
,應該能看到輸出風格不同。 - 如果
style_vec
是固定值(像全 1 或全 0),你應該能觀察到風格不變或只有微調。
你這套模組已經很棒了,如果你有 UNet 的某層架構想套用,我也可以幫你示範怎麼插入 StyleModulation
進去,讓整合更順。要不要我幫你搭一層 UNet block 當例子?
為了實現將 風格嵌入透過 FiLM(Feature-wise Linear Modulation) 應用到 UNet 的各層特徵圖中,我們將修改以下幾個部分:
✅ 本次修改概要
- 新增
StyleModulation
模組 - 在
UnetSkipConnectionBlock
中加入 style modulation - 移除
innermost
層的風格拼接 - 維持其他架構與 loss 不變
🧩 修改 1:新增 StyleModulation
模組
class StyleModulation(nn.Module):
def __init__(self, channels, style_dim):
super(StyleModulation, self).__init__()
self.norm = nn.InstanceNorm2d(channels, affine=False)
self.style_fc = nn.Sequential(
nn.Linear(style_dim, channels * 2),
nn.ReLU(),
nn.Linear(channels * 2, channels * 2)
)
def forward(self, x, style):
x = self.norm(x)
style_params = self.style_fc(style) # (B, 2C)
gamma, beta = style_params.chunk(2, dim=1) # 各為 (B, C)
gamma = gamma.unsqueeze(2).unsqueeze(3) # (B, C, 1, 1)
beta = beta.unsqueeze(2).unsqueeze(3)
return gamma * x + beta
🧩 修改 2:調整 UnetSkipConnectionBlock
只需修改這個 class 即可,這是完整異動後的程式碼:
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None,
norm_layer=nn.InstanceNorm2d, layer=0, embedding_dim=128,
use_dropout=False, self_attention=False, blur=False, outermost=False, innermost=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
self.innermost = innermost
self.embedding_dim = embedding_dim
use_bias = norm_layer != nn.BatchNorm2d
if input_nc is None:
input_nc = outer_nc
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
nn.init.kaiming_normal_(downconv.weight, nonlinearity='leaky_relu')
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(inplace=False)
upnorm = norm_layer(outer_nc)
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
nn.init.kaiming_normal_(upconv.weight)
self.down = nn.Sequential(downconv)
self.up = nn.Sequential(uprelu, upconv, nn.Tanh())
self.style_mod = None # outermost 不調整
elif innermost:
upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
nn.init.kaiming_normal_(upconv.weight)
self.down = nn.Sequential(downrelu, downconv)
self.up = nn.Sequential(uprelu, upconv, upnorm)
self.style_mod = StyleModulation(inner_nc, embedding_dim)
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
nn.init.kaiming_normal_(upconv.weight)
self.down = nn.Sequential(downrelu, downconv, downnorm)
self.up = nn.Sequential(uprelu, upconv, upnorm)
self.style_mod = StyleModulation(inner_nc, embedding_dim)
if use_dropout:
self.up.add_module("dropout", nn.Dropout(0.3))
self.submodule = submodule
self.self_attn = SelfAttention(inner_nc) if self_attention and layer in [4, 6] else None
self.res_skip = ResSkip(outer_nc) if not outermost and not innermost else None
def _process_submodule(self, encoded, style):
if self.submodule:
return self.submodule(encoded, style)
else:
return encoded, None
def _interpolate_if_needed(self, decoded, x):
return F.interpolate(decoded, size=x.shape[2:], mode='bilinear', align_corners=False) if decoded.shape[2:] != x.shape[2:] else decoded
def forward(self, x, style=None):
encoded = self.down(x)
if self.self_attn:
encoded = self.self_attn(encoded)
# Apply FiLM style modulation if available
if self.style_mod and style is not None:
encoded = self.style_mod(encoded, style)
if self.innermost:
decoded = self.up(encoded)
decoded = self._interpolate_if_needed(decoded, x)
return torch.cat([x, decoded], 1), encoded.view(x.shape[0], -1)
sub_output, encoded_real_A = self._process_submodule(encoded, style)
decoded = self.up(sub_output)
decoded = self._interpolate_if_needed(decoded, x)
if self.res_skip:
decoded = self.res_skip(decoded)
if self.outermost:
return decoded, encoded_real_A
else:
return torch.cat([x, decoded], 1), encoded_real_A
📦 其他模組不用動
UNetGenerator
中會自動傳遞style
至各個 block。Embedding
仍保持原樣:
self.embedder = nn.Embedding(embedding_num, embedding_dim)
✅ 測試提醒
- 你現在只需用兩種風格(Noto Sans / Zen Maru Gothic),設定
embedding_num=2
即可。 - style label 為
0
或1
,輸入到embedder
,會變成style vector
傳入各層進行調節。
如果你想讓某些層不要加 FiLM(例如最底層只用 Self-Attention),也可以根據 layer
來做控制,要我幫你加判斷式嗎?還是這樣已經符合你需求?
gemini 2.5 pro 的回覆:
Okay, let’s modify the U-Net generator to use FiLM (Feature-wise Linear Modulation) for style conditioning instead of concatenation at the bottleneck. This approach allows the style embedding to influence feature maps throughout the network.
1. Define the FiLM Layer
First, we need a small module that takes the style embedding and generates the gamma
(scale) and beta
(shift) parameters for FiLM.
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
# Assume other necessary imports like models, VGG16_Weights, T are present
from torchvision import models # Example import
from torchvision.models import VGG16_Weights # Example import
import torchvision.transforms as T # Example import
class FiLMGenerator(nn.Module):
"""Generates FiLM parameters (gamma and beta) from a style embedding."""
def __init__(self, embedding_dim, feature_channels):
super(FiLMGenerator, self).__init__()
# Linear layer to project embedding to gamma and beta dimensions
self.layer = nn.Linear(embedding_dim, feature_channels * 2)
# Initialize gamma close to 1 and beta close to 0
self.layer.weight.data.normal_(0, 0.02) # Small initial weights
self.layer.bias.data.zero_() # Zero initial bias
def forward(self, style_embedding):
# Generate gamma and beta
params = self.layer(style_embedding)
# Split into gamma and beta, reshape for broadcasting
gamma, beta = params.chunk(2, dim=-1) # Split along the last dimension
# Reshape to (Batch, Channels, 1, 1) for broadcasting over feature maps
gamma = gamma.unsqueeze(-1).unsqueeze(-1)
beta = beta.unsqueeze(-1).unsqueeze(-1)
return gamma, beta
class FiLMLayer(nn.Module):
"""Applies FiLM modulation to features."""
def __init__(self, embedding_dim, feature_channels):
super(FiLMLayer, self).__init__()
self.film_generator = FiLMGenerator(embedding_dim, feature_channels)
def forward(self, features, style_embedding):
gamma, beta = self.film_generator(style_embedding)
return gamma * features + beta
2. Modify UnetSkipConnectionBlock
We’ll integrate FiLMLayer
into the downsampling and upsampling paths of each block.
Python
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None,
norm_layer=nn.InstanceNorm2d, layer=0, embedding_dim=128, # Added embedding_dim
use_dropout=False, self_attention=False, blur=False, outermost=False, innermost=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
self.innermost = innermost
use_bias = norm_layer != nn.BatchNorm2d # InstanceNorm typically doesn't use bias after conv
if input_nc is None:
input_nc = outer_nc
# --- Downsampling Path ---
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
nn.init.kaiming_normal_(downconv.weight, nonlinearity='leaky_relu')
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc, affine=False) # Use affine=False if applying FiLM after norm
self.film_down = FiLMLayer(embedding_dim, inner_nc) # FiLM after norm, before ReLU
# --- Upsampling Path ---
uprelu = nn.ReLU(True) # ReLU before FiLM
upnorm = norm_layer(outer_nc, affine=False) # Use affine=False if applying FiLM after norm
self.film_up = FiLMLayer(embedding_dim, outer_nc) # FiLM after norm, before dropout/final activation
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1) # Bias typically true here
nn.init.kaiming_normal_(upconv.weight)
# No norm/relu/film on down path of outermost
self.down = nn.Sequential(downconv)
# Apply ReLU, ConvTranspose, Tanh. No upnorm/film_up here.
self.up = nn.Sequential(uprelu, upconv, nn.Tanh())
self.film_down = None # No FiLM needed in outermost down path
self.film_up = None # No FiLM needed in outermost up path (applied before Tanh if desired, but usually not)
elif innermost:
# No style concatenation in upconv anymore
upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
nn.init.kaiming_normal_(upconv.weight)
# Down path: ReLU, Conv, Norm, FiLM
self.down = nn.Sequential(downrelu, downconv, downnorm) # FiLM applied in forward
# Up path: ReLU, ConvTranspose, Norm, FiLM
self.up = nn.Sequential(uprelu, upconv, upnorm) # FiLM applied in forward
# Note: innermost doesn't have submodule
else: # Intermediate layers
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
nn.init.kaiming_normal_(upconv.weight)
# Down path: ReLU, Conv, Norm, FiLM
self.down = nn.Sequential(downrelu, downconv, downnorm) # FiLM applied in forward
# Up path: ReLU, ConvTranspose, Norm, FiLM
self.up = nn.Sequential(uprelu, upconv, upnorm) # FiLM applied in forward
if use_dropout:
# Apply dropout *after* FiLM modulation on the up-path features
# We need to insert it manually in forward or adjust self.up
self.dropout = nn.Dropout(0.3)
else:
self.dropout = None
self.submodule = submodule
self.self_attention = SelfAttention(inner_nc) if self_attention and layer in [4, 6] else None
# ResSkip is applied after upsampling and FiLM, before concatenation. Keep it.
# But ensure its channel count matches outer_nc.
self.res_skip = ResSkip(outer_nc) if not outermost else None # Innermost doesn't need it before cat
def _process_submodule(self, encoded, style):
# Pass style embedding down recursively
if self.submodule:
return self.submodule(encoded, style)
else: # Innermost block reached
# For const_loss, return the modulated features before upsampling
return encoded, encoded # Pass modulated encoded features as 'encoded_real_A' from bottom
def _interpolate_if_needed(self, decoded, x):
# This function remains the same
if decoded.shape[2:] != x.shape[2:]:
return F.interpolate(decoded, size=x.shape[2:], mode='bilinear', align_corners=False)
return decoded
def forward(self, x, style): # Style embedding is now mandatory (except maybe outermost?)
# --- Down Path ---
if self.outermost:
encoded = self.down(x) # No ReLU/Norm/FiLM here
else:
# Apply down path blocks: LeakyReLU, Conv, Norm
down_features = self.down(x)
# Apply FiLM modulation
if self.film_down:
modulated_down = self.film_down(down_features, style)
else: # Should only happen if something is wrong, but defensively handle
modulated_down = down_features
encoded = modulated_down # Features ready for Self-Attention or submodule
# Apply Self-Attention if specified
if self.self_attention:
encoded = self.self_attention(encoded)
# --- Process Submodule or Innermost Logic ---
if self.innermost:
# Style is already applied via FiLM in self.down
# The 'encoded' here is the final modulated bottleneck feature map
# For consistency loss, we'll use this 'encoded' state.
encoded_for_const_loss = encoded.view(x.shape[0], -1) # Flatten for potential later comparison
# Apply up path blocks: ReLU, ConvTranspose, Norm
up_features = self.up(encoded)
# Apply FiLM modulation
if self.film_up:
modulated_up = self.film_up(up_features, style)
else:
modulated_up = up_features # Should not happen in innermost
decoded = modulated_up
# Interpolate and apply ResSkip (if applicable, though typically not in innermost)
decoded = self._interpolate_if_needed(decoded, x)
if self.res_skip:
decoded = self.res_skip(decoded) # ResSkip now operates on outer_nc
# Return structure: Concatenated output, Bottleneck features for const_loss
return torch.cat([x, decoded], 1), encoded_for_const_loss
else: # Not innermost
# Recursively call submodule, passing modulated 'encoded' features and style
sub_output, encoded_for_const_loss = self._process_submodule(encoded, style)
# --- Up Path (for non-innermost layers) ---
# Apply up path blocks: ReLU, ConvTranspose, Norm
up_features = self.up(sub_output) # sub_output contains skip connection from below + upsampled features
# Apply FiLM modulation
if self.film_up:
modulated_up = self.film_up(up_features, style)
else: # Outermost case
modulated_up = up_features
# Apply dropout if configured (after FiLM)
if hasattr(self, 'dropout') and self.dropout is not None:
modulated_up = self.dropout(modulated_up)
decoded = modulated_up
# Interpolate and apply ResSkip
decoded = self._interpolate_if_needed(decoded, x)
if self.res_skip:
decoded = self.res_skip(decoded) # ResSkip operates on outer_nc
# --- Output ---
if self.outermost:
# Final output image, and the bottleneck features from the bottom
return decoded, encoded_for_const_loss
else:
# Concatenate skip connection (original x) with decoded features
return torch.cat([x, decoded], 1), encoded_for_const_loss
3. Modify UNetGenerator
Update the generator to handle the new block structure and style passing.
Python
class UNetGenerator(nn.Module):
# Set embedding_num=2 for the 2 styles Noto Sans -> Zen Maru Gothic
def __init__(self, input_nc=1, output_nc=1, num_downs=8, ngf=64, embedding_num=2, embedding_dim=128,
norm_layer=nn.InstanceNorm2d, use_dropout=False, self_attention=False, blur=False):
super(UNetGenerator, self).__init__()
# --- Build the U-Net with FiLM ---
# Innermost block
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None,
norm_layer=norm_layer, layer=1, embedding_dim=embedding_dim, # Pass embedding_dim
self_attention=self_attention, blur=blur, innermost=True)
# Middle layers (if num_downs > 5)
for index in range(num_downs - 5): # Example: num_downs=8 -> range(3) -> indices 0, 1, 2 -> layers 2, 3, 4
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block,
norm_layer=norm_layer, layer=index+2, embedding_dim=embedding_dim, # Pass embedding_dim
use_dropout=use_dropout, self_attention=self_attention, blur=blur)
# Upper layers (closer to output)
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block,
norm_layer=norm_layer, layer=num_downs-3, embedding_dim=embedding_dim, # Pass embedding_dim
self_attention=self_attention, blur=blur) # layer index example: 8-3=5
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block,
norm_layer=norm_layer, layer=num_downs-2, embedding_dim=embedding_dim, # Pass embedding_dim
self_attention=self_attention, blur=blur) # layer index example: 8-2=6
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block,
norm_layer=norm_layer, layer=num_downs-1, embedding_dim=embedding_dim, # Pass embedding_dim
self_attention=self_attention, blur=blur) # layer index example: 8-1=7
# Outermost block
self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block,
norm_layer=norm_layer, layer=num_downs, embedding_dim=embedding_dim, # Pass embedding_dim
self_attention=self_attention, blur=blur, outermost=True) # layer index example: 8
# Embedding layer for the 2 styles
self.embedder = nn.Embedding(embedding_num, embedding_dim)
def _prepare_style(self, style_or_label):
# This function remains the same
if style_or_label is not None and isinstance(style_or_label, torch.Tensor) and 'LongTensor' in style_or_label.type():
# Ensure labels are within the embedding_num range (0 or 1)
if style_or_label.max() >= self.embedder.num_embeddings:
raise ValueError(f"Label index {style_or_label.max()} is out of bounds for embedding_num={self.embedder.num_embeddings}")
return self.embedder(style_or_label)
else:
# Allows passing pre-computed embeddings directly if needed
return style_or_label
def forward(self, x, style_or_label): # Style is now required
if style_or_label is None:
raise ValueError("Style embedding or label must be provided for FiLM U-Net.")
style = self._prepare_style(style_or_label)
# self.model now returns (output_image, bottleneck_features)
fake_B, encoded_bottleneck = self.model(x, style)
return fake_B, encoded_bottleneck # Return bottleneck features for const_loss
def encode(self, x, style_or_label): # Style is now required
if style_or_label is None:
raise ValueError("Style embedding or label must be provided for FiLM U-Net encoding.")
style = self._prepare_style(style_or_label)
# Need to call the model and extract the second return value
_, encoded_bottleneck = self.model(x, style)
return encoded_bottleneck # Return the bottleneck features
4. Modify Zi2ZiModel
Adjust the initialization and potentially the loss calculations if they relied heavily on the specific structure of the concatenated encoding.
Python
class Zi2ZiModel:
# Set embedding_num=2 by default if appropriate, or pass it during instantiation
def __init__(self, input_nc=1, embedding_num=2, embedding_dim=128, ngf=64, ndf=64,
Lconst_penalty=10, Lcategory_penalty=1, L1_penalty=100,
schedule=10, lr=0.001, gpu_ids=None, save_dir='.', is_training=True,
self_attention=False, # residual_block param seems unused? Removed for clarity.
weight_decay=1e-5, beta1=0.5, g_blur=False, d_blur=False, epoch=40,
gradient_clip=0.5, norm_type="instance"): # Removed residual_block param
# ... (Keep most of the init code) ...
self.embedding_num = embedding_num # Ensure this is 2 for your task
# ...
# Assume setup, set_input, compute_*, backward_D, backward_G, optimize_parameters remain largely the same
# BUT VERIFY THE LOSSES:
def setup(self):
# ... (Setup code) ...
# Make sure UNetGenerator is initialized with the correct embedding_num
self.netG = UNetGenerator(
input_nc=self.input_nc,
output_nc=self.input_nc,
ngf=self.ngf,
use_dropout=self.use_dropout,
embedding_num=self.embedding_num, # Should be 2
embedding_dim=self.embedding_dim,
self_attention=self.self_attention,
blur=self.g_blur,
norm_layer=norm_layer # Make sure norm_layer is defined correctly based on self.norm_type
)
# Discriminator setup remains the same, but ensure embedding_num matches
self.netD = Discriminator(
input_nc=2 * self.input_nc, # Still takes concatenated input+output
embedding_num=self.embedding_num, # Should be 2
ndf=self.ndf,
blur=self.d_blur,
norm_layer=nn.BatchNorm2d # Or InstanceNorm? Check original D design
)
# ... (Rest of setup: init_net, optimizers, schedulers, losses) ...
# CategoryLoss needs embedding_num=2
self.category_loss = CategoryLoss(self.embedding_num)
# ... (Other losses) ...
def forward(self):
# The forward pass of UNetGenerator now requires the label/style
self.fake_B, self.encoded_real_A = self.netG(self.real_A, self.labels)
# The encode method also requires the label/style
self.encoded_fake_B = self.netG.encode(self.fake_B, self.labels)
# encoded_real_A and encoded_fake_B are now the flattened bottleneck features
# after FiLM modulation but before the innermost upsampling block.
def backward_G(self, no_target_source=False):
# ... (Calculate fake_AB, real_AB, get D outputs) ...
# Consistency Loss: Compare the FiLM-modulated bottleneck features
# Ensure encoded_real_A and encoded_fake_B have compatible shapes for MSE
# They should both be [Batch, FeatureDim] from the view operation in UnetSkipConnectionBlock innermost
const_loss = self.Lconst_penalty * self.mse(self.encoded_real_A, self.encoded_fake_B)
# L1 Loss (Pixel-wise reconstruction) - Remains the same
l1_loss = self.L1_penalty * self.l1_loss(self.fake_B, self.real_B)
# Category Loss for Generator - Remains the same conceptually
# Need fake_category_logits from D
_, fake_category_logits = self.netD(torch.cat([self.real_A, self.fake_B], 1)) # Pass fake_B to D
fake_category_loss = self.Lcategory_penalty * self.category_loss(fake_category_logits, self.labels)
# Adversarial Loss for Generator - Remains the same conceptually
# Need fake_D_logits and real_D_logits
fake_D_logits, _ = self.netD(torch.cat([self.real_A, self.fake_B], 1))
real_D_logits, _ = self.netD(torch.cat([self.real_A, self.real_B], 1)) # D needs real pair too
# Using the relativistic loss form from before:
g_loss_adv = -torch.mean(F.logsigmoid(fake_D_logits - real_D_logits))
# Or the symmetric version:
# g_loss_adv = -0.5 * (torch.mean(F.logsigmoid(real_D_logits - fake_D_logits)) +
# torch.mean(F.logsigmoid(fake_D_logits - real_D_logits)))
# Feature Matching Loss - Remains the same conceptually
fm_loss = self.compute_feature_matching_loss(real_AB, fake_AB) # real_AB and fake_AB need to be defined here
# Perceptual Loss - Remains the same
perceptual_loss = self.vgg_loss(self.fake_B, self.real_B)
perceptual_weight = 10.0
# Total Generator Loss
self.g_loss = g_loss_adv + l1_loss + fake_category_loss + const_loss + fm_loss + perceptual_weight * perceptual_loss
return const_loss, l1_loss, g_loss_adv, fm_loss, perceptual_loss
# backward_D should be okay as D's input format hasn't changed (cat(A, B))
# optimize_parameters should be okay, just ensure the returned loss values are correct scalars.
Key Changes Summary:
FiLMGenerator
&FiLMLayer
: Created helper modules for FiLM.UnetSkipConnectionBlock
:- Added
FiLMLayer
instances (film_down
,film_up
). - Removed style concatenation at the bottleneck.
- Applied FiLM modulation after normalization layers in both down and up paths.
- Modified the
forward
method to pass thestyle
embedding recursively and apply FiLM. - Changed the
innermost
block’s return value to provide FiLM-modulated bottleneck features for the consistency loss. - Set
affine=False
innorm_layer
where FiLM is applied immediately after (FiLM provides its own learnable scale/shift).
- Added
UNetGenerator
:- Passed
embedding_dim
down to blocks. - Updated
forward
andencode
methods to requirestyle_or_label
and return/use the bottleneck features correctly. - Set
embedding_num=2
(assuming 0 for Noto Sans, 1 for Zen Maru Gothic, or vice-versa).
- Passed
Zi2ZiModel
:- Ensured
embedding_num=2
is used during initialization. - Verified that the
const_loss
calculation uses the new bottleneck features returned by the generator/encoder. - Made sure the generator (
self.netG
) and encoder (self.netG.encode
) calls inforward
receive theself.labels
. - Checked that loss calculations (
backward_G
,backward_D
) still make sense with the new generator structure.
- Ensured
Remember to set embedding_num=2
when creating the Zi2ZiModel
instance for your specific two-font task. You’ll need to map “Noto Sans” and “Zen Maru Gothic” to integer labels 0 and 1 in your data loading pipeline.