diff --git a/RGAN.py b/RGAN.py new file mode 100644 index 0000000..f7a1c2e --- /dev/null +++ b/RGAN.py @@ -0,0 +1,190 @@ +import numpy as np +import tensorflow as tf +import pdb +import random +import json +from scipy.stats import mode + +import data_utils +import plotting +import model +import utils +import eval +import DR_discriminator + +from time import time +from math import floor +from mmd import rbf_mmd2, median_pairwise_distance, mix_rbf_mmd2_and_ratio + +begin = time() + +tf.logging.set_verbosity(tf.logging.ERROR) + +# --- get settings --- # +# parse command line arguments, or use defaults +parser = utils.rgan_options_parser() +settings = vars(parser.parse_args()) +# if a settings file is specified, it overrides command line arguments/defaults +if settings['settings_file']: settings = utils.load_settings_from_file(settings) + +# --- get data, split --- # +# samples, pdf, labels = data_utils.get_data(settings) +data_path = './experiments/data/' + settings['data_load_from'] + '.data.npy' +print('Loading data from', data_path) +settings["eval_an"] = False +settings["eval_single"] = False +samples, labels, index = data_utils.get_data(settings["data"], settings["seq_length"], settings["seq_step"], + settings["num_signals"], settings['sub_id'], settings["eval_single"], + settings["eval_an"], data_path) +print('samples_size:',samples.shape) +# -- number of variables -- # +num_variables = samples.shape[2] +print('num_variables:', num_variables) +# --- save settings, data --- # +print('Ready to run with settings:') +for (k, v) in settings.items(): print(v, '\t', k) +# add the settings to local environment +# WARNING: at this point a lot of variables appear +locals().update(settings) +json.dump(settings, open('./experiments/settings/' + identifier + '.txt', 'w'), indent=0) + +# --- build model --- # +# preparation: data placeholders and model parameters +Z, X, T = model.create_placeholders(batch_size, seq_length, latent_dim, num_variables) +discriminator_vars = ['hidden_units_d', 'seq_length', 'batch_size', 'batch_mean'] +discriminator_settings = dict((k, settings[k]) for k in discriminator_vars) +generator_vars = ['hidden_units_g', 'seq_length', 'batch_size', 'learn_scale'] +generator_settings = dict((k, settings[k]) for k in generator_vars) +generator_settings['num_signals'] = num_variables + +# model: GAN losses +D_loss, G_loss = model.GAN_loss(Z, X, generator_settings, discriminator_settings) +D_solver, G_solver, priv_accountant = model.GAN_solvers(D_loss, G_loss, learning_rate, batch_size, + total_examples=samples.shape[0], + l2norm_bound=l2norm_bound, + batches_per_lot=batches_per_lot, sigma=dp_sigma, dp=dp) +# model: generate samples for visualization +G_sample = model.generator(Z, **generator_settings, reuse=True) + + +# # --- evaluation settings--- # +# +# # frequency to do visualisations +# num_samples = samples.shape[0] +# vis_freq = max(6600 // num_samples, 1) +# eval_freq = max(6600// num_samples, 1) +# +# # get heuristic bandwidth for mmd kernel from evaluation samples +# heuristic_sigma_training = median_pairwise_distance(samples) +# best_mmd2_so_far = 1000 +# +# # optimise sigma using that (that's t-hat) +# batch_multiplier = 5000 // batch_size +# eval_size = batch_multiplier * batch_size +# eval_eval_size = int(0.2 * eval_size) +# eval_real_PH = tf.placeholder(tf.float32, [eval_eval_size, seq_length, num_generated_features]) +# eval_sample_PH = tf.placeholder(tf.float32, [eval_eval_size, seq_length, num_generated_features]) +# n_sigmas = 2 +# sigma = tf.get_variable(name='sigma', shape=n_sigmas, initializer=tf.constant_initializer( +# value=np.power(heuristic_sigma_training, np.linspace(-1, 3, num=n_sigmas)))) +# mmd2, that = mix_rbf_mmd2_and_ratio(eval_real_PH, eval_sample_PH, sigma) +# with tf.variable_scope("SIGMA_optimizer"): +# sigma_solver = tf.train.RMSPropOptimizer(learning_rate=0.05).minimize(-that, var_list=[sigma]) +# # sigma_solver = tf.train.AdamOptimizer().minimize(-that, var_list=[sigma]) +# # sigma_solver = tf.train.AdagradOptimizer(learning_rate=0.1).minimize(-that, var_list=[sigma]) +# sigma_opt_iter = 2000 +# sigma_opt_thresh = 0.001 +# sigma_opt_vars = [var for var in tf.global_variables() if 'SIGMA_optimizer' in var.name] + + +# --- run the program --- # +config = tf.ConfigProto() +config.gpu_options.allow_growth = True +sess = tf.Session(config=config) +# sess = tf.Session() +sess.run(tf.global_variables_initializer()) + +# # -- plot the real samples -- # +vis_real_indices = np.random.choice(len(samples), size=16) +vis_real = np.float32(samples[vis_real_indices, :, :]) +plotting.save_plot_sample(vis_real, 0, identifier + '_real', n_samples=16, num_epochs=num_epochs) +plotting.save_samples_real(vis_real, identifier) + +# --- train --- # +train_vars = ['batch_size', 'D_rounds', 'G_rounds', 'use_time', 'seq_length', 'latent_dim'] +train_settings = dict((k, settings[k]) for k in train_vars) +train_settings['num_signals'] = num_variables + +t0 = time() +MMD = np.zeros([num_epochs, ]) + +for epoch in range(num_epochs): +# for epoch in range(1): + # -- train epoch -- # + D_loss_curr, G_loss_curr = model.train_epoch(epoch, samples, labels, sess, Z, X, D_loss, G_loss, + D_solver, G_solver, **train_settings) + + # # -- eval -- # + # # visualise plots of generated samples, with/without labels + # # choose which epoch to visualize + # + # # random input vectors for the latent space, as the inputs of generator + # vis_ZZ = model.sample_Z(batch_size, seq_length, latent_dim, use_time) + # + # # # -- generate samples-- # + # vis_sample = sess.run(G_sample, feed_dict={Z: vis_ZZ}) + # # # -- visualize the generated samples -- # + # plotting.save_plot_sample(vis_sample, epoch, identifier, n_samples=16, num_epochs=None, ncol=4) + # # plotting.save_plot_sample(vis_sample, 0, identifier + '_real', n_samples=16, num_epochs=num_epochs) + # # # save the generated samples in cased they might be useful for comparison + # plotting.save_samples(vis_sample, identifier, epoch) + + # -- print -- # + print('epoch, D_loss_curr, G_loss_curr, seq_length') + print('%d\t%.4f\t%.4f\t%d' % (epoch, D_loss_curr, G_loss_curr, seq_length)) + + # # -- compute mmd2 and if available, prob density -- # + # if epoch % eval_freq == 0: + # # how many samples to evaluate with? + # eval_Z = model.sample_Z(eval_size, seq_length, latent_dim, use_time) + # eval_sample = np.empty(shape=(eval_size, seq_length, num_signals)) + # for i in range(batch_multiplier): + # eval_sample[i * batch_size:(i + 1) * batch_size, :, :] = sess.run(G_sample, feed_dict={ Z: eval_Z[i * batch_size:(i + 1) * batch_size]}) + # eval_sample = np.float32(eval_sample) + # eval_real = np.float32(samples['vali'][np.random.choice(len(samples['vali']), size=batch_multiplier * batch_size), :, :]) + # + # eval_eval_real = eval_real[:eval_eval_size] + # eval_test_real = eval_real[eval_eval_size:] + # eval_eval_sample = eval_sample[:eval_eval_size] + # eval_test_sample = eval_sample[eval_eval_size:] + # + # # MMD + # # reset ADAM variables + # sess.run(tf.initialize_variables(sigma_opt_vars)) + # sigma_iter = 0 + # that_change = sigma_opt_thresh * 2 + # old_that = 0 + # while that_change > sigma_opt_thresh and sigma_iter < sigma_opt_iter: + # new_sigma, that_np, _ = sess.run([sigma, that, sigma_solver], + # feed_dict={eval_real_PH: eval_eval_real, eval_sample_PH: eval_eval_sample}) + # that_change = np.abs(that_np - old_that) + # old_that = that_np + # sigma_iter += 1 + # opt_sigma = sess.run(sigma) + # try: + # mmd2, that_np = sess.run(mix_rbf_mmd2_and_ratio(eval_test_real, eval_test_sample, biased=False, sigmas=sigma)) + # except ValueError: + # mmd2 = 'NA' + # that = 'NA' + # + # MMD[epoch, ] = mmd2 + + # -- save model parameters -- # + model.dump_parameters(sub_id + '_' + str(seq_length) + '_' + str(epoch), sess) + +np.save('./experiments/plots/gs/' + identifier + '_' + 'MMD.npy', MMD) + +end = time() - begin +print('Training terminated | Training time=%d s' %(end) ) + +print("Training terminated | training time = %ds " % (time() - begin)) \ No newline at end of file diff --git a/tf_ops.py b/tf_ops.py new file mode 100644 index 0000000..fdf13ff --- /dev/null +++ b/tf_ops.py @@ -0,0 +1,21 @@ +### from https://github.com/eugenium/MMD/blob/master/tf_ops.py +import tensorflow as tf + + +def sq_sum(t, name=None): + "The squared Frobenius-type norm of a tensor, sum(t ** 2)." + with tf.name_scope(name, "SqSum", [t]): + t = tf.convert_to_tensor(t, name='t') + return 2 * tf.nn.l2_loss(t) + + +def dot(x, y, name=None): + "The dot product of two vectors x and y." + with tf.name_scope(name, "Dot", [x, y]): + x = tf.convert_to_tensor(x, name='x') + y = tf.convert_to_tensor(y, name='y') + + x.get_shape().assert_has_rank(1) + y.get_shape().assert_has_rank(1) + + return tf.squeeze(tf.matmul(tf.expand_dims(x, 0), tf.expand_dims(y, 1))) diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..49278e8 --- /dev/null +++ b/utils.py @@ -0,0 +1,108 @@ +#!/usr/bin/env ipython +# Utility functions that don't fit in other scripts +import argparse +import json + +def rgan_options_parser(): + """ + Define parser to parse options from command line, with defaults. + Refer to this function for definitions of various variables. + """ + parser = argparse.ArgumentParser(description='Train a GAN to generate sequential, real-valued data.') + # meta-option + parser.add_argument('--settings_file', help='json file of settings, overrides everything else', type=str, default='') + # options pertaining to data + parser.add_argument('--data', help='what kind of data to train with?', + default='gp_rbf', + choices=['gp_rbf', 'sine', 'mnist', 'load']) + # parser.add_argument('--num_samples', type=int, help='how many training examples \ + # to generate?', default=28*5*100) + # parser.add_argument('--num_samples_t', type=int, help='how many testing examples \ + # for anomaly detection?', default=28 * 5 * 100) + parser.add_argument('--seq_length', type=int, default=30) + parser.add_argument('--num_signals', type=int, default=1) + parser.add_argument('--normalise', type=bool, default=False, help='normalise the \ + training/vali/test data (during split)?') + # parser.add_argument('--AD', type=bool, default=False, help='should we conduct anomaly detection?') + + ### for gp_rbf + parser.add_argument('--scale', type=float, default=0.1) + ### for sin (should be using subparsers for this...) + parser.add_argument('--freq_low', type=float, default=1.0) + parser.add_argument('--freq_high', type=float, default=5.0) + parser.add_argument('--amplitude_low', type=float, default=0.1) + parser.add_argument('--amplitude_high', type=float, default=0.9) + ### for mnist + parser.add_argument('--multivariate_mnist', type=bool, default=False) + parser.add_argument('--full_mnist', type=bool, default=False) + ### for loading + parser.add_argument('--data_load_from', type=str, default='') + ### for eICU + parser.add_argument('--resample_rate_in_min', type=int, default=15) + # hyperparameters of the model + parser.add_argument('--hidden_units_g', type=int, default=100) + parser.add_argument('--hidden_units_d', type=int, default=100) + parser.add_argument('--hidden_units_e', type=int, default=100) + parser.add_argument('--kappa', type=float, help='weight between final output \ + and intermediate steps in discriminator cost (1 = all \ + intermediate', default=1) + parser.add_argument('--latent_dim', type=int, default=5, help='dimensionality \ + of the latent/noise space') + parser.add_argument('--weight', type=int, default=0.5, help='weight of score') + parser.add_argument('--degree', type=int, default=1, help='norm degree') + parser.add_argument('--batch_mean', type=bool, default=False, help='append the mean \ + of the batch to all variables for calculating discriminator loss') + parser.add_argument('--learn_scale', type=bool, default=False, help='make the \ + "scale" parameter at the output of the generator learnable (else fixed \ + to 1') + # options pertaining to training + parser.add_argument('--learning_rate', type=float, default=0.1) + parser.add_argument('--batch_size', type=int, default=28) + parser.add_argument('--num_epochs', type=int, default=100) + parser.add_argument('--D_rounds', type=int, default=5, help='number of rounds \ + of discriminator training') + parser.add_argument('--G_rounds', type=int, default=1, help='number of rounds \ + of generator training') + parser.add_argument('--E_rounds', type=int, default=1, help='number of rounds \ + of encoder training') + # parser.add_argument('--use_time', type=bool, default=False, help='enforce \ + # latent dimension 0 to correspond to time') + parser.add_argument('--shuffle', type=bool, default=True) + parser.add_argument('--eval_mul', type=bool, default=False) + parser.add_argument('--eval_an', type=bool, default=False) + parser.add_argument('--eval_single', type=bool, default=False) + parser.add_argument('--wrong_labels', type=bool, default=False, help='augment \ + discriminator loss with real examples with wrong (~shuffled, sort of) labels') + # options pertaining to evaluation and exploration + parser.add_argument('--identifier', type=str, default='test', help='identifier \ + string for output files') + parser.add_argument('--sub_id', type=str, default='test', help='identifier \ + string for load parameters') + # options pertaining to differential privacy + parser.add_argument('--dp', type=bool, default=False, help='train discriminator \ + with differentially private SGD?') + parser.add_argument('--l2norm_bound', type=float, default=1e-5, + help='bound on norm of individual gradients for DP training') + parser.add_argument('--batches_per_lot', type=int, default=1, + help='number of batches per lot (for DP)') + parser.add_argument('--dp_sigma', type=float, default=1e-5, + help='sigma for noise added (for DP)') + + return parser + +def load_settings_from_file(settings): + """ + Handle loading settings from a JSON file, filling in missing settings from + the command line defaults, but otherwise overwriting them. + """ + settings_path = './experiments/settings/' + settings['settings_file'] + '.txt' + print('Loading settings from', settings_path) + settings_loaded = json.load(open(settings_path, 'r')) + # check for settings missing in file + for key in settings.keys(): + if not key in settings_loaded: + print(key, 'not found in loaded settings - adopting value from command line defaults: ', settings[key]) + # overwrite parsed/default settings with those read from file, allowing for + # (potentially new) default settings not present in file + settings.update(settings_loaded) + return settings