This repository is a Tensorflow implementation of Martin Arjovsky's Wasserstein GAN, arXiv:1701.07875v3.
- tensorflow 1.9.0
- python 3.5.3
- numpy 1.14.2
- pillow 5.0.0
- scipy 0.19.0
- matplotlib 2.2.2
- Generator (DCGAN)
- Critic (DCGAN)
- MNIST
- CelebA
MNIST dataset will be downloaded automatically if in a specific folder there are no dataset. Use the following command to download CelebA
dataset and copy the `CelebA' dataset on the corresponding file as introduced in Directory Hierarchy information.
python download2.py celebA
.
│ WGAN
│ ├── src
│ │ ├── dataset.py
│ │ ├── download2.py
│ │ ├── main.py
│ │ ├── solver.py
│ │ ├── tensorflow_utils.py
│ │ ├── utils.py
│ │ └── wgan.py
│ Data
│ ├── celebA
│ └── mnist
src: source codes of the WGAN
Implementation uses TensorFlow to train the WGAN. Same generator and critic networks are used as described in Alec Radford's paper. WGAN does not use a sigmoid function in the last layer of the critic, a log-likelihood in the cost function. Optimizer is used RMSProp instead of Adam.
Use main.py
to train a WGAN network. Example usage:
python main.py --is_train=true --dataset=[celebA|mnist]
gpu_index
: gpu index, default:0
batch_size
: batch size for one feed forward, default:64
dataset
: dataset name for choice [celebA|mnist], default:celebA
is_train
: training or inference mode, default:False
learning_rate
: initial learning rate, default:0.00005
num_critic
: the number of iterations of the critic per generator iteration, default:5
z_dim
: dimension of z vector, default:100
iters
: number of interations, default:100000
print_freq
: print frequency for loss, default:50
save_freq
: save frequency for model, default:10000
sample_freq
: sample frequency for saving image, default:200
sample_size
: sample size for check generated image quality, default:64
load_model
: folder of save model that you wish to test, (e.g. 20180704-1736). default:None
- MNIST
- CelebA
Use main.py
to evaluate a WGAN network. Example usage:
python main.py --is_train=false --load_model=folder/you/wish/to/test/e.g./20180704-1746
Please refer to the above arguments.
@misc{chengbinjin2018wgan,
author = {Cheng-Bin Jin},
title = {WGAN-tensorflow},
year = {2018},
howpublished = {\url{https://github.com/ChengBinJin/WGAN-TensorFlow}},
note = {commit xxxxxxx}
}
- This project borrowed some code from wiseodd
- Some readme formatting was borrowed from Logan Engstrom
Copyright (c) 2018 Cheng-Bin Jin. Contact me for commercial use (or rather any use that is not academic research) (email: [email protected]). Free for research use, as long as proper attribution is given and this copyright notice is retained.