Skip to content

Commit

Permalink
Improve connectivity plots (PennLINC#988)
Browse files Browse the repository at this point in the history
  • Loading branch information
tsalo authored Nov 1, 2023
1 parent 959d210 commit f20a70f
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 20 deletions.
112 changes: 96 additions & 16 deletions xcp_d/interfaces/connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import numpy as np
import pandas as pd
from nilearn.maskers import NiftiLabelsMasker
from nilearn.plotting import plot_matrix
from nipype import logging
from nipype.interfaces.base import (
BaseInterfaceInputSpec,
Expand Down Expand Up @@ -628,12 +627,16 @@ def _run_interface(self, runtime):


class _ConnectPlotInputSpec(BaseInterfaceInputSpec):
in_file = File(exists=True, mandatory=True, desc="bold file")
atlas_names = InputMultiObject(
traits.Str,
mandatory=True,
desc="List of atlases. Aligned with the list of time series in time_series_tsv.",
)
atlas_tsvs = InputMultiObject(
traits.Str,
mandatory=True,
desc="The dseg.tsv associated with each atlas.",
)
correlations_tsv = InputMultiObject(
File(exists=True),
mandatory=True,
Expand All @@ -658,10 +661,69 @@ class ConnectPlot(SimpleInterface):
input_spec = _ConnectPlotInputSpec
output_spec = _ConnectPlotOutputSpec

def plot_matrix(self, corr_mat, network_labels, ax):
"""Plot matrix in subplot Axes."""
assert corr_mat.shape[0] == len(network_labels)
assert corr_mat.shape[1] == len(network_labels)

# Determine order of nodes while retaining original order of networks
unique_labels = []
for label in network_labels:
if label not in unique_labels:
unique_labels.append(label)

mapper = {label: f"{i:03d}_{label}" for i, label in enumerate(unique_labels)}
mapped_network_labels = [mapper[label] for label in network_labels]
community_order = np.argsort(mapped_network_labels)

# Sort parcels by community
corr_mat = corr_mat[community_order, :]
corr_mat = corr_mat[:, community_order]
np.fill_diagonal(corr_mat, 0)

# Get the community name associated with each network
labels = np.array(network_labels)[community_order]
unique_labels = sorted(list(set(labels)))
unique_labels = []
for label in labels:
if label not in unique_labels:
unique_labels.append(label)

# Find the locations for the community-separating lines
break_idx = [0]
end_idx = None
for label in unique_labels:
start_idx = np.where(labels == label)[0][0]
if end_idx:
break_idx.append(np.mean([start_idx, end_idx]))

end_idx = np.where(labels == label)[0][-1]

break_idx.append(len(labels))
break_idx = np.array(break_idx)

# Find the locations for the labels in the middles of the communities
label_idx = np.mean(np.vstack((break_idx[1:], break_idx[:-1])), axis=0)

# Plot the correlation matrix
ax.imshow(corr_mat, vmin=-1, vmax=1, cmap="seismic")

# Add lines separating networks
for idx in break_idx[1:-1]:
ax.axes.axvline(idx, color="black")
ax.axes.axhline(idx, color="black")

# Add network names
ax.axes.set_yticks(label_idx)
ax.axes.set_xticks(label_idx)
ax.axes.set_yticklabels(unique_labels)
ax.axes.set_xticklabels(unique_labels, rotation=90)
return ax

def _run_interface(self, runtime):
ATLAS_LOOKUP = {
"4S252Parcels": {
"title": "4S 252 Parcels",
"4S152Parcels": {
"title": "4S 152 Parcels",
"axes": [0, 0],
},
"4S452Parcels": {
Expand All @@ -678,29 +740,47 @@ def _run_interface(self, runtime):
},
}

# Generate a plot of each matrix's correlation coefficients
fig, axes = plt.subplots(2, 2)
fig.set_size_inches(20, 20)
font = {"weight": "normal", "size": 20}
COMMUNITY_LOOKUP = {
"4S152Parcels": "network_label",
"4S452Parcels": "network_label",
"Glasser": "community_yeo",
"Gordon": "community",
}

fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(20, 20))
for atlas_name, subdict in ATLAS_LOOKUP.items():
atlas_idx = self.inputs.atlas_names.index(atlas_name)
atlas_file = self.inputs.correlations_tsv[atlas_idx]
dseg_file = self.inputs.atlas_tsvs[atlas_idx]

column_name = COMMUNITY_LOOKUP[atlas_name]
dseg_df = pd.read_table(dseg_file)
corrs_df = pd.read_table(atlas_file, index_col="Node")

plot_matrix(
mat=corrs_df.to_numpy(),
colorbar=False,
vmax=1,
vmin=-1,
axes=axes[subdict["axes"][0], subdict["axes"][1]],
if atlas_name.startswith("4S"):
atlas_mapper = {
"CIT168Subcortical": "Subcortical",
"ThalamusHCP": "Thalamus",
"SubcorticalHCP": "Subcortical",
}
network_labels = dseg_df[column_name].fillna(dseg_df["atlas_name"]).tolist()
network_labels = [atlas_mapper.get(network, network) for network in network_labels]
else:
network_labels = dseg_df[column_name].fillna("None").tolist()

ax = axes[subdict["axes"][0], subdict["axes"][1]]
ax = self.plot_matrix(
corr_mat=corrs_df.to_numpy(),
network_labels=network_labels,
ax=ax,
)
axes[subdict["axes"][0], subdict["axes"][1]].set_title(
ax.set_title(
subdict["title"],
fontdict=font,
fontdict={"weight": "normal", "size": 20},
)

fig.tight_layout()

# Write the results out
self._results["connectplot"] = fname_presuffix(
"connectivityplot",
Expand Down
4 changes: 2 additions & 2 deletions xcp_d/tests/test_workflows_connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def test_init_functional_connectivity_nifti_wf(ds001419_data, tmp_path_factory):
censoring_df.to_csv(temporal_mask, sep="\t", index=False)

# Load atlases
atlas_names = ["4S1052Parcels", "4S252Parcels", "4S452Parcels", "Gordon", "Glasser"]
atlas_names = ["4S1052Parcels", "4S152Parcels", "4S452Parcels", "Gordon", "Glasser"]
atlas_files = [get_atlas_nifti(atlas_name)[0] for atlas_name in atlas_names]
atlas_labels_files = [get_atlas_nifti(atlas_name)[1] for atlas_name in atlas_names]

Expand Down Expand Up @@ -239,7 +239,7 @@ def test_init_functional_connectivity_cifti_wf(ds001419_data, tmp_path_factory):
censoring_df.to_csv(temporal_mask, sep="\t", index=False)

# Load atlases
atlas_names = ["4S1052Parcels", "4S252Parcels", "4S452Parcels", "Gordon", "Glasser"]
atlas_names = ["4S1052Parcels", "4S152Parcels", "4S452Parcels", "Gordon", "Glasser"]
atlas_files = [get_atlas_cifti(atlas_name)[0] for atlas_name in atlas_names]
atlas_labels_files = [get_atlas_cifti(atlas_name)[1] for atlas_name in atlas_names]

Expand Down
4 changes: 2 additions & 2 deletions xcp_d/workflows/connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,8 +636,8 @@ def init_functional_connectivity_nifti_wf(
# fmt:off
workflow.connect([
(inputnode, connectivity_plot, [
("denoised_bold", "in_file"),
("atlas_names", "atlas_names"),
("atlas_labels_files", "atlas_tsvs"),
]),
(functional_connectivity, connectivity_plot, [("correlations", "correlations_tsv")]),
])
Expand Down Expand Up @@ -860,8 +860,8 @@ def init_functional_connectivity_cifti_wf(
# fmt:off
workflow.connect([
(inputnode, connectivity_plot, [
("denoised_bold", "in_file"),
("atlas_names", "atlas_names"),
("atlas_labels_files", "atlas_tsvs"),
]),
(functional_connectivity, connectivity_plot, [("correlations", "correlations_tsv")]),
])
Expand Down

0 comments on commit f20a70f

Please sign in to comment.