diff --git a/rest_dFC/functions/post_analysis_funcs.py b/rest_dFC/functions/post_analysis_funcs.py index 80a76ed..21fbdac 100644 --- a/rest_dFC/functions/post_analysis_funcs.py +++ b/rest_dFC/functions/post_analysis_funcs.py @@ -14,6 +14,7 @@ import statsmodels.api as sm from statsmodels.formula.api import ols from sklearn.manifold import TSNE +from math import ceil import matplotlib.pyplot as plt import matplotlib as mpl @@ -116,7 +117,75 @@ def plot_sample_dFC(D, x, else: plt.show() + +def plot_rois( + node_networks, + nodes_locs, + save_image=False, + output_root=None + ): + + networks = list(np.unique(node_networks)) + + fig_width = 25 + fig_height = len(networks) + + fig, axes = plt.subplots(ceil(len(networks)/3), 3, figsize=(fig_width, fig_height), + facecolor='w', edgecolor='k') + axes = axes.ravel() + + fig.subplots_adjust( + bottom=0.1, + top=0.85, + left=0.1, + right=0.9, + wspace=0.03, + hspace=0.3 + ) + + for i, target_network in enumerate(networks): + + locs = [] + node_values = [] + for node_id, node_network in enumerate(node_networks): + if node_network == target_network: + node_values.append(1) + locs.append(nodes_locs[node_id]) + + node_values = np.array(node_values) + locs = np.array(locs) + + plot_markers( + node_values=node_values, + node_coords=locs, + node_size=100, + node_cmap='Reds', + node_vmax=1, + node_vmin=0, + annotate=True, + colorbar=False, axes=axes[i], + ) + + title = f"Resting State Networks" + # set subplot titles + for i, network in enumerate(networks): + axes[i].title.set_text(f"{network} network") + axes[i].title.set_size(20) + axes[i].title.set_weight('bold') + + if save_image: + folder = output_root[:output_root.rfind('/')] + if not os.path.exists(folder): + os.makedirs(folder) + fig.savefig(output_root+title.replace(" ", "_")+'.'+save_fig_format, + dpi=fig_dpi, bbox_inches=fig_bbox_inches, pad_inches=fig_pad, format=save_fig_format + ) + plt.close() + else: + plt.show() + + def pairwise_cat_plots(data=None, x=None, y=None, z=None, title='', label_dict={}, diff --git a/rest_dFC/visualization.py b/rest_dFC/visualization.py index d9c7074..f594e3b 100644 --- a/rest_dFC/visualization.py +++ b/rest_dFC/visualization.py @@ -57,6 +57,17 @@ save_image=save_image, output_root=output_root+'FCS/' ) +################################# RSNs visualization ################################# + +measure = ALL_RESULTS['measure_lst'][0] + +plot_rois( + node_networks, + measure.TS_info['nodes_locs'], + save_image=save_image, + output_root=f"{output_root}RSNs/" +) + ################################# dFC values distributions ################################# dFC_dist_plot = True