Skip to content

Commit

Permalink
Add ref_batch parameter to neuroHarmonize to use SITE/scanner as refe…
Browse files Browse the repository at this point in the history
…rence for batch adjustments
  • Loading branch information
melhemr committed Apr 21, 2022
1 parent 1f5155e commit bf0d224
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 19 deletions.
24 changes: 20 additions & 4 deletions neuroHarmonize/harmonizationApply.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import pickle
import numpy as np
import pandas as pd
import nibabel as nib
from statsmodels.gam.api import BSplines
from .neuroCombat import make_design_matrix, adjust_data_final

Expand All @@ -24,6 +23,9 @@ def harmonizationApply(data, covars, model,return_stand_mean=False):
model : a dictionary of model parameters
the output of a call to harmonizationLearn()
ref_batch : batch (site or scanner) to be used as reference for batch adjustment.
- None by default
Returns
-------
Expand All @@ -39,6 +41,19 @@ def harmonizationApply(data, covars, model,return_stand_mean=False):
cat_cols = []
num_cols = [covars.columns.get_loc(c) for c in covars.columns if c!='SITE']
covars = np.array(covars, dtype='object')

if (not ('ref_batch' in model)) or (model['ref_batch'] is None):
ref_level=None
else:
ref_indices = np.argwhere((covars[:,batch_col]==model['ref_batch']).squeeze())
if ref_indices.shape[0]==0:
ref_level=None
print('[neuroCombat] batch.ref not found. Setting to None.')
covars[:,batch_col]=np.unique(covars[:,batch_col],return_inverse=True)[-1]
else:
covars[:,batch_col] = np.unique(covars[:,batch_col],return_inverse=True)[-1]
ref_level = np.int(covars[ref_indices[0],batch_col])

# load the smoothing model
smooth_model = model['smooth_model']
smooth_cols = smooth_model['smooth_cols']
Expand All @@ -59,14 +74,15 @@ def harmonizationApply(data, covars, model,return_stand_mean=False):
'n_batch': len(batch_levels),
'n_sample': int(covars.shape[0]),
'sample_per_batch': sample_per_batch.astype('int'),
'batch_info': [list(np.where(covars[:,batch_col]==idx)[0]) for idx in batch_levels]
'batch_info': [list(np.where(covars[:,batch_col]==idx)[0]) for idx in batch_levels],
'ref_level': ref_level
}
covars[~isTrainSite, batch_col] = 0
covars[:,batch_col] = covars[:,batch_col].astype(int)
###
# isolate array of data in training site
# apply ComBat without re-learning model parameters
design = make_design_matrix(covars, batch_col, cat_cols, num_cols,nb_class = len(model['SITE_labels']))
design = make_design_matrix(covars, batch_col, cat_cols, num_cols,ref_level,nb_class = len(model['SITE_labels']))
design[~isTrainSite,0:len(model['SITE_labels'])] = np.nan
### additional setup if smoothing is performed
if smooth_model['perform_smoothing']:
Expand Down Expand Up @@ -94,7 +110,7 @@ def harmonizationApply(data, covars, model,return_stand_mean=False):
bayes_data = np.full(s_data.shape,np.nan)
else:
bayes_data = adjust_data_final(s_data, design, model['gamma_star'], model['delta_star'],
stand_mean, var_pooled, info_dict)
stand_mean, var_pooled, info_dict, data)
bayes_data[:,~isTrainSite] = np.nan

# transpose data to return to original shape
Expand Down
28 changes: 23 additions & 5 deletions neuroHarmonize/harmonizationLearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .neuroCombat import make_design_matrix, find_parametric_adjustments, adjust_data_final, aprior, bprior

def harmonizationLearn(data, covars, eb=True, smooth_terms=[],
smooth_term_bounds=(None, None), return_s_data=False):
smooth_term_bounds=(None, None),ref_batch=None, return_s_data=False):
"""
Wrapper for neuroCombat function that returns the harmonization model.
Expand Down Expand Up @@ -36,6 +36,9 @@ def harmonizationLearn(data, covars, eb=True, smooth_terms=[],
useful when holdout data covers different range than
specify the bounds as (minimum, maximum)
currently not supported for models with mutliple smooth terms
ref_batch : batch (site or scanner) to be used as reference for batch adjustment.
- None by default
return_s_data (Optional) : bool, default False
whether to return s_data, the standardized data array
Expand Down Expand Up @@ -75,6 +78,20 @@ def harmonizationLearn(data, covars, eb=True, smooth_terms=[],
'df_gam': None
}
covars = np.array(covars, dtype='object')

if ref_batch is None:
ref_level=None
else:
ref_indices = np.argwhere((covars[:,batch_col]==ref_batch).squeeze())
if ref_indices.shape[0]==0:
ref_level=None
ref_batch=None
print('[neuroCombat] batch.ref not found. Setting to None.')
covars[:,batch_col]=np.unique(covars[:,batch_col],return_inverse=True)[-1]
else:
covars[:,batch_col] = np.unique(covars[:,batch_col],return_inverse=True)[-1]
ref_level = np.int(covars[ref_indices[0],batch_col])

### additional setup code from neuroCombat implementation:
# convert batch col to integer
covars[:,batch_col] = np.unique(covars[:,batch_col],return_inverse=True)[-1]
Expand All @@ -85,10 +102,11 @@ def harmonizationLearn(data, covars, eb=True, smooth_terms=[],
'n_batch': len(batch_levels),
'n_sample': int(covars.shape[0]),
'sample_per_batch': sample_per_batch.astype('int'),
'batch_info': [list(np.where(covars[:,batch_col]==idx)[0]) for idx in batch_levels]
'batch_info': [list(np.where(covars[:,batch_col]==idx)[0]) for idx in batch_levels],
'ref_level': ref_level
}
###
design = make_design_matrix(covars, batch_col, cat_cols, num_cols)
design = make_design_matrix(covars, batch_col, cat_cols, num_cols,ref_level)
### additional setup if smoothing is performed
if smooth_model['perform_smoothing']:
# create cubic spline basis for smooth terms
Expand Down Expand Up @@ -124,15 +142,15 @@ def harmonizationLearn(data, covars, eb=True, smooth_terms=[],
else:
gamma_star = LS_dict['gamma_hat']
delta_star = np.array(LS_dict['delta_hat'])
bayes_data = adjust_data_final(s_data, design, gamma_star, delta_star, stand_mean, var_pooled, info_dict)
bayes_data = adjust_data_final(s_data, design, gamma_star, delta_star, stand_mean, var_pooled, info_dict, data)
# save model parameters in single object
model = {'design': design, 'SITE_labels': batch_labels,
'var_pooled':var_pooled, 'B_hat':B_hat, 'grand_mean': grand_mean,
'gamma_star': gamma_star, 'delta_star': delta_star, 'info_dict': info_dict,
'gamma_hat': LS_dict['gamma_hat'], 'delta_hat': np.array(LS_dict['delta_hat']),
'gamma_bar': LS_dict['gamma_bar'], 't2': LS_dict['t2'],
'a_prior': LS_dict['a_prior'], 'b_prior': LS_dict['b_prior'],
'smooth_model': smooth_model, 'eb': eb}
'smooth_model': smooth_model, 'eb': eb, 'ref_batch':ref_batch}
# transpose data to return to original shape
bayes_data = bayes_data.T

Expand Down
64 changes: 54 additions & 10 deletions neuroHarmonize/neuroCombat.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
"""
ComBat for correcting batch effects in neuroimaging data
"""
#from __future__ import division
#from __future__ import absolute_import, print_function

from __future__ import absolute_import, print_function
import pandas as pd
import numpy as np
import numpy.linalg as la

import math
import copy

def neuroCombat(data,
covars,
batch_col,
discrete_cols=None,
continuous_cols=None):
continuous_cols=None,
ref_batch=None):
"""
Run ComBat to correct for batch effects in neuroimaging data
Expand All @@ -37,6 +38,9 @@ def neuroCombat(data,
continuous_cols : string or list of strings
- variables which are continous that you want to predict
- e.g. depression sub-scores
ref_batch : batch (site or scanner) to be used as reference for batch adjustment.
- None by default
Returns
-------
Expand Down Expand Up @@ -78,12 +82,27 @@ def neuroCombat(data,
cat_cols = [np.where(covar_labels==c_var)[0][0] for c_var in discrete_cols]
num_cols = [np.where(covar_labels==n_var)[0][0] for n_var in continuous_cols]

# convert batch col to integer
if ref_batch is None:
ref_level=None
else:
ref_indices = np.argwhere((covars[:,batch_col]==ref_batch).squeeze())
if ref_indices.shape[0]==0:
ref_level=None
ref_batch=None
print('[neuroCombat] batch.ref not found. Setting to None.')
covars[:,batch_col]=np.unique(covars[:,batch_col],return_inverse=True)[-1]
else:
covars[:,batch_col] = np.unique(covars[:,batch_col],return_inverse=True)[-1]
ref_level = np.int(covars[ref_indices[0],batch_col])

# conver batch col to integer
covars[:,batch_col] = np.unique(covars[:,batch_col],return_inverse=True)[-1]
# covars[:,batch_col] = np.unique(covars[:,batch_col],return_inverse=True)[-1]
# create dictionary that stores batch info
(batch_levels, sample_per_batch) = np.unique(covars[:,batch_col],return_counts=True)
info_dict = {
'batch_levels': batch_levels.astype('int'),
'ref_level': ref_level,
'n_batch': len(batch_levels),
'n_sample': int(covars.shape[0]),
'sample_per_batch': sample_per_batch.astype('int'),
Expand All @@ -92,7 +111,7 @@ def neuroCombat(data,

# create design matrix
print('Creating design matrix..')
design = make_design_matrix(covars, batch_col, cat_cols, num_cols)
design = make_design_matrix(covars, batch_col, cat_cols, num_cols,ref_level)

# standardize data across features
print('Standardizing data across features..')
Expand All @@ -115,7 +134,7 @@ def neuroCombat(data,

return bayes_data.T

def make_design_matrix(Y, batch_col, cat_cols, num_cols,nb_class=None):
def make_design_matrix(Y, batch_col, cat_cols, num_cols,ref_level,nb_class=None):
"""
Return Matrix containing the following parts:
- one-hot matrix of batch variable (full)
Expand All @@ -142,6 +161,9 @@ def to_categorical(y, nb_classes=None):
else:
batch = np.unique(Y[:,batch_col],return_inverse=True)[-1]
batch_onehot = to_categorical(batch, len(np.unique(batch)))

if ref_level is not None:
batch_onehot[:,ref_level] = np.ones(batch_onehot.shape[0])

hstack_list.append(batch_onehot)

Expand All @@ -164,11 +186,21 @@ def standardize_across_features(X, design, info_dict):
n_batch = info_dict['n_batch']
n_sample = info_dict['n_sample']
sample_per_batch = info_dict['sample_per_batch']
batch_info = info_dict['batch_info']
ref_level = info_dict['ref_level']

B_hat = np.dot(np.dot(la.inv(np.dot(design.T, design)), design.T), X.T)
grand_mean = np.dot((sample_per_batch/ float(n_sample)).T, B_hat[:n_batch,:])
var_pooled = np.dot(((X - np.dot(design, B_hat).T)**2), np.ones((n_sample, 1)) / float(n_sample))

if ref_level is not None:
grand_mean = np.transpose(B_hat[ref_level,:])
X_ref = X[:,batch_info[ref_level]]
design_ref = design[batch_info[ref_level],:]
n_sample_ref = sample_per_batch[ref_level]
var_pooled = np.dot(((X_ref - np.dot(design_ref, B_hat).T)**2), np.ones((n_sample_ref, 1)) / float(n_sample_ref))
else:
grand_mean = np.dot((sample_per_batch/ float(n_sample)).T, B_hat[:n_batch,:])
var_pooled = np.dot(((X - np.dot(design, B_hat).T)**2), np.ones((n_sample, 1)) / float(n_sample))

stand_mean = np.dot(grand_mean.T.reshape((len(grand_mean), 1)), np.ones((1, n_sample)))
tmp = np.array(design.copy())
tmp[:,:n_batch] = 0
Expand Down Expand Up @@ -241,6 +273,7 @@ def it_sol(sdat, g_hat, d_hat, g_bar, t2, a, b, conv=0.0001):

def find_parametric_adjustments(s_data, LS, info_dict):
batch_info = info_dict['batch_info']
ref_level = info_dict['ref_level']

gamma_star, delta_star = [], []
for i, batch_idxs in enumerate(batch_info):
Expand All @@ -250,14 +283,22 @@ def find_parametric_adjustments(s_data, LS, info_dict):

gamma_star.append(temp[0])
delta_star.append(temp[1])

gamma_star = np.array(gamma_star)
delta_star = np.array(delta_star)

if ref_level is not None:
gamma_star[ref_level,:] = np.zeros(gamma_star.shape[-1])
delta_star[ref_level,:] = np.ones(delta_star.shape[-1])

return np.array(gamma_star), np.array(delta_star)

def adjust_data_final(s_data, design, gamma_star, delta_star, stand_mean, var_pooled, info_dict):
def adjust_data_final(s_data, design, gamma_star, delta_star, stand_mean, var_pooled, info_dict,data):
sample_per_batch = info_dict['sample_per_batch']
n_batch = info_dict['n_batch']
n_sample = info_dict['n_sample']
batch_info = info_dict['batch_info']
ref_level = info_dict['ref_level']

batch_design = design[:,:n_batch]

Expand All @@ -276,4 +317,7 @@ def adjust_data_final(s_data, design, gamma_star, delta_star, stand_mean, var_po
vpsq = np.sqrt(var_pooled).reshape((len(var_pooled), 1))
bayesdata = bayesdata * np.dot(vpsq, np.ones((1, n_sample))) + stand_mean

if ref_level is not None:
bayesdata[:, batch_info[ref_level]] = data[:,batch_info[ref_level]]

return bayesdata

0 comments on commit bf0d224

Please sign in to comment.