Skip to content

Commit

Permalink
Update training_multiSample.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yiwang12 authored Nov 14, 2024
1 parent 5d75626 commit 64b78f4
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions mNSF/training_multiSample.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,9 +426,9 @@ def _train_model_fixed_lr(self, list_tro,list_Dtrain, list_D__, ckpt_mgr,Dval=No
#epoch=self.epoch
#chol=(self.epoch % kernel_hp_update_freq==0)
trl=0.0
nsample=len(list_Dtrain)
nsample_chunked=len(list_Dtrain)
if list_nchunk is None:
for ksample in range(0,nsample):
for ksample in range(0,nsample_chunked):
list_tro[ksample].model.Z=list_D__[ksample]["Z"]
Dtrain_ksample = list_Dtrain[ksample]
for D in Dtrain_ksample: #iterate through each of the batches
Expand All @@ -437,9 +437,12 @@ def _train_model_fixed_lr(self, list_tro,list_Dtrain, list_D__, ckpt_mgr,Dval=No
trl = trl + epoch_loss.result().numpy()
else:
vec_batch = list()
nsample = len(list_nchunk)
for ksample in range(0,nsample):
vec_batch.append([False]*1 + [True]*(list_nchunk[ksample]-1))
for ksample in range(0,nsample):
print("vec_batch")
print(vec_batch)
for ksample in range(0,nsample_chunked):
list_tro[ksample].model.Z=list_D__[ksample]["Z"]
Dtrain_ksample = list_Dtrain[ksample]
if vec_batch[ksample]:
Expand All @@ -454,8 +457,8 @@ def _train_model_fixed_lr(self, list_tro,list_Dtrain, list_D__, ckpt_mgr,Dval=No
trl = trl + epoch_loss.result().numpy()
W_updated=list_tro[ksample].model.W-list_tro[ksample].model.W
#print(trl)
for ksample in range(0,nsample):
W_updated = W_updated+ (list_tro[ksample].model.W / nsample)
for ksample in range(0,nsample_chunked):
W_updated = W_updated+ (list_tro[ksample].model.W / nsample_chunked)
self.epoch.assign_add(1)
i = self.epoch.numpy()
self.loss["train"][i] = trl
Expand Down

0 comments on commit 64b78f4

Please sign in to comment.