diff --git a/mNSF/training_multiSample.py b/mNSF/training_multiSample.py index e4b7192..b0ead65 100644 --- a/mNSF/training_multiSample.py +++ b/mNSF/training_multiSample.py @@ -314,92 +314,92 @@ def _train_model_fixed_lr(self, list_tro,list_Dtrain, list_D__, ckpt_mgr,Dval=No ptic = process_time(), wtic = time(), ckpt_freq=50, test_cvdNorm=False, kernel_hp_update_freq=10, status_freq=10, span=100, tol=1e-4, tol_norm = 0.4, pickle_freq=None, check_convergence: bool = True): - """ - train_step - Dtrain, Dval : tensorflow Datasets produced by prepare_datasets_tf func - ckpt_mgr must store at least 2 checkpoints (max_to_keep) - Ntr: total number of training observations, needed to adjust KL term in ELBO - S: number of samples to approximate the ELBO - verbose: should status updates be printed - num_epochs: maximum passes through the data after which optimization will be stopped - ptic,wtic: process and wall time baselines - kernel_hp_update_freq: how often to update the kernel hyperparameters (eg every 10 epochs) - updating less than once per epoch improves speed but reduces numerical stability - status_freq: how often to check for convergence and print updates - ckpt_freq: how often to save tensorflow checkpoints to disk - span: when checking for convergence, how many recent observations to consider - tol: numerical (relative) change below which convergence is declared - pickle_freq: how often to save the entire object to disk as a pickle file - """ - ptic,wtic = self.checkpoint(ckpt_mgr, process_time()-ptic, time()-wtic) - self.loss["train"] = rpad(self.loss["train"],num_epochs+1) - if pickle_freq is None: #only pickle at the end - pickle_freq = num_epochs - msg = '{:04d} train: {:.3e}' - if Dval: - msg += ', val: {:.3e}' - self.loss["val"] = rpad(self.loss["val"],num_epochs+1) - msg2 = "" #modified later to include rel_chg - cvg = 0 #increment each time we think it has converged - cvg_normalized=0 - cc = ConvergenceChecker(span) - while (not self.converged) and (self.epoch < num_epochs): - epoch_loss = tf.keras.metrics.Mean() - chol=(self.epoch % kernel_hp_update_freq==0) - trl=0.0 - nsample=len(list_Dtrain) - for ksample in range(0,nsample): - list_tro[ksample].model.Z=list_D__[ksample]["Z"] - Dtrain_ksample = list_Dtrain[ksample] - for D in Dtrain_ksample: #iterate through each of the batches - epoch_loss.update_state(list_tro[ksample].model.train_step( D, list_tro[ksample].optimizer, list_tro[ksample].optimizer_k, - Ntot=list_tro[ksample].model.delta.shape[1], chol=True)) - trl = trl + epoch_loss.result().numpy() - W_updated=list_tro[ksample].model.W-list_tro[ksample].model.W - for ksample in range(0,nsample): - W_updated = W_updated+ (list_tro[ksample].model.W / nsample) - self.epoch.assign_add(1) - i = self.epoch.numpy() - self.loss["train"][i] = trl - - ## check for nan in any sample loadings - for fit_i in list_tro: - if np.isnan(fit_i.model.W).any(): - print('NaN in sample ' + str(list_tro.index(fit_i) + 1)) - - if not np.isfinite(trl): ### modified - print("training loss calculated at the point of divergence: ") - print(trl) - raise NumericalDivergenceError###!!!NumericalDivergenceError - #if not np.isfinite(trl) or trl>self.loss["train"][1]: ### modified - # raise NumericalDivergenceError###!!!NumericalDivergenceError - if i%status_freq==0 or i==num_epochs: - if Dval: - val_loss = self.model.validation_step(Dval, S=S, chol=False).numpy() - self.loss["val"][i] = val_loss - if i>span and check_convergence: #checking for convergence - rel_chg = cc.relative_change(self.loss["train"],idx=i) - print("rel_chg") - print(rel_chg) - msg2 = ", chg: {:.2e}".format(-rel_chg) - if abs(rel_chg)=2 or cvg_normalized>=2: #i.e. either convergence or normalized convergence has been detected twice in a row - self.converged=True - pickle_freq = i #ensures final pickling will happen - self.loss = truncate_history(self.loss, i) - if verbose: - if Dval: print(msg.format(i,trl,val_loss)+msg2) - else: print(msg.format(i,trl)+msg2) - if i%ckpt_freq==0: + """ + train_step + Dtrain, Dval : tensorflow Datasets produced by prepare_datasets_tf func + ckpt_mgr must store at least 2 checkpoints (max_to_keep) + Ntr: total number of training observations, needed to adjust KL term in ELBO + S: number of samples to approximate the ELBO + verbose: should status updates be printed + num_epochs: maximum passes through the data after which optimization will be stopped + ptic,wtic: process and wall time baselines + kernel_hp_update_freq: how often to update the kernel hyperparameters (eg every 10 epochs) + updating less than once per epoch improves speed but reduces numerical stability + status_freq: how often to check for convergence and print updates + ckpt_freq: how often to save tensorflow checkpoints to disk + span: when checking for convergence, how many recent observations to consider + tol: numerical (relative) change below which convergence is declared + pickle_freq: how often to save the entire object to disk as a pickle file + """ ptic,wtic = self.checkpoint(ckpt_mgr, process_time()-ptic, time()-wtic) - if self.pickle_path and i%pickle_freq==0: - ptic,wtic = self.pickle(process_time()-ptic, time()-wtic) + self.loss["train"] = rpad(self.loss["train"],num_epochs+1) + if pickle_freq is None: #only pickle at the end + pickle_freq = num_epochs + msg = '{:04d} train: {:.3e}' + if Dval: + msg += ', val: {:.3e}' + self.loss["val"] = rpad(self.loss["val"],num_epochs+1) + msg2 = "" #modified later to include rel_chg + cvg = 0 #increment each time we think it has converged + cvg_normalized=0 + cc = ConvergenceChecker(span) + while (not self.converged) and (self.epoch < num_epochs): + epoch_loss = tf.keras.metrics.Mean() + chol=(self.epoch % kernel_hp_update_freq==0) + trl=0.0 + nsample=len(list_Dtrain) + for ksample in range(0,nsample): + list_tro[ksample].model.Z=list_D__[ksample]["Z"] + Dtrain_ksample = list_Dtrain[ksample] + for D in Dtrain_ksample: #iterate through each of the batches + epoch_loss.update_state(list_tro[ksample].model.train_step( D, list_tro[ksample].optimizer, list_tro[ksample].optimizer_k, + Ntot=list_tro[ksample].model.delta.shape[1], chol=True)) + trl = trl + epoch_loss.result().numpy() + W_updated=list_tro[ksample].model.W-list_tro[ksample].model.W + for ksample in range(0,nsample): + W_updated = W_updated+ (list_tro[ksample].model.W / nsample) + self.epoch.assign_add(1) + i = self.epoch.numpy() + self.loss["train"][i] = trl + + ## check for nan in any sample loadings + for fit_i in list_tro: + if np.isnan(fit_i.model.W).any(): + print('NaN in sample ' + str(list_tro.index(fit_i) + 1)) + + if not np.isfinite(trl): ### modified + print("training loss calculated at the point of divergence: ") + print(trl) + raise NumericalDivergenceError###!!!NumericalDivergenceError + #if not np.isfinite(trl) or trl>self.loss["train"][1]: ### modified + # raise NumericalDivergenceError###!!!NumericalDivergenceError + if i%status_freq==0 or i==num_epochs: + if Dval: + val_loss = self.model.validation_step(Dval, S=S, chol=False).numpy() + self.loss["val"][i] = val_loss + if i>span and check_convergence: #checking for convergence + rel_chg = cc.relative_change(self.loss["train"],idx=i) + print("rel_chg") + print(rel_chg) + msg2 = ", chg: {:.2e}".format(-rel_chg) + if abs(rel_chg)=2 or cvg_normalized>=2: #i.e. either convergence or normalized convergence has been detected twice in a row + self.converged=True + pickle_freq = i #ensures final pickling will happen + self.loss = truncate_history(self.loss, i) + if verbose: + if Dval: print(msg.format(i,trl,val_loss)+msg2) + else: print(msg.format(i,trl)+msg2) + if i%ckpt_freq==0: + ptic,wtic = self.checkpoint(ckpt_mgr, process_time()-ptic, time()-wtic) + if self.pickle_path and i%pickle_freq==0: + ptic,wtic = self.pickle(process_time()-ptic, time()-wtic) def find_checkpoint(self, ckpt_freq, back=1, epoch0=0): """