Cycle GAN 原理与implementation
ref:https://arxiv.org/abs/1703.10593v7
CycleGAN主要用来处理无监督场景下的domain transfer类型的任务,即将一个分布的图像转换到另一个分布(比如从马到斑马,从油画到照片类的风格转换)。训练好的CycleGAN可以得到两个模型,分别实现从A->B和B->A。
CycleGAN的效果主要来自于训练方式和loss函数的控制。cycle控制的思想在很多模型设计中都有应用,主要目的是利用正反两个方向的信息,增加模型的鲁棒性。在CycleGAN模型中,由于没有pair数据,因此只用GAN loss只能保证生产的图符合目标分布,而无法保证与source domain中的图像的内容和结构的一致性,因此用cycle的思路,将转换后的结果再转回source domain,保证内容基本一致,得到内容比较匹配的结果。
主要loss和训练过程:两个domain的生成结果的ganloss,保证迁移过去的结果符合domain分布;F(G(x))
与x
以及G(F(y))
与y
的consistency loss,保证生成内容的一致性;在实现过程中,一般还会加入identity loss,将B domain的图作为输入过一遍G(A->B)
的生成器,保证生成的结果和B domain的输入一致,同理对A domain也是。因为对于从A到B的generator来说,对于已经是B domain的输入,应该不进行转换,保持内容和风格一致。
以一个简单的实现为例,介绍cyclegan的具体训练过程。
参考代码:https://github.com/jzsherlock4869/cyclegan-pytorch
首先,确定两个方向forward的过程,首先是一个普通的GAN的生成器和判别器的计算,然后送入一个real_B
到AtoB的生成器中,后续计算idt loss,最后将fake_B
反向过BtoA生成器,计算cycle loss。这里同时还对先BA后AB的方向也算了一个cycle loss,但是不是用real_B
而是identity_B
(这两个应该被优化到相同)。
def forward_AtoB(self):
# set .train() and .eval() to each network
self.set_requires_grad(["netG_AB"], True)
self.set_requires_grad(["netG_BA", "netD_B", "netD_A"], False)
# discriminator forward
self.fake_B = self.netG_AB(self.real_A)
self.score_fake_B = self.netD_B(self.fake_B)
# identity forward
self.identity_B = self.netG_AB(self.real_B)
# cycle forwards and backwards
self.cycle_A = self.netG_BA(self.fake_B)
self.cycle_B = self.netG_AB(self.netG_BA(self.identity_B))
def forward_BtoA(self):
# set .train() and .eval() to each network
self.set_requires_grad(["netG_BA"], True)
self.set_requires_grad(["netG_AB", "netD_A", "netD_B"], False)
# discriminator forward
self.fake_A = self.netG_BA(self.real_B)
self.score_fake_A = self.netD_B(self.fake_A)
# identity forward
self.identity_A = self.netG_AB(self.real_A)
# cycle forwards and backwards
self.cycle_B = self.netG_AB(self.fake_A)
self.cycle_A = self.netG_BA(self.netG_AB(self.identity_A))
然后是参数更新部分,主要分为三个阶段:AtoB的生成器更新,此时只对正向网络求导;BtoA的生成器更新,此时只对反向网络求导;最后更新判别器D。
def optimize_parameters(self, step):
loss_dict = OrderedDict()
# ===================================================== #
# forward netG_AB and calc loss, while other nets frozen
loss_G_AB = 0
self.forward_AtoB()
# adv. loss for netG_AB in B domain
if self.losses.get("adv_B"):
self.set_requires_grad(["netD_B"], False)
gab_adv_loss = self.calculate_gan_loss_G(
self.netD_B, self.losses["adv_B"], self.fake_B
)
loss_dict["adv_B"] = gab_adv_loss.item()
loss_G_AB += self.loss_weights["adv_B"] * gab_adv_loss
# identity loss for netG_AB(B) and B
if self.losses.get("identity_B"):
gab_idt_loss = self.losses["identity_B"](self.identity_B, self.real_B)
loss_dict["identity_B"] = gab_idt_loss.item()
loss_G_AB += self.loss_weights["identity_B"] * gab_idt_loss
# cycle loss for netG_BA(netG_AB(A)) and B, AND netG_AB(netG_BA(B)) and A
if self.losses.get("cycle_AB"):
gab_cycle_loss = self.losses["cycle_AB"](self.cycle_B, self.real_B) + \
self.losses["cycle_AB"](self.cycle_A, self.real_A)
loss_dict["cycle_AB"] = gab_cycle_loss.item()
loss_G_AB += self.loss_weights["cycle_AB"] * gab_cycle_loss
self.set_optimizer(names=["netG_AB"], operation="zero_grad")
loss_G_AB.backward()
self.clip_grad_norm(names=["netG_AB"], norm=self.max_grad_norm)
self.set_optimizer(names=["netG_AB"], operation="step")
# ===================================================== #
# forward netG_BA and calc loss, while other nets frozen
loss_G_BA = 0
self.forward_BtoA()
# adv. loss for netG_BA in A domain
if self.losses.get("adv_A"):
self.set_requires_grad(["netD_A"], False)
gba_adv_loss = self.calculate_gan_loss_G(
self.netD_A, self.losses["adv_A"], self.fake_A
)
loss_dict["adv_A"] = gba_adv_loss.item()
loss_G_BA += self.loss_weights["adv_A"] * gba_adv_loss
# identity loss for netG_AB(B) and B
if self.losses.get("identity_A"):
gba_idt_loss = self.losses["identity_A"](self.identity_A, self.real_A)
loss_dict["identity_A"] = gba_idt_loss.item()
loss_G_BA += self.loss_weights["identity_A"] * gba_idt_loss
# cycle loss for netG_BA(netG_AB(A)) and B, AND netG_AB(netG_BA(B)) and A
if self.losses.get("cycle_BA"):
gba_cycle_loss = self.losses["cycle_BA"](self.cycle_B, self.real_B) + \
self.losses["cycle_BA"](self.cycle_A, self.real_A)
loss_dict["cycle_BA"] = gba_cycle_loss.item()
loss_G_BA += self.loss_weights["cycle_BA"] * gba_cycle_loss
self.set_optimizer(names=["netG_BA"], operation="zero_grad")
loss_G_BA.backward()
self.clip_grad_norm(names=["netG_BA"], norm=self.max_grad_norm)
self.set_optimizer(names=["netG_BA"], operation="step")
## update netD_A, netD_B
if self.losses.get("adv_B"):
if step % self.D_ratio == 0:
self.set_requires_grad(["netD_B"], True)
loss_D_B = self.calculate_gan_loss_D(
self.netD_B, self.losses["adv_B"], self.real_B,
self.fake_B_buffer.choose(self.fake_B.detach())
)
loss_dict["d_adv_B"] = loss_D_B.item()
loss_D_B = self.loss_weights["adv_B"] * loss_D_B
self.set_optimizer(names=["netD_B"], operation="zero_grad")
loss_D_B.backward()
self.clip_grad_norm(["netD_B"], norm=self.max_grad_norm)
self.set_optimizer(names=["netD_B"], operation="step")
if self.losses.get("adv_A"):
if step % self.D_ratio == 0:
self.set_requires_grad(["netD_A"], True)
loss_D_A = self.calculate_gan_loss_D(
self.netD_A, self.losses["adv_A"], self.real_A,
self.fake_A_buffer.choose(self.fake_A.detach())
)
loss_dict["d_adv_A"] = loss_D_A.item()
loss_D_A = self.loss_weights["adv_A"] * loss_D_A
self.set_optimizer(names=["netD_A"], operation="zero_grad")
loss_D_A.backward()
self.clip_grad_norm(["netD_A"], norm=self.max_grad_norm)
self.set_optimizer(names=["netD_A"], operation="step")
self.log_dict = loss_dict