Skip to content

Commit

Permalink
Merge pull request #11 from mtorabi59/main
Browse files Browse the repository at this point in the history
Update MA demo
  • Loading branch information
mtorabi59 authored Nov 20, 2023
2 parents 24e2c2d + ded01a2 commit c21bd09
Show file tree
Hide file tree
Showing 19 changed files with 823 additions and 183 deletions.
11 changes: 11 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,14 @@
__pycache__
*.pyc
*.cpython
sample_data/sub-0001_task-restingstate_acq-mb3_space-MNI152NLin2009cAsym_desc-preproc_bold.nii.gz
sample_data/sub-0001_task-restingstate_acq-mb3_desc-confounds_regressors.tsv
sample_data/sub-0002_task-restingstate_acq-mb3_space-MNI152NLin2009cAsym_desc-preproc_bold.nii.gz
sample_data/sub-0002_task-restingstate_acq-mb3_desc-confounds_regressors.tsv
sample_data/sub-0003_task-restingstate_acq-mb3_space-MNI152NLin2009cAsym_desc-preproc_bold.nii.gz
sample_data/sub-0003_task-restingstate_acq-mb3_desc-confounds_regressors.tsv
sample_data/sub-0004_task-restingstate_acq-mb3_space-MNI152NLin2009cAsym_desc-preproc_bold.nii.gz
sample_data/sub-0004_task-restingstate_acq-mb3_desc-confounds_regressors.tsv
sample_data/sub-0005_task-restingstate_acq-mb3_space-MNI152NLin2009cAsym_desc-preproc_bold.nii.gz
sample_data/sub-0005_task-restingstate_acq-mb3_desc-confounds_regressors.tsv

390 changes: 217 additions & 173 deletions multi_analysis_demo.ipynb

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions multi_analysis_dfc/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def nifti2array(nifti_file,
and global signal regression
is applied.
'''
from nilearn.input_data import NiftiLabelsMasker
from nilearn.maskers import NiftiLabelsMasker
from nilearn import datasets
from nilearn.plotting import find_parcellation_cut_coords
from nilearn.interfaces.fmriprep import load_confounds
Expand Down Expand Up @@ -232,7 +232,7 @@ def nifti2timeseries(
session=None,
):
'''
this function is only for single subject data loading
this function is only for single subject and single session data loading
it uses nilearn maskers to extract ROI signals from nifti files
and returns a TIME_SERIES object
Expand Down Expand Up @@ -278,7 +278,7 @@ def multi_nifti2timeseries(
session=None,
):
'''
loading data of multiple subjects from their niifti files
loading data of multiple subjects, but single session, from their nifti files
'''
BOLD_multi = None
for subj_id, nifti_file in zip(subj_id_list, nifti_files_list):
Expand Down
2 changes: 1 addition & 1 deletion multi_analysis_dfc/dfc_methods/sliding_window_clustr.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def estimate_FCS(self, time_series):
if dFC_raw.n_time<self.params['n_subj_clstrs']:
print( \
'Number of subject-level clusters cannot be more than SW dFC samples! n_subj_clstrs was changed to ' \
+ str(dFC_raw.n_time))
+ str(dFC_raw.n_time) + '. This change will cause problems in similarity assessment.')
self.params['n_subj_clstrs'] = dFC_raw.n_time

FCS, _ = self.cluster_FC( \
Expand Down
71 changes: 65 additions & 6 deletions multi_analysis_dfc/task_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,26 @@
@author: Mohammad Torabi
"""

from tracemalloc import start
import numpy as np
from nilearn import glm
from scipy import signal

################################# Preprocessing Functions ####################################

def events_time_to_labels(events, Fs, num_time, event_types=[], return_0_1=False):
def events_time_to_labels(events, TR_mri, num_time_mri, event_types=[], oversampling=50, return_0_1=False):
'''
event_types is a list of event types to be considered. If None, 0 and 1s will be returned.
Assigns the longest event in each TR to that TR (in the interval from last TR to current TR).
It assumes that the first time point is TR0 which corresponds to [0 sec, TR sec] interval.
oversampling: number of samples per TR_mri to improve the time resolution of tasks
'''
assert events[0, 0]=='onset', 'The first column of the events file should be the onset!'
assert events[0, 1]=='duration', 'The second column of the events file should be the duration!'
assert events[0, 2]=='trial_type', 'The third column of the events file should be the trial type!'

event_labels = np.zeros((num_time, 1))
Fs = float(1 / TR_mri) * oversampling
num_time_task = int(num_time_mri * oversampling)
event_labels = np.zeros((num_time_task, 1))
for i in range(events.shape[0]):
# skip the first row which is the header
if i==0:
Expand All @@ -30,11 +34,66 @@ def events_time_to_labels(events, Fs, num_time, event_types=[], return_0_1=False
if events[i, 2] in event_types:
start_time = float(events[i, 0])
end_time = float(events[i, 0]) + float(events[i, 1])
start_TR = np.round(start_time*Fs)
end_TR = np.round(end_time*Fs)
start_TR = int(np.rint(start_time * Fs))
end_TR = int(np.rint(end_time * Fs))
event_labels[start_TR:end_TR] = event_types.index(events[i, 2])

if return_0_1:
event_labels = np.multiply(event_labels!=0, 1)

return event_labels
return event_labels, Fs


################################# Validation Functions ####################################


def event_conv_hrf(event_signal, TR_mri, TR_task):
time_length_HRF = 32.0 # in sec
hrf_model = 'glover' # 'spm' or 'glover'

TR_HRF = TR_task
oversampling = TR_mri/TR_HRF # more samples per TR than the func data to have a better HRF resolution, same as for event_labels
if hrf_model=='glover':
HRF = glm.first_level.glover_hrf(tr=TR_mri, oversampling=oversampling, time_length=time_length_HRF, onset=0.0)
elif hrf_model=='spm':
HRF = glm.first_level.spm_hrf(tr=TR_mri, oversampling=oversampling, time_length=time_length_HRF, onset=0.0)

events_hrf = signal.convolve(HRF, event_signal, mode='full')[:len(event_signal)]

return events_hrf


def event_labels_conv_hrf(event_labels, TR_mri, TR_task):
'''
event_labels: event labels including 0 and event ids at the time each event happens
TR_mri: TR of MRI
TR_task: TR of task
assums that 0 is the resting state
return: event labels convolved with HRF for each event type
the convolved event labels have the same length as the event_labels
event type i convolved with HRF is in events_hrf[:, i-1]
'''

event_labels = np.array(event_labels)
print(event_labels.shape)
L = event_labels.shape[0]
event_ids = np.unique(event_labels)
event_ids = event_ids.astype(int)
print(event_ids)
events_hrf = np.zeros((L, len(event_ids)-1)) # 0 is not an event, is the resting state
print(events_hrf.shape)
for i, event_id in enumerate(event_ids):
print(event_id)
# 0 is not an event, is the resting state
if event_id == 0:
continue
event_signal = np.zeros(L)
event_signal[event_labels[:, 0]==event_id] = 1.0

# -1 because the first event is the resting state
events_hrf[:, i-1] = event_conv_hrf(event_signal, TR_mri, TR_task)

# time_length_task = len(event_labels)*TR_task

return events_hrf
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import statsmodels.api as sm
from statsmodels.formula.api import ols
from sklearn.manifold import TSNE
from math import ceil

import matplotlib.pyplot as plt
import matplotlib as mpl
Expand Down Expand Up @@ -116,7 +117,75 @@ def plot_sample_dFC(D, x,
else:
plt.show()


def plot_rois(
node_networks,
nodes_locs,
save_image=False,
output_root=None
):

networks = list(np.unique(node_networks))

fig_width = 25
fig_height = len(networks)

fig, axes = plt.subplots(ceil(len(networks)/3), 3, figsize=(fig_width, fig_height),
facecolor='w', edgecolor='k')

axes = axes.ravel()

fig.subplots_adjust(
bottom=0.1,
top=0.85,
left=0.1,
right=0.9,
wspace=0.03,
hspace=0.3
)

for i, target_network in enumerate(networks):

locs = []
node_values = []
for node_id, node_network in enumerate(node_networks):
if node_network == target_network:
node_values.append(1)
locs.append(nodes_locs[node_id])

node_values = np.array(node_values)
locs = np.array(locs)

plot_markers(
node_values=node_values,
node_coords=locs,
node_size=100,
node_cmap='Reds',
node_vmax=1,
node_vmin=0,
annotate=True,
colorbar=False, axes=axes[i],
)

title = f"Resting State Networks"
# set subplot titles
for i, network in enumerate(networks):
axes[i].title.set_text(f"{network} network")
axes[i].title.set_size(20)
axes[i].title.set_weight('bold')

if save_image:
folder = output_root[:output_root.rfind('/')]
if not os.path.exists(folder):
os.makedirs(folder)
fig.savefig(output_root+title.replace(" ", "_")+'.'+save_fig_format,
dpi=fig_dpi, bbox_inches=fig_bbox_inches, pad_inches=fig_pad, format=save_fig_format
)
plt.close()
else:
plt.show()


def pairwise_cat_plots(data=None, x=None, y=None, z=None,
title='',
label_dict={},
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
11 changes: 11 additions & 0 deletions BIC_codes/visualization.py → rest_dFC/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,17 @@
save_image=save_image, output_root=output_root+'FCS/'
)

################################# RSNs visualization #################################

measure = ALL_RESULTS['measure_lst'][0]

plot_rois(
node_networks,
measure.TS_info['nodes_locs'],
save_image=save_image,
output_root=f"{output_root}RSNs/"
)

################################# dFC values distributions #################################
dFC_dist_plot = True

Expand Down
Loading

0 comments on commit c21bd09

Please sign in to comment.