Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization for Image-to-Image Translation
Authors: Junho Kim, Minjae Kim, Hyeonwoo Kang, Kwanghee Lee.
ICLR 2020
This folder provides a re-implementation of this paper in PyTorch, developed as part of the course METU CENG 796 - Deep Generative Models. The re-implementation is provided by:
Onur Can Üner, [email protected]
Sinan Gençoğlu, [email protected]
Please see the jupyter notebook file main.ipynb for a summary of paper, the implementation notes and our experimental results.
You can download the dataset using following bash command or using the link below. It will create directory: data/selfie2anime/..
sh scripts/download_data.sh
After downloading the dataset you can train your own model using command:
python main --config ./config/selfie2anime
You can make lots of changes using config json file.
If you want to test model with pre-trained weights:
sh scripts/download_checkpoint.sh
Import neccessary libraries
import json
import numpy as np
import matplotlib.pyplot as plt
from easydict import EasyDict as edict
import torch
import torchvision
from torchvision.utils import make_grid
import lib.models as models
import lib.data as datasets
from lib.utils import code_backup, load_checkpoint
Load the config file and model.
config = edict(json.load(open('configs/selfie2anime.json')))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model
checkpoint = './saved/UGATIT_selfie2anime/05-30_00-05/checkpoints/checkpoint-epoch42.pth'
model = getattr(models, config.arch.type)(**config.arch.args)
checkpoint = load_checkpoint(checkpoint)
del checkpoint
Create validation(test) set
val_set = getattr(datasets, config.datamanager.type)(
config.datamanager.root, config.datamanager.dataset_dir,
You can get sample images using G_AB generator:
images = []
for i, data in enumerate(val_set):
if i == 10:
real_A = data['A'].to(device).unsqueeze(0)
real_B = data['B'].to(device).unsqueeze(0)
_, fake_B = model.forward(real_A, real_B)
real_A = val_set.denormalize(real_A, device=device)
fake_B = val_set.denormalize(fake_B, device=device)
images.append(torch.cat((real_A.cpu().detach(), fake_B.cpu().detach()), dim=-1))
images = torch.cat(images, dim=-2)
images = make_grid(images, padding=100).numpy()
plt.figure(figsize = (128, 64))
plt.imshow(np.transpose(images, (1,2,0)), interpolation='nearest')