Skip to content

Commit

Permalink
add scaling as an option
Browse files Browse the repository at this point in the history
  • Loading branch information
62442katieb committed Nov 27, 2023
1 parent 8036f5e commit f274dec
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions idconn/nbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def pynbs(matrices, outcome, num_node=None, diagonal=False, alpha=0.05, predict=


def kfold_nbs(
matrices, outcome, confounds=None, alpha=0.05, groups=None, num_node=None, diagonal=False, n_splits=10, n_iterations=10
matrices, outcome, confounds=None, alpha=0.05, groups=None, num_node=None, diagonal=False, scale_x=False, scale_y=False, n_splits=10, n_iterations=10
):
"""Calculates the Network Based Statistic (Zalesky et al., 20##) on connectivity matrices provided
of shape ((subject x session)x node x node)
Expand Down Expand Up @@ -333,8 +333,7 @@ def kfold_nbs(
manager = enlighten.get_manager()
ticks = manager.counter(total=n_splits * n_iterations, desc="Progress", unit="folds")
for train_idx, test_idx in cv.split(edges, split_y):
x_scaler = Normalizer()
y_scaler = Normalizer()

cv_results.at[i, "split"] = (train_idx, test_idx)

# assert len(train_a_idx) == len(train_b_idx)
Expand Down Expand Up @@ -382,15 +381,21 @@ def kfold_nbs(
test_y, test_edges = residualize(X=test_edges, y=test_y, confounds=test_confounds)
else:
pass
if scale_x:
x_scaler = Normalizer()
train_edges = x_scaler.fit_transform(train_edges)
test_edges = x_scaler.transform(test_edges)
if scale_y:
if np.unique(outcome).shape[0] == 2:
pass
else:
y_scaler = Normalizer()
train_y = y_scaler.fit_transform(train_y.reshape(-1, 1))
test_y = y_scaler.transform(test_y.reshape(-1, 1))



train_edges = x_scaler.fit_transform(train_edges)
test_edges = x_scaler.transform(test_edges)

if np.unique(outcome).shape[0] == 2:
pass
else:
train_y = y_scaler.fit_transform(train_y.reshape(-1, 1))
test_y = y_scaler.transform(test_y.reshape(-1, 1))


# perform NBS wooooooooo
# note: output is a dataframe :)
Expand Down

0 comments on commit f274dec

Please sign in to comment.