From f274decfd10bc8b13eaaf32aa0cd70148fcd7a7f Mon Sep 17 00:00:00 2001 From: "Katherine L. Bottenhorn" Date: Mon, 27 Nov 2023 15:43:31 -0800 Subject: [PATCH] add scaling as an option --- idconn/nbs.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/idconn/nbs.py b/idconn/nbs.py index 3e2b48f..26ed551 100644 --- a/idconn/nbs.py +++ b/idconn/nbs.py @@ -233,7 +233,7 @@ def pynbs(matrices, outcome, num_node=None, diagonal=False, alpha=0.05, predict= def kfold_nbs( - matrices, outcome, confounds=None, alpha=0.05, groups=None, num_node=None, diagonal=False, n_splits=10, n_iterations=10 + matrices, outcome, confounds=None, alpha=0.05, groups=None, num_node=None, diagonal=False, scale_x=False, scale_y=False, n_splits=10, n_iterations=10 ): """Calculates the Network Based Statistic (Zalesky et al., 20##) on connectivity matrices provided of shape ((subject x session)x node x node) @@ -333,8 +333,7 @@ def kfold_nbs( manager = enlighten.get_manager() ticks = manager.counter(total=n_splits * n_iterations, desc="Progress", unit="folds") for train_idx, test_idx in cv.split(edges, split_y): - x_scaler = Normalizer() - y_scaler = Normalizer() + cv_results.at[i, "split"] = (train_idx, test_idx) # assert len(train_a_idx) == len(train_b_idx) @@ -382,15 +381,21 @@ def kfold_nbs( test_y, test_edges = residualize(X=test_edges, y=test_y, confounds=test_confounds) else: pass + if scale_x: + x_scaler = Normalizer() + train_edges = x_scaler.fit_transform(train_edges) + test_edges = x_scaler.transform(test_edges) + if scale_y: + if np.unique(outcome).shape[0] == 2: + pass + else: + y_scaler = Normalizer() + train_y = y_scaler.fit_transform(train_y.reshape(-1, 1)) + test_y = y_scaler.transform(test_y.reshape(-1, 1)) + + - train_edges = x_scaler.fit_transform(train_edges) - test_edges = x_scaler.transform(test_edges) - - if np.unique(outcome).shape[0] == 2: - pass - else: - train_y = y_scaler.fit_transform(train_y.reshape(-1, 1)) - test_y = y_scaler.transform(test_y.reshape(-1, 1)) + # perform NBS wooooooooo # note: output is a dataframe :)