Skip to content

Commit

Permalink
Merge pull request #17 from mtorabi59/main
Browse files Browse the repository at this point in the history
rename folder
  • Loading branch information
mtorabi59 authored Feb 8, 2024
2 parents d3fc715 + 3f9a8f2 commit 65eeb77
Show file tree
Hide file tree
Showing 19 changed files with 290 additions and 131 deletions.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from copy import deepcopy
import os

Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
1 change: 0 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
.. image:: https://zenodo.org/badge/DOI/10.5281/zenodo.10161176.svg
:target: https://zenodo.org/doi/10.5281/zenodo.10161176


pydfc
=======
An implementation of several well-known dynamic Functional Connectivity (dFC) assessment methods.
Expand Down
30 changes: 30 additions & 0 deletions pydfc/comparison/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,36 @@ def title2file_name(title):
'''
return title.replace(" ", "_")


def plot_TS(X, Fs, title=''):
'''
X = (n_time, n_features)
'''
if X.shape[1] > 10:
print('Too many features to plot')
return
fig_width = 35
fig_height = 2 * X.shape[1]
fig, axes = plt.subplots(X.shape[1], 1, figsize=(fig_width, fig_height),
facecolor='w', edgecolor='k')
time = np.arange(0, X.shape[0])/Fs
for i in range(0, X.shape[1]):
axes[i].plot(time, X[:, i], linewidth=4)
axes[i].set_title('TC'+str(i), fontdict={'fontsize': 15, 'fontweight': 'bold'})
plt.suptitle(title)
plt.xlabel('Time (s)')
plt.show()
# if save_image:
# folder = output_root[:output_root.rfind('/')]
# if not os.path.exists(folder):
# os.makedirs(folder)
# fig.savefig(output_root+title.replace(" ", "_")+'.'+save_fig_format,
# dpi=fig_dpi, bbox_inches=fig_bbox_inches, pad_inches=fig_pad, format=save_fig_format
# )
# plt.close()
# else:
# plt.show()

def plot_sample_dFC(D, x,
title='',
cmap='seismic',
Expand Down
48 changes: 44 additions & 4 deletions pydfc/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,8 @@
@author: Mohammad Torabi
"""

from re import S
from tkinter import N
import numpy as np
import hdf5storage
import scipy.io as sio
import os

from .dfc_utils import intersection, label2network
Expand All @@ -24,6 +21,9 @@ def find_subj_list(data_root, sessions):
the files must follow the format: subjectID_sessionID
only these files should be in the data_root
'''
if data_root[-1] != '/':
data_root += '/'

ALL_FILES = os.listdir(data_root)
FOLDERS = [item for item in ALL_FILES if os.path.isdir(data_root+item)]

Expand All @@ -43,7 +43,7 @@ def find_subj_list(data_root, sessions):
for subj in SUBJECTS:
kept_subjs.append(subj)
for session in sessions:
if not os.path.exists(data_root+subj+'_'+session):
if not os.path.exists(f"{data_root}{subj}_{session}"):
failed_subjs.append(subj)
kept_subjs.remove(subj)
break
Expand Down Expand Up @@ -307,4 +307,44 @@ def multi_nifti2timeseries(
return BOLD_multi


def load_TS(
data_root,
file_name,
SESSIONs,
subj_id2load=None
):
'''
load a TIME_SERIES object from a .npy file
if SESSIONs is a list, it will load all the sessions,
if it is a string, it will load that session
if subj_id2load is None, it will load all the subjects
'''
# check if SESSIONs is a list or a string
flag = False
if type(SESSIONs) is str:
SESSIONs = [SESSIONs]
flag = True

if subj_id2load is None:
SUBJECTS = find_subj_list(data_root, sessions=SESSIONs)
else:
SUBJECTS = [subj_id2load]

TS = {}
for session in SESSIONs:
TS[session] = None
for subj in SUBJECTS:
subj_fldr = f"{subj}_{session}"
time_series = np.load(f"{data_root}/{subj_fldr}/{file_name}", allow_pickle='True').item()
if TS[session] is None:
TS[session] = time_series
else:
TS[session].concat_ts(time_series)

if flag:
return TS[SESSIONs[0]]
return TS



####################################################################################################################################
22 changes: 18 additions & 4 deletions pydfc/dfc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def dFC_mask(dFC_mat, mask):
def dFC_mat2vec(C_t):
'''
C_t must be an array of matrices or a single matrix
diagonal values not included. if you want to include
diagonal values will not be included. if you want to include
them set k=0
if C_t is a single matrix, F will be one dim
changing F will not change C_t
Expand Down Expand Up @@ -491,10 +491,13 @@ def visualize_conn_mat(C, axis=None, title='',
cmap='seismic',
V_MIN=None, V_MAX=None,
node_networks=None,
title_fontsize=18
title_fontsize=18,
loc_x=None, loc_y=None,
):
'''
C is (regions, regions)
you can use loc_x and loc_y to set the location of the image
loc_x and loc_y are lists of two elements, [start, end]
'''

if axis is None:
Expand All @@ -508,8 +511,19 @@ def visualize_conn_mat(C, axis=None, title='',
if V_MIN is None:
V_MIN = -1*V_MAX

im = axis.imshow(C, interpolation='nearest', aspect='equal', cmap=cmap, # 'viridis' or 'jet'
vmin=V_MIN, vmax=V_MAX)
if loc_x is None or loc_y is None:
im = axis.imshow(
C,
interpolation='nearest', aspect='equal', cmap=cmap, # 'viridis' or 'jet'
vmin=V_MIN, vmax=V_MAX
)
else:
im = axis.imshow(
C,
interpolation='nearest', aspect='equal', cmap=cmap, # 'viridis' or 'jet'
vmin=V_MIN, vmax=V_MAX,
extent=[loc_x[0], loc_x[1], loc_y[0], loc_y[1]]
)

# cluster node networks
if not node_networks is None:
Expand Down
114 changes: 105 additions & 9 deletions pydfc/task_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
import numpy as np
from nilearn import glm
from scipy import signal
import matplotlib.pyplot as plt
from .dfc_utils import rank_norm
from .dfc_utils import visualize_conn_mat
from .dfc_utils import TR_intersection

################################# Preprocessing Functions ####################################

Expand Down Expand Up @@ -44,12 +48,80 @@ def events_time_to_labels(events, TR_mri, num_time_mri, event_types=[], oversamp
return event_labels, Fs


################################# Visualization Functions ####################################

def plot_task_dFC(task_labels, dFC_lst, event_types, Fs_mri, TR_step=12):
'''
task_labels: numpy array of shape (num_time_task, num_event_types) containing the event or task labels
this function assumes that the task data has the same Fs as the dFC data, i.e. MRI data
and that the time points of the task data are aligned with the time points of the dFC data
'''
conn_mat_size = 20
scale_task_plot = 20

# plot task_data['event_labels']
fig = plt.figure(figsize=(50, 200))

ax = plt.gca()

time = np.arange(0, task_labels.shape[0])/Fs_mri
for i in range(0, task_labels.shape[1]):
ax.plot(time, task_labels[:, i]*scale_task_plot, label=event_types[i], linewidth=4)
plt.legend()
plt.xlabel('Time (s)')

comman_TRs = TR_intersection(dFC_lst)
TRs_dFC = comman_TRs[::TR_step]

for dFC_id, dFC in enumerate(dFC_lst):
dFC_mat = rank_norm(dFC.get_dFC_mat(), global_norm=True)
TR_array = dFC.TR_array
for i in range(0, len(TR_array), 1):

C = dFC_mat[i, :, :]
TR = TR_array[i]
if not TR in TRs_dFC:
continue
visualize_conn_mat(
C=C, axis=ax, title='',
cmap='plasma',
V_MIN=0, V_MAX=None,
node_networks=None,
title_fontsize=18,
loc_x = [TR/Fs_mri-conn_mat_size/2, TR/Fs_mri+conn_mat_size/2],
loc_y = [(1+dFC_id)*conn_mat_size, (2+dFC_id)*conn_mat_size],
)

x1, y1 = [TR/Fs_mri, TR/Fs_mri], [conn_mat_size, 0]
ax.plot(x1, y1, color='k', linestyle='-', linewidth=2)

plt.show()


################################# PCA Functions ####################################

# def BOLD


################################# Prediction Functions ####################################

from sklearn.linear_model import LinearRegression

def linear_reg(X, y):
'''
X = (n_samples, n_features)
y = (n_samples, n_targets)
'''
reg = LinearRegression().fit(X, y)
print(reg.score(X, y))
return reg.predict(X)

################################# Validation Functions ####################################


def event_conv_hrf(event_signal, TR_mri, TR_task):
time_length_HRF = 32.0 # in sec
hrf_model = 'glover' # 'spm' or 'glover'
hrf_model = 'spm' # 'spm' or 'glover'

TR_HRF = TR_task
oversampling = TR_mri/TR_HRF # more samples per TR than the func data to have a better HRF resolution, same as for event_labels
Expand All @@ -76,24 +148,48 @@ def event_labels_conv_hrf(event_labels, TR_mri, TR_task):
'''

event_labels = np.array(event_labels)
print(event_labels.shape)
L = event_labels.shape[0]
event_ids = np.unique(event_labels)
event_ids = event_ids.astype(int)
print(event_ids)
events_hrf = np.zeros((L, len(event_ids)-1)) # 0 is not an event, is the resting state
print(events_hrf.shape)
events_hrf = np.zeros((L, len(event_ids))) # 0 is the resting state
for i, event_id in enumerate(event_ids):
print(event_id)
# 0 is not an event, is the resting state
if event_id == 0:
continue
event_signal = np.zeros(L)
event_signal[event_labels[:, 0]==event_id] = 1.0

# -1 because the first event is the resting state
events_hrf[:, i-1] = event_conv_hrf(event_signal, TR_mri, TR_task)
events_hrf[:, i] = event_conv_hrf(event_signal, TR_mri, TR_task)

# the time points that are not in any event are considered as resting state
events_hrf[np.sum(events_hrf[:, 1:], axis=1)==0.0, 0] = 1.0

# time_length_task = len(event_labels)*TR_task

return events_hrf
return events_hrf


def downsample_events_hrf(events_hrf, TR_mri, TR_task, method='uniform'):
'''
method:
uniform
resample
decimate
no major difference was observed between these methods
'''
events_hrf_ds = []
for i in range(events_hrf.shape[1]):
if method=='uniform':
events_hrf_ds.append(
events_hrf[::int(TR_mri/TR_task), i]
)
elif method=='resample':
events_hrf_ds.append(
signal.resample(events_hrf[:, i], int(events_hrf.shape[0]*TR_task/TR_mri))
)
elif method=='decimate':
events_hrf_ds.append(
signal.decimate(events_hrf[:, i], int(TR_mri/TR_task))
)
events_hrf_ds = np.array(events_hrf_ds).T
return events_hrf_ds
2 changes: 1 addition & 1 deletion pydfc/time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def get_subj_ts(self, subjs_id=None):


def append_ts(self, new_time_series, time_array=None, subj_id=None):
# append new time series to existing ones
# append new time series numpy array to existing ones
# truncate and node selection , etc will be automatically applied to new TS;
# However, at first the new TS must have the same properties as the original properties of
# the existing TSs
Expand Down
Loading

0 comments on commit 65eeb77

Please sign in to comment.