-
Notifications
You must be signed in to change notification settings - Fork 163
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
319 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
import numpy as np | ||
import tensorflow as tf | ||
import pdb | ||
import random | ||
import json | ||
from scipy.stats import mode | ||
|
||
import data_utils | ||
import plotting | ||
import model | ||
import utils | ||
import eval | ||
import DR_discriminator | ||
|
||
from time import time | ||
from math import floor | ||
from mmd import rbf_mmd2, median_pairwise_distance, mix_rbf_mmd2_and_ratio | ||
|
||
begin = time() | ||
|
||
tf.logging.set_verbosity(tf.logging.ERROR) | ||
|
||
# --- get settings --- # | ||
# parse command line arguments, or use defaults | ||
parser = utils.rgan_options_parser() | ||
settings = vars(parser.parse_args()) | ||
# if a settings file is specified, it overrides command line arguments/defaults | ||
if settings['settings_file']: settings = utils.load_settings_from_file(settings) | ||
|
||
# --- get data, split --- # | ||
# samples, pdf, labels = data_utils.get_data(settings) | ||
data_path = './experiments/data/' + settings['data_load_from'] + '.data.npy' | ||
print('Loading data from', data_path) | ||
settings["eval_an"] = False | ||
settings["eval_single"] = False | ||
samples, labels, index = data_utils.get_data(settings["data"], settings["seq_length"], settings["seq_step"], | ||
settings["num_signals"], settings['sub_id'], settings["eval_single"], | ||
settings["eval_an"], data_path) | ||
print('samples_size:',samples.shape) | ||
# -- number of variables -- # | ||
num_variables = samples.shape[2] | ||
print('num_variables:', num_variables) | ||
# --- save settings, data --- # | ||
print('Ready to run with settings:') | ||
for (k, v) in settings.items(): print(v, '\t', k) | ||
# add the settings to local environment | ||
# WARNING: at this point a lot of variables appear | ||
locals().update(settings) | ||
json.dump(settings, open('./experiments/settings/' + identifier + '.txt', 'w'), indent=0) | ||
|
||
# --- build model --- # | ||
# preparation: data placeholders and model parameters | ||
Z, X, T = model.create_placeholders(batch_size, seq_length, latent_dim, num_variables) | ||
discriminator_vars = ['hidden_units_d', 'seq_length', 'batch_size', 'batch_mean'] | ||
discriminator_settings = dict((k, settings[k]) for k in discriminator_vars) | ||
generator_vars = ['hidden_units_g', 'seq_length', 'batch_size', 'learn_scale'] | ||
generator_settings = dict((k, settings[k]) for k in generator_vars) | ||
generator_settings['num_signals'] = num_variables | ||
|
||
# model: GAN losses | ||
D_loss, G_loss = model.GAN_loss(Z, X, generator_settings, discriminator_settings) | ||
D_solver, G_solver, priv_accountant = model.GAN_solvers(D_loss, G_loss, learning_rate, batch_size, | ||
total_examples=samples.shape[0], | ||
l2norm_bound=l2norm_bound, | ||
batches_per_lot=batches_per_lot, sigma=dp_sigma, dp=dp) | ||
# model: generate samples for visualization | ||
G_sample = model.generator(Z, **generator_settings, reuse=True) | ||
|
||
|
||
# # --- evaluation settings--- # | ||
# | ||
# # frequency to do visualisations | ||
# num_samples = samples.shape[0] | ||
# vis_freq = max(6600 // num_samples, 1) | ||
# eval_freq = max(6600// num_samples, 1) | ||
# | ||
# # get heuristic bandwidth for mmd kernel from evaluation samples | ||
# heuristic_sigma_training = median_pairwise_distance(samples) | ||
# best_mmd2_so_far = 1000 | ||
# | ||
# # optimise sigma using that (that's t-hat) | ||
# batch_multiplier = 5000 // batch_size | ||
# eval_size = batch_multiplier * batch_size | ||
# eval_eval_size = int(0.2 * eval_size) | ||
# eval_real_PH = tf.placeholder(tf.float32, [eval_eval_size, seq_length, num_generated_features]) | ||
# eval_sample_PH = tf.placeholder(tf.float32, [eval_eval_size, seq_length, num_generated_features]) | ||
# n_sigmas = 2 | ||
# sigma = tf.get_variable(name='sigma', shape=n_sigmas, initializer=tf.constant_initializer( | ||
# value=np.power(heuristic_sigma_training, np.linspace(-1, 3, num=n_sigmas)))) | ||
# mmd2, that = mix_rbf_mmd2_and_ratio(eval_real_PH, eval_sample_PH, sigma) | ||
# with tf.variable_scope("SIGMA_optimizer"): | ||
# sigma_solver = tf.train.RMSPropOptimizer(learning_rate=0.05).minimize(-that, var_list=[sigma]) | ||
# # sigma_solver = tf.train.AdamOptimizer().minimize(-that, var_list=[sigma]) | ||
# # sigma_solver = tf.train.AdagradOptimizer(learning_rate=0.1).minimize(-that, var_list=[sigma]) | ||
# sigma_opt_iter = 2000 | ||
# sigma_opt_thresh = 0.001 | ||
# sigma_opt_vars = [var for var in tf.global_variables() if 'SIGMA_optimizer' in var.name] | ||
|
||
|
||
# --- run the program --- # | ||
config = tf.ConfigProto() | ||
config.gpu_options.allow_growth = True | ||
sess = tf.Session(config=config) | ||
# sess = tf.Session() | ||
sess.run(tf.global_variables_initializer()) | ||
|
||
# # -- plot the real samples -- # | ||
vis_real_indices = np.random.choice(len(samples), size=16) | ||
vis_real = np.float32(samples[vis_real_indices, :, :]) | ||
plotting.save_plot_sample(vis_real, 0, identifier + '_real', n_samples=16, num_epochs=num_epochs) | ||
plotting.save_samples_real(vis_real, identifier) | ||
|
||
# --- train --- # | ||
train_vars = ['batch_size', 'D_rounds', 'G_rounds', 'use_time', 'seq_length', 'latent_dim'] | ||
train_settings = dict((k, settings[k]) for k in train_vars) | ||
train_settings['num_signals'] = num_variables | ||
|
||
t0 = time() | ||
MMD = np.zeros([num_epochs, ]) | ||
|
||
for epoch in range(num_epochs): | ||
# for epoch in range(1): | ||
# -- train epoch -- # | ||
D_loss_curr, G_loss_curr = model.train_epoch(epoch, samples, labels, sess, Z, X, D_loss, G_loss, | ||
D_solver, G_solver, **train_settings) | ||
|
||
# # -- eval -- # | ||
# # visualise plots of generated samples, with/without labels | ||
# # choose which epoch to visualize | ||
# | ||
# # random input vectors for the latent space, as the inputs of generator | ||
# vis_ZZ = model.sample_Z(batch_size, seq_length, latent_dim, use_time) | ||
# | ||
# # # -- generate samples-- # | ||
# vis_sample = sess.run(G_sample, feed_dict={Z: vis_ZZ}) | ||
# # # -- visualize the generated samples -- # | ||
# plotting.save_plot_sample(vis_sample, epoch, identifier, n_samples=16, num_epochs=None, ncol=4) | ||
# # plotting.save_plot_sample(vis_sample, 0, identifier + '_real', n_samples=16, num_epochs=num_epochs) | ||
# # # save the generated samples in cased they might be useful for comparison | ||
# plotting.save_samples(vis_sample, identifier, epoch) | ||
|
||
# -- print -- # | ||
print('epoch, D_loss_curr, G_loss_curr, seq_length') | ||
print('%d\t%.4f\t%.4f\t%d' % (epoch, D_loss_curr, G_loss_curr, seq_length)) | ||
|
||
# # -- compute mmd2 and if available, prob density -- # | ||
# if epoch % eval_freq == 0: | ||
# # how many samples to evaluate with? | ||
# eval_Z = model.sample_Z(eval_size, seq_length, latent_dim, use_time) | ||
# eval_sample = np.empty(shape=(eval_size, seq_length, num_signals)) | ||
# for i in range(batch_multiplier): | ||
# eval_sample[i * batch_size:(i + 1) * batch_size, :, :] = sess.run(G_sample, feed_dict={ Z: eval_Z[i * batch_size:(i + 1) * batch_size]}) | ||
# eval_sample = np.float32(eval_sample) | ||
# eval_real = np.float32(samples['vali'][np.random.choice(len(samples['vali']), size=batch_multiplier * batch_size), :, :]) | ||
# | ||
# eval_eval_real = eval_real[:eval_eval_size] | ||
# eval_test_real = eval_real[eval_eval_size:] | ||
# eval_eval_sample = eval_sample[:eval_eval_size] | ||
# eval_test_sample = eval_sample[eval_eval_size:] | ||
# | ||
# # MMD | ||
# # reset ADAM variables | ||
# sess.run(tf.initialize_variables(sigma_opt_vars)) | ||
# sigma_iter = 0 | ||
# that_change = sigma_opt_thresh * 2 | ||
# old_that = 0 | ||
# while that_change > sigma_opt_thresh and sigma_iter < sigma_opt_iter: | ||
# new_sigma, that_np, _ = sess.run([sigma, that, sigma_solver], | ||
# feed_dict={eval_real_PH: eval_eval_real, eval_sample_PH: eval_eval_sample}) | ||
# that_change = np.abs(that_np - old_that) | ||
# old_that = that_np | ||
# sigma_iter += 1 | ||
# opt_sigma = sess.run(sigma) | ||
# try: | ||
# mmd2, that_np = sess.run(mix_rbf_mmd2_and_ratio(eval_test_real, eval_test_sample, biased=False, sigmas=sigma)) | ||
# except ValueError: | ||
# mmd2 = 'NA' | ||
# that = 'NA' | ||
# | ||
# MMD[epoch, ] = mmd2 | ||
|
||
# -- save model parameters -- # | ||
model.dump_parameters(sub_id + '_' + str(seq_length) + '_' + str(epoch), sess) | ||
|
||
np.save('./experiments/plots/gs/' + identifier + '_' + 'MMD.npy', MMD) | ||
|
||
end = time() - begin | ||
print('Training terminated | Training time=%d s' %(end) ) | ||
|
||
print("Training terminated | training time = %ds " % (time() - begin)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
### from https://github.com/eugenium/MMD/blob/master/tf_ops.py | ||
import tensorflow as tf | ||
|
||
|
||
def sq_sum(t, name=None): | ||
"The squared Frobenius-type norm of a tensor, sum(t ** 2)." | ||
with tf.name_scope(name, "SqSum", [t]): | ||
t = tf.convert_to_tensor(t, name='t') | ||
return 2 * tf.nn.l2_loss(t) | ||
|
||
|
||
def dot(x, y, name=None): | ||
"The dot product of two vectors x and y." | ||
with tf.name_scope(name, "Dot", [x, y]): | ||
x = tf.convert_to_tensor(x, name='x') | ||
y = tf.convert_to_tensor(y, name='y') | ||
|
||
x.get_shape().assert_has_rank(1) | ||
y.get_shape().assert_has_rank(1) | ||
|
||
return tf.squeeze(tf.matmul(tf.expand_dims(x, 0), tf.expand_dims(y, 1))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
#!/usr/bin/env ipython | ||
# Utility functions that don't fit in other scripts | ||
import argparse | ||
import json | ||
|
||
def rgan_options_parser(): | ||
""" | ||
Define parser to parse options from command line, with defaults. | ||
Refer to this function for definitions of various variables. | ||
""" | ||
parser = argparse.ArgumentParser(description='Train a GAN to generate sequential, real-valued data.') | ||
# meta-option | ||
parser.add_argument('--settings_file', help='json file of settings, overrides everything else', type=str, default='') | ||
# options pertaining to data | ||
parser.add_argument('--data', help='what kind of data to train with?', | ||
default='gp_rbf', | ||
choices=['gp_rbf', 'sine', 'mnist', 'load']) | ||
# parser.add_argument('--num_samples', type=int, help='how many training examples \ | ||
# to generate?', default=28*5*100) | ||
# parser.add_argument('--num_samples_t', type=int, help='how many testing examples \ | ||
# for anomaly detection?', default=28 * 5 * 100) | ||
parser.add_argument('--seq_length', type=int, default=30) | ||
parser.add_argument('--num_signals', type=int, default=1) | ||
parser.add_argument('--normalise', type=bool, default=False, help='normalise the \ | ||
training/vali/test data (during split)?') | ||
# parser.add_argument('--AD', type=bool, default=False, help='should we conduct anomaly detection?') | ||
|
||
### for gp_rbf | ||
parser.add_argument('--scale', type=float, default=0.1) | ||
### for sin (should be using subparsers for this...) | ||
parser.add_argument('--freq_low', type=float, default=1.0) | ||
parser.add_argument('--freq_high', type=float, default=5.0) | ||
parser.add_argument('--amplitude_low', type=float, default=0.1) | ||
parser.add_argument('--amplitude_high', type=float, default=0.9) | ||
### for mnist | ||
parser.add_argument('--multivariate_mnist', type=bool, default=False) | ||
parser.add_argument('--full_mnist', type=bool, default=False) | ||
### for loading | ||
parser.add_argument('--data_load_from', type=str, default='') | ||
### for eICU | ||
parser.add_argument('--resample_rate_in_min', type=int, default=15) | ||
# hyperparameters of the model | ||
parser.add_argument('--hidden_units_g', type=int, default=100) | ||
parser.add_argument('--hidden_units_d', type=int, default=100) | ||
parser.add_argument('--hidden_units_e', type=int, default=100) | ||
parser.add_argument('--kappa', type=float, help='weight between final output \ | ||
and intermediate steps in discriminator cost (1 = all \ | ||
intermediate', default=1) | ||
parser.add_argument('--latent_dim', type=int, default=5, help='dimensionality \ | ||
of the latent/noise space') | ||
parser.add_argument('--weight', type=int, default=0.5, help='weight of score') | ||
parser.add_argument('--degree', type=int, default=1, help='norm degree') | ||
parser.add_argument('--batch_mean', type=bool, default=False, help='append the mean \ | ||
of the batch to all variables for calculating discriminator loss') | ||
parser.add_argument('--learn_scale', type=bool, default=False, help='make the \ | ||
"scale" parameter at the output of the generator learnable (else fixed \ | ||
to 1') | ||
# options pertaining to training | ||
parser.add_argument('--learning_rate', type=float, default=0.1) | ||
parser.add_argument('--batch_size', type=int, default=28) | ||
parser.add_argument('--num_epochs', type=int, default=100) | ||
parser.add_argument('--D_rounds', type=int, default=5, help='number of rounds \ | ||
of discriminator training') | ||
parser.add_argument('--G_rounds', type=int, default=1, help='number of rounds \ | ||
of generator training') | ||
parser.add_argument('--E_rounds', type=int, default=1, help='number of rounds \ | ||
of encoder training') | ||
# parser.add_argument('--use_time', type=bool, default=False, help='enforce \ | ||
# latent dimension 0 to correspond to time') | ||
parser.add_argument('--shuffle', type=bool, default=True) | ||
parser.add_argument('--eval_mul', type=bool, default=False) | ||
parser.add_argument('--eval_an', type=bool, default=False) | ||
parser.add_argument('--eval_single', type=bool, default=False) | ||
parser.add_argument('--wrong_labels', type=bool, default=False, help='augment \ | ||
discriminator loss with real examples with wrong (~shuffled, sort of) labels') | ||
# options pertaining to evaluation and exploration | ||
parser.add_argument('--identifier', type=str, default='test', help='identifier \ | ||
string for output files') | ||
parser.add_argument('--sub_id', type=str, default='test', help='identifier \ | ||
string for load parameters') | ||
# options pertaining to differential privacy | ||
parser.add_argument('--dp', type=bool, default=False, help='train discriminator \ | ||
with differentially private SGD?') | ||
parser.add_argument('--l2norm_bound', type=float, default=1e-5, | ||
help='bound on norm of individual gradients for DP training') | ||
parser.add_argument('--batches_per_lot', type=int, default=1, | ||
help='number of batches per lot (for DP)') | ||
parser.add_argument('--dp_sigma', type=float, default=1e-5, | ||
help='sigma for noise added (for DP)') | ||
|
||
return parser | ||
|
||
def load_settings_from_file(settings): | ||
""" | ||
Handle loading settings from a JSON file, filling in missing settings from | ||
the command line defaults, but otherwise overwriting them. | ||
""" | ||
settings_path = './experiments/settings/' + settings['settings_file'] + '.txt' | ||
print('Loading settings from', settings_path) | ||
settings_loaded = json.load(open(settings_path, 'r')) | ||
# check for settings missing in file | ||
for key in settings.keys(): | ||
if not key in settings_loaded: | ||
print(key, 'not found in loaded settings - adopting value from command line defaults: ', settings[key]) | ||
# overwrite parsed/default settings with those read from file, allowing for | ||
# (potentially new) default settings not present in file | ||
settings.update(settings_loaded) | ||
return settings |