Skip to content

Commit

Permalink
Update test_small_run.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yiwang12 authored Nov 12, 2024
1 parent 5f2afd5 commit 3f1bbc6
Showing 1 changed file with 8 additions and 19 deletions.
27 changes: 8 additions & 19 deletions tests/test_small_run.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,18 @@
import pickle
import random
from pathlib import Path

import click
import pandas as pd

from mNSF import process_multiSample, training_multiSample


def load_data(pth: Path, n_sample: int):
X = [pd.read_csv(pth / f"X_sample{k}.csv") for k in range(1, n_sample + 1)]
Y = [pd.read_csv(pth / f"Y_sample{k}.csv") for k in range(1, n_sample + 1)]
D = [process_multiSample.get_D(x, y) for x, y in zip(X, Y)]
X = [D[k]["X"] for k in range(0, n_sample)]
return D, X


def _run(
def run(
data_dir: str | Path,
output_dir: str | Path,
n_loadings: int = 3,
Expand All @@ -25,15 +21,14 @@ def _run(
legacy: bool = False,
):
output_dir, data_dir = Path(output_dir), Path(data_dir)

# step 0 Data loading
D, X = load_data(data_dir, n_sample)
list_nchunk = [2,2]
listDtrain = process_multiSample.get_listDtrain(D,list_nchunk=list_nchunk)
print("len(listDtrain")
list_nchunk = [2, 2]
listDtrain = process_multiSample.get_listDtrain(D, list_nchunk=list_nchunk)
print("len(listDtrain)")
print(len(listDtrain))

list_D_chunked = process_multiSample.get_listD_chunked(D,list_nchunk=list_nchunk)
list_D_chunked = process_multiSample.get_listD_chunked(D, list_nchunk=list_nchunk)
for ksample in range(0, len(list_D_chunked)):
random.seed(10)
ninduced = round(list_D_chunked[ksample]["X"].shape[0] * 0.35)
Expand All @@ -47,11 +42,10 @@ def _run(
fit = process_multiSample.ini_multiSample(list_D_chunked, n_loadings, "nb", chol=False)

# step 2 fit model

(pp := (output_dir / "models" / "pp")).mkdir(parents=True, exist_ok=True)
fit = training_multiSample.train_model_mNSF(fit, pp, listDtrain, list_D_chunked, legacy=legacy, num_epochs=epochs)
(output_dir / "list_fit_smallData.pkl").write_bytes(pickle.dumps(fit))

# step 3 save results
inpf12 = process_multiSample.interpret_npf_v3(fit, X, list_nchunk, S=2, lda_mode=False)
(
Expand All @@ -60,15 +54,12 @@ def _run(
columns=range(1, n_loadings + 1),
).to_csv(output_dir / "loadings_spde_smallData.csv")
)

factors = inpf12["factors"][:, :n_loadings]
for k in range(n_sample):
indices = process_multiSample.get_listSampleID(D)[k].astype(int)
pd.DataFrame(factors[indices, :]).to_csv(output_dir / f"factors_sample{k + 1:02d}_smallData.csv")

print("Done!")


@click.command()
@click.argument("data_dir", type=click.Path(exists=True, dir_okay=True, file_okay=False))
@click.argument("output_dir", type=click.Path(exists=True, dir_okay=True, file_okay=False))
Expand All @@ -84,12 +75,10 @@ def run_cli(
epochs: int = 10,
legacy: bool = True,
):
return _run(data_dir, output_dir, n_loadings, n_sample, epochs, legacy)

return run(data_dir, output_dir, n_loadings, n_sample, epochs, legacy)

def test_small_run():
_run("tests/data", ".", n_loadings=1, n_sample=2, legacy=True)

run("tests/data", ".", n_loadings=1, n_sample=2, legacy=True)

if __name__ == "__main__":
run_cli()

0 comments on commit 3f1bbc6

Please sign in to comment.