下面模型的 use_res_skip 是 true or false?

因為程式改版, 改到忘記某一個正在訓練中的模型這個新增加的輸入值應該是什麼, 所以直接問 AI 問題,

下面模型的 use_res_skip 是 true or false?

UNetGenerator(
  (embedder): Embedding(40, 128)
  (model): UnetSkipConnectionBlock(
    (submodule): UnetSkipConnectionBlock(
      (submodule): UnetSkipConnectionBlock(
        (submodule): UnetSkipConnectionBlock(
          (submodule): UnetSkipConnectionBlock(
            (att): SelfAttention(
              (query_conv): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1))
              (key_conv): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1))
              (value_conv): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
            )
            (submodule): UnetSkipConnectionBlock(
              (submodule): UnetSkipConnectionBlock(
                (att): SelfAttention(
                  (query_conv): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1))
                  (key_conv): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1))
                  (value_conv): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
                )
                (down): Sequential(
                  (0): LeakyReLU(negative_slope=0.2, inplace=True)
                  (1): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
                )
                (up): Sequential(
                  (0): ReLU(inplace=True)
                  (1): Conv2d(640, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                  (2): PixelShuffle(upscale_factor=2)
                  (3): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
                )
              )
              (down): Sequential(
                (0): LeakyReLU(negative_slope=0.2, inplace=True)
                (1): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
                (2): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
              )
              (up): Sequential(
                (0): ReLU(inplace=True)
                (1): Conv2d(512, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (2): PixelShuffle(upscale_factor=2)
                (3): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
                (4): Dropout(p=0.5, inplace=False)
              )
            )
            (down): Sequential(
              (0): LeakyReLU(negative_slope=0.2, inplace=True)
              (1): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
              (2): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
            )
            (up): Sequential(
              (0): ReLU(inplace=True)
              (1): Conv2d(512, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (2): PixelShuffle(upscale_factor=2)
              (3): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
              (4): Dropout(p=0.5, inplace=False)
            )
          )
          (down): Sequential(
            (0): LeakyReLU(negative_slope=0.2, inplace=True)
            (1): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
            (2): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          )
          (up): Sequential(
            (0): ReLU(inplace=True)
            (1): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): PixelShuffle(upscale_factor=2)
            (3): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          )
        )
        (down): Sequential(
          (0): LeakyReLU(negative_slope=0.2, inplace=True)
          (1): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
          (2): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        )
        (up): Sequential(
          (0): ReLU(inplace=True)
          (1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (2): PixelShuffle(upscale_factor=2)
          (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        )
      )
      (down): Sequential(
        (0): LeakyReLU(negative_slope=0.2, inplace=True)
        (1): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (2): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      )
      (up): Sequential(
        (0): ReLU(inplace=True)
        (1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (2): PixelShuffle(upscale_factor=2)
        (3): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      )
    )
    (down): Sequential(
      (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    )
    (up): Sequential(
      (0): ReLU(inplace=True)
      (1): Conv2d(64, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (2): PixelShuffle(upscale_factor=2)
      (3): Tanh()
    )
  )
)
[Network G] Total number of parameters : 52.875 M
Discriminator(
  (model): Sequential(
    (0): Conv2d(2, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (6): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (9): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  )
  (binary): Linear(in_features=2304, out_features=1, bias=True)
  (catagory): Linear(in_features=2304, out_features=40, bias=True)
)
[Network D] Total number of parameters : 4.412 M

AI 回:

根據你提供的模型結構輸出,

use_res_skip

 是 True

判斷依據:

在 UNetGenerator 的 UnetSkipConnectionBlock 中,

use_res_skip

 決定了子模組(submodule)的輸出和編碼特徵(skip connection)是如何結合的:

  • True (Addition): 直接相加 x + dec通道數(Channels)維持與 inner_nc 相同。因此,上一層(該層的包覆層)的 up 卷積層輸入通道數會等於該層的 inner_nc
  • False (Concatenation): 進行串接 cat([x, dec])通道數變為 inner_nc * 2。因此,上一層的 up 卷積層輸入通道數會是 inner_nc * 2

觀察你的模型輸出片段(以倒數第二層為例):

(submodule): UnetSkipConnectionBlock(
...
(down): Sequential(
...
(1): Conv2d(256, 512, ...) <-- 此層 encoded (outer_nc=256, inner_nc=512)
)
(up): Sequential(
(0): ReLU(inplace=True)
(1): Conv2d(512, 1024, ...) <-- 關鍵在於這裡的輸入通道數是 512
...
)
)
  1. 此區塊的 inner_nc 是 512(從 down 層的輸出可見)。
  2. 此區塊的 up 層輸入是 512
  3. 如果 use_res_skip 是 False (Concatenation),這裡的輸入應該是 512 * 2 = 1024
  4. 因為這裡輸入是 512(等於 inner_nc),表示下層回傳的特徵與 Skip Connection 是採相加方式,通道數未倍增。

所以 

use_res_skip = True

真棒, 提問問題並提供資料, 就可以無腦得到解答.

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *