Skip to content

Latest commit

 

History

History
170 lines (124 loc) · 7.94 KB

File metadata and controls

170 lines (124 loc) · 7.94 KB

Cycle GAN 原理与implementation

ref:https://arxiv.org/abs/1703.10593v7

cycle_gan

背景与原理

CycleGAN主要用来处理无监督场景下的domain transfer类型的任务,即将一个分布的图像转换到另一个分布(比如从马到斑马,从油画到照片类的风格转换)。训练好的CycleGAN可以得到两个模型,分别实现从A->B和B->A。

CycleGAN的效果主要来自于训练方式和loss函数的控制。cycle控制的思想在很多模型设计中都有应用,主要目的是利用正反两个方向的信息,增加模型的鲁棒性。在CycleGAN模型中,由于没有pair数据,因此只用GAN loss只能保证生产的图符合目标分布,而无法保证与source domain中的图像的内容和结构的一致性,因此用cycle的思路,将转换后的结果再转回source domain,保证内容基本一致,得到内容比较匹配的结果。

主要loss和训练过程:两个domain的生成结果的ganloss,保证迁移过去的结果符合domain分布;F(G(x))x以及G(F(y))y的consistency loss,保证生成内容的一致性;在实现过程中,一般还会加入identity loss,将B domain的图作为输入过一遍G(A->B)的生成器,保证生成的结果和B domain的输入一致,同理对A domain也是。因为对于从A到B的generator来说,对于已经是B domain的输入,应该不进行转换,保持内容和风格一致。

code implementation

以一个简单的实现为例,介绍cyclegan的具体训练过程。

参考代码:https://github.com/jzsherlock4869/cyclegan-pytorch

首先,确定两个方向forward的过程,首先是一个普通的GAN的生成器和判别器的计算,然后送入一个real_B到AtoB的生成器中,后续计算idt loss,最后将fake_B反向过BtoA生成器,计算cycle loss。这里同时还对先BA后AB的方向也算了一个cycle loss,但是不是用real_B而是identity_B(这两个应该被优化到相同)。

    
  	def forward_AtoB(self):
        # set .train() and .eval() to each network
        self.set_requires_grad(["netG_AB"], True)
        self.set_requires_grad(["netG_BA", "netD_B", "netD_A"], False)
        # discriminator forward
        self.fake_B = self.netG_AB(self.real_A)
        self.score_fake_B = self.netD_B(self.fake_B)
        # identity forward
        self.identity_B = self.netG_AB(self.real_B)
        # cycle forwards and backwards
        self.cycle_A = self.netG_BA(self.fake_B)
        self.cycle_B = self.netG_AB(self.netG_BA(self.identity_B))

    def forward_BtoA(self):
        # set .train() and .eval() to each network
        self.set_requires_grad(["netG_BA"], True)
        self.set_requires_grad(["netG_AB", "netD_A", "netD_B"], False)
        # discriminator forward
        self.fake_A = self.netG_BA(self.real_B)
        self.score_fake_A = self.netD_B(self.fake_A)
        # identity forward
        self.identity_A = self.netG_AB(self.real_A)
        # cycle forwards and backwards
        self.cycle_B = self.netG_AB(self.fake_A)
        self.cycle_A = self.netG_BA(self.netG_AB(self.identity_A))

然后是参数更新部分,主要分为三个阶段:AtoB的生成器更新,此时只对正向网络求导;BtoA的生成器更新,此时只对反向网络求导;最后更新判别器D。

	def optimize_parameters(self, step):

        loss_dict = OrderedDict()
        # ===================================================== #
        # forward netG_AB and calc loss, while other nets frozen
        loss_G_AB = 0
        self.forward_AtoB()
        # adv. loss for netG_AB in B domain
        if self.losses.get("adv_B"):
            self.set_requires_grad(["netD_B"], False)
            gab_adv_loss = self.calculate_gan_loss_G(
                self.netD_B, self.losses["adv_B"], self.fake_B
                )
            loss_dict["adv_B"] = gab_adv_loss.item()
            loss_G_AB += self.loss_weights["adv_B"] * gab_adv_loss
        
        # identity loss for netG_AB(B) and B
        if self.losses.get("identity_B"):
            gab_idt_loss = self.losses["identity_B"](self.identity_B, self.real_B)
            loss_dict["identity_B"] = gab_idt_loss.item()
            loss_G_AB += self.loss_weights["identity_B"] * gab_idt_loss

        # cycle loss for netG_BA(netG_AB(A)) and B, AND netG_AB(netG_BA(B)) and A
        if self.losses.get("cycle_AB"):
            gab_cycle_loss = self.losses["cycle_AB"](self.cycle_B, self.real_B) + \
                                self.losses["cycle_AB"](self.cycle_A, self.real_A)
            loss_dict["cycle_AB"] = gab_cycle_loss.item()
            loss_G_AB += self.loss_weights["cycle_AB"] * gab_cycle_loss

        self.set_optimizer(names=["netG_AB"], operation="zero_grad")
        loss_G_AB.backward()
        self.clip_grad_norm(names=["netG_AB"], norm=self.max_grad_norm)
        self.set_optimizer(names=["netG_AB"], operation="step")

        # ===================================================== #
        # forward netG_BA and calc loss, while other nets frozen
        loss_G_BA = 0
        self.forward_BtoA()
        # adv. loss for netG_BA in A domain
        if self.losses.get("adv_A"):
            self.set_requires_grad(["netD_A"], False)
            gba_adv_loss = self.calculate_gan_loss_G(
                self.netD_A, self.losses["adv_A"], self.fake_A
                )
            loss_dict["adv_A"] = gba_adv_loss.item()
            loss_G_BA += self.loss_weights["adv_A"] * gba_adv_loss

        # identity loss for netG_AB(B) and B
        if self.losses.get("identity_A"):
            gba_idt_loss = self.losses["identity_A"](self.identity_A, self.real_A)
            loss_dict["identity_A"] = gba_idt_loss.item()
            loss_G_BA += self.loss_weights["identity_A"] * gba_idt_loss

        # cycle loss for netG_BA(netG_AB(A)) and B, AND netG_AB(netG_BA(B)) and A
        if self.losses.get("cycle_BA"):
            gba_cycle_loss = self.losses["cycle_BA"](self.cycle_B, self.real_B) + \
                                self.losses["cycle_BA"](self.cycle_A, self.real_A)
            loss_dict["cycle_BA"] = gba_cycle_loss.item()
            loss_G_BA += self.loss_weights["cycle_BA"] * gba_cycle_loss

        self.set_optimizer(names=["netG_BA"], operation="zero_grad")
        loss_G_BA.backward()
        self.clip_grad_norm(names=["netG_BA"], norm=self.max_grad_norm)
        self.set_optimizer(names=["netG_BA"], operation="step")

        ## update netD_A, netD_B
        if self.losses.get("adv_B"):
            if step % self.D_ratio == 0:
                self.set_requires_grad(["netD_B"], True)
                loss_D_B = self.calculate_gan_loss_D(
                    self.netD_B, self.losses["adv_B"], self.real_B,
                    self.fake_B_buffer.choose(self.fake_B.detach())
                )
                loss_dict["d_adv_B"] = loss_D_B.item()
                loss_D_B = self.loss_weights["adv_B"] * loss_D_B

                self.set_optimizer(names=["netD_B"], operation="zero_grad")
                loss_D_B.backward()
                self.clip_grad_norm(["netD_B"], norm=self.max_grad_norm)
                self.set_optimizer(names=["netD_B"], operation="step")

        if self.losses.get("adv_A"):
            if step % self.D_ratio == 0:
                self.set_requires_grad(["netD_A"], True)
                loss_D_A = self.calculate_gan_loss_D(
                    self.netD_A, self.losses["adv_A"], self.real_A,
                    self.fake_A_buffer.choose(self.fake_A.detach())
                )
                loss_dict["d_adv_A"] = loss_D_A.item()
                loss_D_A = self.loss_weights["adv_A"] * loss_D_A

                self.set_optimizer(names=["netD_A"], operation="zero_grad")
                loss_D_A.backward()
                self.clip_grad_norm(["netD_A"], norm=self.max_grad_norm)
                self.set_optimizer(names=["netD_A"], operation="step")

        self.log_dict = loss_dict