Zi2ZiModel 字型風格轉換只有要學習 noto sans 到 zen maru gothic, 還需要 embedding 嗎

是的,如果您的目標是僅將 Noto Sans 字型轉換為 Zen Maru Gothic 字型,那麼您不需要使用 embedding。Embedding 的主要目的是處理多種風格或類別的轉換,當您只有兩種固定的字型風格時,它可以被簡化。

以下是針對 Noto Sans 到 Zen Maru Gothic 單一風格轉換的修改建議:

1. 移除 Embedding 相關程式碼:

  • UNetGeneratorUnetSkipConnectionBlock 類別中移除 embedding_numembedding_dim 參數。
  • 移除 UNetGenerator 中的 self.embedder
  • 移除 UNetGeneratorUnetSkipConnectionBlockforward 函數中與 embedding 相關的邏輯。
  • Zi2ZiModel 類別中移除 embedding_numembedding_dim 參數。
  • Zi2ZiModel 類別的 setup 函數中,移除與 embedding 相關的程式碼。
  • Zi2ZiModel 類別的 set_input 函數中移除 labels 輸入。
  • Zi2ZiModel 類別的 forward 函數中移除 labels 輸入。
  • Zi2ZiModel 類別的 backward_D 函數中移除 labels 輸入。
  • Zi2ZiModel 類別的 backward_G 函數中移除 labels 輸入。
  • Zi2ZiModel 類別的 optimize_parameters 函數中移除 category_loss 輸入。
  • 移除 CategoryLoss 類別。
  • 移除所有跟 category_loss 有關的程式碼。

2. 簡化網路結構:

  • 由於不再需要處理多種風格,您可以簡化生成器和判別器的網路結構。
  • 例如,您可以減少生成器中的跳躍連接或減少判別器中的層數。

3. 修改損失函數:

  • 由於不再需要類別損失,您可以僅使用二元交叉熵損失(Binary Cross Entropy Loss)和 L1 損失。

4. 修改訓練循環:

  • 從訓練循環中移除與類別相關的程式碼。
  • 簡化損失函數的計算。

簡化後的程式碼範例:

Python

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import os
import time

# ... (UNetGenerator, UnetSkipConnectionBlock, Discriminator, BinaryLoss classes remain the same) ...

class Zi2ZiModel:
    def __init__(self, input_nc=1, output_nc=1, ngf=64, ndf=64,
                 Lconst_penalty=15, L1_penalty=100,
                 schedule=10, lr=0.001, gpu_ids=None, save_dir='.', is_training=True,
                 image_size=256):

        # ... (init function) ...

    def setup(self):
        self.netG = UNetGenerator(
            input_nc=self.input_nc,
            output_nc=self.output_nc,
            ngf=self.ngf,
            use_dropout=self.use_dropout
        )
        self.netD = Discriminator(
            input_nc=2 * self.input_nc,
            ndf=self.ndf,
            image_size=self.image_size
        )

        # ... (setup function) ...

    def set_input(self, real_A, real_B):
        # ... (set_input function) ...

    def forward(self):
        self.fake_B, self.encoded_real_A = self.netG(self.real_A)
        self.encoded_fake_B = self.netG(self.fake_B).view(self.fake_B.shape[0], -1)

    def backward_D(self):
        # ... (backward_D function) ...

    def backward_G(self):
        # ... (backward_G function) ...

    def optimize_parameters(self):
        # ... (optimize_parameters function) ...

    # ... (other functions) ...

# ... (training script) ...

通過這些修改,您的模型將專注於 Noto Sans 到 Zen Maru Gothic 的單一風格轉換,並簡化網路結構和訓練過程。


Python

import torch
import torch.nn as nn
import torch.optim as optim
import os
import math
import functools
from torch.utils.data import DataLoader

class UNetGenerator(nn.Module):
    def __init__(self, input_nc=1, output_nc=1, ngf=64, use_dropout=False):
        super(UNetGenerator, self).__init__()
        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, innermost=True)
        for _ in range(3):
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, use_dropout=use_dropout)
        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block)
        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block)
        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block)
        self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True)

    def forward(self, x):
        return self.model(x)

class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None, outermost=False, innermost=False, use_dropout=False):
        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost
        if input_nc is None:
            input_nc = outer_nc
        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1)
        downrelu = nn.LeakyReLU(0.2, True)
        uprelu = nn.ReLU(True)

        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1)
            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]
        elif innermost:
            upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1)
            down = [downrelu, downconv]
            up = [uprelu, upconv]
        else:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1)
            down = [downrelu, downconv]
            up = [uprelu, upconv]
            if use_dropout:
                up.append(nn.Dropout(0.5))

        self.submodule = submodule
        self.down = nn.Sequential(*down)
        self.up = nn.Sequential(*up)

    def forward(self, x):
        if self.outermost:
            return self.up(self.submodule(self.down(x)))
        else:
            return torch.cat([x, self.up(self.submodule(self.down(x)))], 1)

class Discriminator(nn.Module):
    def __init__(self, input_nc=1, ndf=64, image_size=256):
        super(Discriminator, self).__init__()
        sequence = [
            nn.Conv2d(input_nc, ndf, kernel_size=5, stride=2, padding=2),
            nn.LeakyReLU(0.2, True)
        ]
        nf_mult = 1
        for _ in range(2):
            nf_mult_prev = nf_mult
            nf_mult = min(2 * nf_mult, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=5, stride=2, padding=2),
                nn.LeakyReLU(0.2, True)
            ]
        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=5, stride=1, padding=2)]
        self.model = nn.Sequential(*sequence)
        self.binary = nn.Linear(image_size // 8 * image_size // 8, 1)

    def forward(self, input):
        features = self.model(input)
        features = features.view(input.shape[0], -1)
        return self.binary(features)

class Zi2ZiModel:
    def __init__(self, input_nc=1, ngf=64, ndf=64, lr=0.001, gpu_ids=None, image_size=256):
        self.gpu_ids = gpu_ids
        self.input_nc = input_nc
        self.ngf = ngf
        self.ndf = ndf
        self.lr = lr
        self.image_size = image_size
        self.setup()

    def setup(self):
        self.netG = UNetGenerator(input_nc=self.input_nc, output_nc=self.input_nc, ngf=self.ngf)
        self.netD = Discriminator(input_nc=2 * self.input_nc, ndf=self.ndf, image_size=self.image_size)
        self.optimizer_G = optim.Adam(self.netG.parameters(), lr=self.lr, betas=(0.5, 0.999))
        self.optimizer_D = optim.Adam(self.netD.parameters(), lr=self.lr, betas=(0.5, 0.999))
        self.criterion = nn.BCEWithLogitsLoss()

    def set_input(self, real_A, real_B):
        self.real_A, self.real_B = real_A, real_B

    def forward(self):
        self.fake_B = self.netG(self.real_A)

    def backward_D(self):
        real_AB = torch.cat([self.real_A, self.real_B], 1)
        fake_AB = torch.cat([self.real_A, self.fake_B.detach()], 1)
        loss_real = self.criterion(self.netD(real_AB), torch.ones_like(self.netD(real_AB)))
        loss_fake = self.criterion(self.netD(fake_AB), torch.zeros_like(self.netD(fake_AB)))
        loss_D = (loss_real + loss_fake) * 0.5
        loss_D.backward()
        return loss_D

    def backward_G(self):
        fake_AB = torch.cat([self.real_A, self.fake_B], 1)
        loss_G = self.criterion(self.netD(fake_AB), torch.ones_like(self.netD(fake_AB)))
        loss_G.backward()
        return loss_G

    def optimize_parameters(self):
        self.forward()
        self.optimizer_D.zero_grad()
        loss_D = self.backward_D()
        self.optimizer_D.step()
        self.optimizer_G.zero_grad()
        loss_G = self.backward_G()
        self.optimizer_G.step()
        return loss_D, loss_G

# 訓練腳本
model = Zi2ZiModel()
train_dataset = DataLoader([], batch_size=1)  # 請替換成實際的 DataLoader
for epoch in range(10):  # 設定訓練 Epochs
    for batch in train_dataset:
        real_A, real_B = batch  # 假設 DataLoader 輸出格式為 (real_A, real_B)
        model.set_input(real_A, real_B)
        loss_D, loss_G = model.optimize_parameters()

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *