Skip to content

Commit

Permalink
Update test_small_run.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yiwang12 authored Nov 12, 2024
1 parent 30cb4e7 commit 8c65dec
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions tests/test_small_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,22 @@ def _run(

# step 0 Data loading
D, X = load_data(data_dir, n_sample)
ninduced
listDtrain = process_multiSample.get_listDtrain(D)

for ksample in range(0, len(D)):
list_D_chunked = process_multiSample.get_listD_chunked(D)
for ksample in range(0, len(list_D_chunked)):
random.seed(10)
ninduced = round(D[ksample]["X"].shape[0] * 0.35)
ninduced = round(list_D_chunked[ksample]["X"].shape[0] * 0.35)
D_tmp = D[ksample]
D[ksample]["Z"] = D_tmp["X"][random.sample(range(0, D_tmp["X"].shape[0] - 1), ninduced), :]

# step 1 initialize model
fit = process_multiSample.ini_multiSample(D, n_loadings, "nb", chol=False)

fit = process_multiSample.ini_multiSample(list_D_chunked, n_loadings, "nb", chol=False)
# step 2 fit model

(pp := (output_dir / "models" / "pp")).mkdir(parents=True, exist_ok=True)
fit = training_multiSample.train_model_mNSF(fit, pp, listDtrain, D, legacy=legacy, num_epochs=epochs)
fit = training_multiSample.train_model_mNSF(fit, pp, listDtrain, list_D_chunked, legacy=legacy, num_epochs=epochs)
(output_dir / "list_fit_smallData.pkl").write_bytes(pickle.dumps(fit))

# step 3 save results
Expand Down

0 comments on commit 8c65dec

Please sign in to comment.