FiLMModulation 的2種不同寫法

下列沒顯示出來的版本:

class FiLMModulation(nn.Module):
    def __init__(self, feature_nc, style_dim):
        super(FiLMModulation, self).__init__()
        self.scale_fc = nn.Linear(style_dim, feature_nc)
        self.shift_fc = nn.Linear(style_dim, feature_nc)

    def forward(self, x, style):
        scale = self.scale_fc(style).view(style.size(0), -1, 1, 1)
        shift = self.shift_fc(style).view(style.size(0), -1, 1, 1)
        return x * scale + shift

兩個 FiLMModulation 版本功能是相同的:透過 style 向量對特徵圖進行通道尺度與偏移調控(scale + shift)。但設計上略有差異:


✅ 差異比較

特性版本 1 (scale_fc & shift_fc)版本 2 (Linear → chunk → gamma/beta)
結構兩個 nn.Linear一個 nn.Linear 層 + chunk
初始化預設初始化使用 kaiming_normal_
語意清晰度較直觀:一個 linear 負責 scale,一個負責 shift較壓縮:合併到單個層,較難讀但更精簡
參數量2 × (style_dim × C)1 × (style_dim × 2C) — 相同總參數量
靈活性可分別處理 scale_fcshift_fc少一點彈性,但較簡潔
效能差距幾乎無差異幾乎無差異

✅ 你該選哪個版本?

✔ 推薦版本(大多數情況)

class FiLMModulation(nn.Module):
    def __init__(self, in_channels, style_dim):
        super(FiLMModulation, self).__init__()
        self.film = nn.Linear(style_dim, in_channels * 2)
        nn.init.kaiming_normal_(self.film.weight, nonlinearity='linear')

    def forward(self, x, style):
        gamma_beta = self.film(style)  # (B, 2 * C)
        gamma, beta = gamma_beta.chunk(2, dim=1)  # (B, C), (B, C)
        gamma = gamma.unsqueeze(-1).unsqueeze(-1)  # (B, C, 1, 1)
        beta = beta.unsqueeze(-1).unsqueeze(-1)
        return gamma * x + beta

為什麼選這個版本?

  • 精簡,有助於大型架構中保持模組整潔
  • chunk 方法常見於 多頭參數 的設計(如 transformer, GAN)
  • 初始化控制 也更明確(尤其在大型模型中影響收斂)

✅ 如果你偏好可讀性(或要 debug 個別效果)

你可以選 scale_fc / shift_fc 的版本,它讓你更容易看出誰控制了什麼,對於 ablation 或實驗研究也許更直觀。


如果你有後續想要:

  • FiLM gate(例如 gating residuals)
  • 多層 FiLM 結構
  • 與 transformer block 或 attention 結合

我也可以幫你擴展。需要的話說一聲 👍

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *