Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
LiDan456 authored Jan 9, 2019
1 parent f8afe19 commit 7304d6c
Show file tree
Hide file tree
Showing 3 changed files with 319 additions and 0 deletions.
190 changes: 190 additions & 0 deletions RGAN.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import numpy as np
import tensorflow as tf
import pdb
import random
import json
from scipy.stats import mode

import data_utils
import plotting
import model
import utils
import eval
import DR_discriminator

from time import time
from math import floor
from mmd import rbf_mmd2, median_pairwise_distance, mix_rbf_mmd2_and_ratio

begin = time()

tf.logging.set_verbosity(tf.logging.ERROR)

# --- get settings --- #
# parse command line arguments, or use defaults
parser = utils.rgan_options_parser()
settings = vars(parser.parse_args())
# if a settings file is specified, it overrides command line arguments/defaults
if settings['settings_file']: settings = utils.load_settings_from_file(settings)

# --- get data, split --- #
# samples, pdf, labels = data_utils.get_data(settings)
data_path = './experiments/data/' + settings['data_load_from'] + '.data.npy'
print('Loading data from', data_path)
settings["eval_an"] = False
settings["eval_single"] = False
samples, labels, index = data_utils.get_data(settings["data"], settings["seq_length"], settings["seq_step"],
settings["num_signals"], settings['sub_id'], settings["eval_single"],
settings["eval_an"], data_path)
print('samples_size:',samples.shape)
# -- number of variables -- #
num_variables = samples.shape[2]
print('num_variables:', num_variables)
# --- save settings, data --- #
print('Ready to run with settings:')
for (k, v) in settings.items(): print(v, '\t', k)
# add the settings to local environment
# WARNING: at this point a lot of variables appear
locals().update(settings)
json.dump(settings, open('./experiments/settings/' + identifier + '.txt', 'w'), indent=0)

# --- build model --- #
# preparation: data placeholders and model parameters
Z, X, T = model.create_placeholders(batch_size, seq_length, latent_dim, num_variables)
discriminator_vars = ['hidden_units_d', 'seq_length', 'batch_size', 'batch_mean']
discriminator_settings = dict((k, settings[k]) for k in discriminator_vars)
generator_vars = ['hidden_units_g', 'seq_length', 'batch_size', 'learn_scale']
generator_settings = dict((k, settings[k]) for k in generator_vars)
generator_settings['num_signals'] = num_variables

# model: GAN losses
D_loss, G_loss = model.GAN_loss(Z, X, generator_settings, discriminator_settings)
D_solver, G_solver, priv_accountant = model.GAN_solvers(D_loss, G_loss, learning_rate, batch_size,
total_examples=samples.shape[0],
l2norm_bound=l2norm_bound,
batches_per_lot=batches_per_lot, sigma=dp_sigma, dp=dp)
# model: generate samples for visualization
G_sample = model.generator(Z, **generator_settings, reuse=True)


# # --- evaluation settings--- #
#
# # frequency to do visualisations
# num_samples = samples.shape[0]
# vis_freq = max(6600 // num_samples, 1)
# eval_freq = max(6600// num_samples, 1)
#
# # get heuristic bandwidth for mmd kernel from evaluation samples
# heuristic_sigma_training = median_pairwise_distance(samples)
# best_mmd2_so_far = 1000
#
# # optimise sigma using that (that's t-hat)
# batch_multiplier = 5000 // batch_size
# eval_size = batch_multiplier * batch_size
# eval_eval_size = int(0.2 * eval_size)
# eval_real_PH = tf.placeholder(tf.float32, [eval_eval_size, seq_length, num_generated_features])
# eval_sample_PH = tf.placeholder(tf.float32, [eval_eval_size, seq_length, num_generated_features])
# n_sigmas = 2
# sigma = tf.get_variable(name='sigma', shape=n_sigmas, initializer=tf.constant_initializer(
# value=np.power(heuristic_sigma_training, np.linspace(-1, 3, num=n_sigmas))))
# mmd2, that = mix_rbf_mmd2_and_ratio(eval_real_PH, eval_sample_PH, sigma)
# with tf.variable_scope("SIGMA_optimizer"):
# sigma_solver = tf.train.RMSPropOptimizer(learning_rate=0.05).minimize(-that, var_list=[sigma])
# # sigma_solver = tf.train.AdamOptimizer().minimize(-that, var_list=[sigma])
# # sigma_solver = tf.train.AdagradOptimizer(learning_rate=0.1).minimize(-that, var_list=[sigma])
# sigma_opt_iter = 2000
# sigma_opt_thresh = 0.001
# sigma_opt_vars = [var for var in tf.global_variables() if 'SIGMA_optimizer' in var.name]


# --- run the program --- #
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
# sess = tf.Session()
sess.run(tf.global_variables_initializer())

# # -- plot the real samples -- #
vis_real_indices = np.random.choice(len(samples), size=16)
vis_real = np.float32(samples[vis_real_indices, :, :])
plotting.save_plot_sample(vis_real, 0, identifier + '_real', n_samples=16, num_epochs=num_epochs)
plotting.save_samples_real(vis_real, identifier)

# --- train --- #
train_vars = ['batch_size', 'D_rounds', 'G_rounds', 'use_time', 'seq_length', 'latent_dim']
train_settings = dict((k, settings[k]) for k in train_vars)
train_settings['num_signals'] = num_variables

t0 = time()
MMD = np.zeros([num_epochs, ])

for epoch in range(num_epochs):
# for epoch in range(1):
# -- train epoch -- #
D_loss_curr, G_loss_curr = model.train_epoch(epoch, samples, labels, sess, Z, X, D_loss, G_loss,
D_solver, G_solver, **train_settings)

# # -- eval -- #
# # visualise plots of generated samples, with/without labels
# # choose which epoch to visualize
#
# # random input vectors for the latent space, as the inputs of generator
# vis_ZZ = model.sample_Z(batch_size, seq_length, latent_dim, use_time)
#
# # # -- generate samples-- #
# vis_sample = sess.run(G_sample, feed_dict={Z: vis_ZZ})
# # # -- visualize the generated samples -- #
# plotting.save_plot_sample(vis_sample, epoch, identifier, n_samples=16, num_epochs=None, ncol=4)
# # plotting.save_plot_sample(vis_sample, 0, identifier + '_real', n_samples=16, num_epochs=num_epochs)
# # # save the generated samples in cased they might be useful for comparison
# plotting.save_samples(vis_sample, identifier, epoch)

# -- print -- #
print('epoch, D_loss_curr, G_loss_curr, seq_length')
print('%d\t%.4f\t%.4f\t%d' % (epoch, D_loss_curr, G_loss_curr, seq_length))

# # -- compute mmd2 and if available, prob density -- #
# if epoch % eval_freq == 0:
# # how many samples to evaluate with?
# eval_Z = model.sample_Z(eval_size, seq_length, latent_dim, use_time)
# eval_sample = np.empty(shape=(eval_size, seq_length, num_signals))
# for i in range(batch_multiplier):
# eval_sample[i * batch_size:(i + 1) * batch_size, :, :] = sess.run(G_sample, feed_dict={ Z: eval_Z[i * batch_size:(i + 1) * batch_size]})
# eval_sample = np.float32(eval_sample)
# eval_real = np.float32(samples['vali'][np.random.choice(len(samples['vali']), size=batch_multiplier * batch_size), :, :])
#
# eval_eval_real = eval_real[:eval_eval_size]
# eval_test_real = eval_real[eval_eval_size:]
# eval_eval_sample = eval_sample[:eval_eval_size]
# eval_test_sample = eval_sample[eval_eval_size:]
#
# # MMD
# # reset ADAM variables
# sess.run(tf.initialize_variables(sigma_opt_vars))
# sigma_iter = 0
# that_change = sigma_opt_thresh * 2
# old_that = 0
# while that_change > sigma_opt_thresh and sigma_iter < sigma_opt_iter:
# new_sigma, that_np, _ = sess.run([sigma, that, sigma_solver],
# feed_dict={eval_real_PH: eval_eval_real, eval_sample_PH: eval_eval_sample})
# that_change = np.abs(that_np - old_that)
# old_that = that_np
# sigma_iter += 1
# opt_sigma = sess.run(sigma)
# try:
# mmd2, that_np = sess.run(mix_rbf_mmd2_and_ratio(eval_test_real, eval_test_sample, biased=False, sigmas=sigma))
# except ValueError:
# mmd2 = 'NA'
# that = 'NA'
#
# MMD[epoch, ] = mmd2

# -- save model parameters -- #
model.dump_parameters(sub_id + '_' + str(seq_length) + '_' + str(epoch), sess)

np.save('./experiments/plots/gs/' + identifier + '_' + 'MMD.npy', MMD)

end = time() - begin
print('Training terminated | Training time=%d s' %(end) )

print("Training terminated | training time = %ds " % (time() - begin))
21 changes: 21 additions & 0 deletions tf_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
### from https://github.com/eugenium/MMD/blob/master/tf_ops.py
import tensorflow as tf


def sq_sum(t, name=None):
"The squared Frobenius-type norm of a tensor, sum(t ** 2)."
with tf.name_scope(name, "SqSum", [t]):
t = tf.convert_to_tensor(t, name='t')
return 2 * tf.nn.l2_loss(t)


def dot(x, y, name=None):
"The dot product of two vectors x and y."
with tf.name_scope(name, "Dot", [x, y]):
x = tf.convert_to_tensor(x, name='x')
y = tf.convert_to_tensor(y, name='y')

x.get_shape().assert_has_rank(1)
y.get_shape().assert_has_rank(1)

return tf.squeeze(tf.matmul(tf.expand_dims(x, 0), tf.expand_dims(y, 1)))
108 changes: 108 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#!/usr/bin/env ipython
# Utility functions that don't fit in other scripts
import argparse
import json

def rgan_options_parser():
"""
Define parser to parse options from command line, with defaults.
Refer to this function for definitions of various variables.
"""
parser = argparse.ArgumentParser(description='Train a GAN to generate sequential, real-valued data.')
# meta-option
parser.add_argument('--settings_file', help='json file of settings, overrides everything else', type=str, default='')
# options pertaining to data
parser.add_argument('--data', help='what kind of data to train with?',
default='gp_rbf',
choices=['gp_rbf', 'sine', 'mnist', 'load'])
# parser.add_argument('--num_samples', type=int, help='how many training examples \
# to generate?', default=28*5*100)
# parser.add_argument('--num_samples_t', type=int, help='how many testing examples \
# for anomaly detection?', default=28 * 5 * 100)
parser.add_argument('--seq_length', type=int, default=30)
parser.add_argument('--num_signals', type=int, default=1)
parser.add_argument('--normalise', type=bool, default=False, help='normalise the \
training/vali/test data (during split)?')
# parser.add_argument('--AD', type=bool, default=False, help='should we conduct anomaly detection?')

### for gp_rbf
parser.add_argument('--scale', type=float, default=0.1)
### for sin (should be using subparsers for this...)
parser.add_argument('--freq_low', type=float, default=1.0)
parser.add_argument('--freq_high', type=float, default=5.0)
parser.add_argument('--amplitude_low', type=float, default=0.1)
parser.add_argument('--amplitude_high', type=float, default=0.9)
### for mnist
parser.add_argument('--multivariate_mnist', type=bool, default=False)
parser.add_argument('--full_mnist', type=bool, default=False)
### for loading
parser.add_argument('--data_load_from', type=str, default='')
### for eICU
parser.add_argument('--resample_rate_in_min', type=int, default=15)
# hyperparameters of the model
parser.add_argument('--hidden_units_g', type=int, default=100)
parser.add_argument('--hidden_units_d', type=int, default=100)
parser.add_argument('--hidden_units_e', type=int, default=100)
parser.add_argument('--kappa', type=float, help='weight between final output \
and intermediate steps in discriminator cost (1 = all \
intermediate', default=1)
parser.add_argument('--latent_dim', type=int, default=5, help='dimensionality \
of the latent/noise space')
parser.add_argument('--weight', type=int, default=0.5, help='weight of score')
parser.add_argument('--degree', type=int, default=1, help='norm degree')
parser.add_argument('--batch_mean', type=bool, default=False, help='append the mean \
of the batch to all variables for calculating discriminator loss')
parser.add_argument('--learn_scale', type=bool, default=False, help='make the \
"scale" parameter at the output of the generator learnable (else fixed \
to 1')
# options pertaining to training
parser.add_argument('--learning_rate', type=float, default=0.1)
parser.add_argument('--batch_size', type=int, default=28)
parser.add_argument('--num_epochs', type=int, default=100)
parser.add_argument('--D_rounds', type=int, default=5, help='number of rounds \
of discriminator training')
parser.add_argument('--G_rounds', type=int, default=1, help='number of rounds \
of generator training')
parser.add_argument('--E_rounds', type=int, default=1, help='number of rounds \
of encoder training')
# parser.add_argument('--use_time', type=bool, default=False, help='enforce \
# latent dimension 0 to correspond to time')
parser.add_argument('--shuffle', type=bool, default=True)
parser.add_argument('--eval_mul', type=bool, default=False)
parser.add_argument('--eval_an', type=bool, default=False)
parser.add_argument('--eval_single', type=bool, default=False)
parser.add_argument('--wrong_labels', type=bool, default=False, help='augment \
discriminator loss with real examples with wrong (~shuffled, sort of) labels')
# options pertaining to evaluation and exploration
parser.add_argument('--identifier', type=str, default='test', help='identifier \
string for output files')
parser.add_argument('--sub_id', type=str, default='test', help='identifier \
string for load parameters')
# options pertaining to differential privacy
parser.add_argument('--dp', type=bool, default=False, help='train discriminator \
with differentially private SGD?')
parser.add_argument('--l2norm_bound', type=float, default=1e-5,
help='bound on norm of individual gradients for DP training')
parser.add_argument('--batches_per_lot', type=int, default=1,
help='number of batches per lot (for DP)')
parser.add_argument('--dp_sigma', type=float, default=1e-5,
help='sigma for noise added (for DP)')

return parser

def load_settings_from_file(settings):
"""
Handle loading settings from a JSON file, filling in missing settings from
the command line defaults, but otherwise overwriting them.
"""
settings_path = './experiments/settings/' + settings['settings_file'] + '.txt'
print('Loading settings from', settings_path)
settings_loaded = json.load(open(settings_path, 'r'))
# check for settings missing in file
for key in settings.keys():
if not key in settings_loaded:
print(key, 'not found in loaded settings - adopting value from command line defaults: ', settings[key])
# overwrite parsed/default settings with those read from file, allowing for
# (potentially new) default settings not present in file
settings.update(settings_loaded)
return settings

0 comments on commit 7304d6c

Please sign in to comment.