Skip to content

Commit

Permalink
Update training_multiSample.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yiwang12 authored Nov 1, 2024
1 parent 3d79d0d commit c8a8487
Showing 1 changed file with 24 additions and 12 deletions.
36 changes: 24 additions & 12 deletions mNSF/training_multiSample.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def _train_model_fixed_lr(self, list_tro,list_Dtrain, list_D__, ckpt_mgr,Dval=No
verbose=True,num_epochs=500,
ptic = process_time(), wtic = time(), ckpt_freq=50, test_cvdNorm=False,
kernel_hp_update_freq=10, status_freq=10, chol=True,
span=100, tol=1e-4, tol_norm = 0.4, pickle_freq=None, check_convergence: bool = True):
span=100, tol=1e-4, tol_norm = 0.4, pickle_freq=None, check_convergence: bool = True, vec_batch = None):
"""train_step
Dtrain, Dval : tensorflow Datasets produced by prepare_datasets_tf func
ckpt_mgr must store at least 2 checkpoints (max_to_keep)
Expand Down Expand Up @@ -427,18 +427,30 @@ def _train_model_fixed_lr(self, list_tro,list_Dtrain, list_D__, ckpt_mgr,Dval=No
#chol=(self.epoch % kernel_hp_update_freq==0)
trl=0.0
nsample=len(list_Dtrain)
for ksample in range(0,nsample):
list_tro[ksample].model.Z=list_D__[ksample]["Z"]
Dtrain_ksample = list_Dtrain[ksample]
for D in Dtrain_ksample: #iterate through each of the batches
epoch_loss.update_state(list_tro[ksample].model.train_step( D, list_tro[ksample].optimizer, list_tro[ksample].optimizer_k,
if vec_batch is None:
for ksample in range(0,nsample):
list_tro[ksample].model.Z=list_D__[ksample]["Z"]
Dtrain_ksample = list_Dtrain[ksample]
for D in Dtrain_ksample: #iterate through each of the batches
epoch_loss.update_state(list_tro[ksample].model.train_step( D, list_tro[ksample].optimizer, list_tro[ksample].optimizer_k,
Ntot=list_tro[ksample].model.delta.shape[1], chol=chol))
trl = trl + epoch_loss.result().numpy()
else:
for ksample in range(0,nsample):
list_tro[ksample].model.Z=list_D__[ksample]["Z"]
Dtrain_ksample = list_Dtrain[ksample]
if vec_batch[ksample]:
list_tro[ksample].model.beta0.assign(list_tro[ksample-1].model.beta0)
list_tro[ksample].model.beta.assign(list_tro[ksample-1].model.beta)
list_tro[ksample].model.W.assign(list_tro[ksample-1].model.W.numpy())
list_tro[ksample].model.amplitude.assign(list_tro[ksample-1].model.amplitude())
list_tro[ksample].model.length_scale.assign(list_tro[ksample-1].model.length_scale())
list_tro[ksample].model.scale_diag.assign(list_tro[ksample-1].model.scale_diag())
#Loadings weights
for D in Dtrain_ksample: #iterate through each of the batches
epoch_loss.update_state(list_tro[ksample].model.train_step( D, list_tro[ksample].optimizer, list_tro[ksample].optimizer_k,
Ntot=list_tro[ksample].model.delta.shape[1], chol=chol))
trl = trl + epoch_loss.result().numpy()
#print("ksample")
#print(ksample)
#print(D["X"].shape)
#print(list_tro[ksample].model.delta.shape[1])
#print(tf.config.experimental.get_memory_info('GPU:0'))
trl = trl + epoch_loss.result().numpy()
W_updated=list_tro[ksample].model.W-list_tro[ksample].model.W
#print(trl)
for ksample in range(0,nsample):
Expand Down

0 comments on commit c8a8487

Please sign in to comment.