Skip to content

Commit

Permalink
Improve component plots (#1045)
Browse files Browse the repository at this point in the history
* Improve component plots maybe.

* Update static_figures.py

* Update static_figures.py

* Update static_figures.py

* Update static_figures.py

* Update static_figures.py

* Update static_figures.py

* Update static_figures.py

* Update static_figures.py

* Update static_figures.py

* Put title on top axis.

* Use suptitle.

* Update static_figures.py

* Update static_figures.py

* Update static_figures.py

* Update static_figures.py

* Update static_figures.py

* Add time ticks to time series subplot.

* Update static_figures.py

* Update static_figures.py
  • Loading branch information
tsalo authored Feb 29, 2024
1 parent 8d7e5ff commit 7c0b91f
Showing 1 changed file with 159 additions and 83 deletions.
242 changes: 159 additions & 83 deletions tedana/reporting/static_figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import logging
import os
from io import BytesIO

import matplotlib
import nibabel as nb
import numpy as np

matplotlib.use("AGG")
Expand Down Expand Up @@ -180,9 +182,136 @@ def carpet_plot(
)


def comp_figures(ts, mask, comptable, mmix, io_generator, png_cmap):
def plot_component(
*,
stat_img,
component_timeseries,
power_spectrum,
frequencies,
tr,
classification_color,
png_cmap,
title,
out_file,
):
"""Create a figure with a component's spatial map, time series, and power spectrum.
Parameters
----------
stat_img : :obj:`nibabel.Nifti1Image`
Image of the component's spatial map
component_timeseries : (T,) array_like
Time series of the component
power_spectrum : (T,) array_like
Power spectrum of the component's time series
frequencies : (T,) array_like
Frequencies for the power spectrum
tr : float
Repetition time of the time series
classification_color : str
Color to use for the time series and power spectrum
png_cmap : str
Colormap to use for the spatial map
title : str
Title for the figure
out_file : str
Path to save the figure
"""
Create static figures that highlight certain aspects of tedana processing.
import matplotlib.image as mpimg
from matplotlib import gridspec

# Set range to ~1/10th of max positive or negative beta
imgmax = 0.1 * np.max(np.abs(stat_img.get_fdata()))

# Save the figure to an in-memory file object
display = plotting.plot_stat_map(
stat_img,
bg_img=None,
display_mode="mosaic",
cut_coords=5,
vmax=imgmax,
cmap=png_cmap,
symmetric_cbar=True,
colorbar=False,
draw_cross=False,
annotate=False,
)
display.annotate(size=30)
example_ax = list(display.axes.values())[0]
nilearn_fig = example_ax.ax.figure

with BytesIO() as buf:
nilearn_fig.savefig(buf, format="png")
buf.seek(0)

# Read the image back into an image array
img = mpimg.imread(buf)

plt.close(nilearn_fig)

# Make the width of the original image the width of the new figure,
# but add top and bottom axes that each take up 10% of the height
width = 10
img_hw_ratio = img.shape[0] / img.shape[1]
img_dims = (width, (width * img_hw_ratio * 1.6))

# Create a new figure and gridspec
fig = plt.figure(figsize=img_dims)
fig.suptitle(title, fontsize=14)
gs = gridspec.GridSpec(3, 1, height_ratios=[2, 10, 2], hspace=0.2)

# Create three subplots
# First is the time series of the component
ax_ts = fig.add_subplot(gs[0])
ax_ts.plot(component_timeseries, color=classification_color)
ax_ts.set_xlim(0, len(component_timeseries) - 1)
ax_ts.set_yticks([])

max_xticks = 10
xloc = plt.MaxNLocator(max_xticks)
ax_ts.xaxis.set_major_locator(xloc)

ax_ts2 = ax_ts.twiny()
ax1_xs = ax_ts.get_xticks()

ax2_xs = []
for x in ax1_xs:
# Limit to 2 decimal places
seconds_val = round(x * tr, 2)
ax2_xs.append(seconds_val)

ax_ts2.set_xticks(ax1_xs)
ax_ts2.set_xlim(ax_ts.get_xbound())
ax_ts2.set_xticklabels(ax2_xs)
ax_ts2.set_xlabel("seconds")

# Second is the cached image of the spatial map
ax_map = fig.add_subplot(gs[1])
ax_map.axis("off")
ax_map.imshow(img)

# Third is the power spectrum of the component's time series
ax_fft = fig.add_subplot(gs[2])
ax_fft.plot(frequencies, power_spectrum, color=classification_color)
ax_fft.set_title("One-Sided FFT")
ax_fft.set_xlabel("Frequency (Hz)")
ax_fft.set_xlim(0, frequencies.max())
ax_fft.set_yticks([])

# Get the current positions of the second and last subplots
# pos_ts = ax_ts.get_position()
# pos_freq = ax_fft.get_position()

# Adjust the positions of the second and last subplots
# ax_ts.set_position([pos_ts.x0, pos_ts.y0 - 0.1, pos_ts.width, pos_ts.height])
# ax_fft.set_position([pos_freq.x0, pos_freq.y0 - 0.2, pos_freq.width, pos_freq.height])

fig.savefig(out_file)
plt.close(fig)


def comp_figures(ts, mask, comptable, mmix, io_generator, png_cmap):
"""Create static figures that highlight certain aspects of tedana processing.
This includes a figure for each component showing the component time course,
the spatial weight map and a fast Fourier transform of the time course.
Expand All @@ -202,127 +331,74 @@ def comp_figures(ts, mask, comptable, mmix, io_generator, png_cmap):
io_generator : :obj:`tedana.io.OutputGenerator`
Output Generator object to use for this workflow
"""
# Get the lenght of the timeseries
n_vols = len(mmix)

# Flip signs of mixing matrix as needed
mmix = mmix * comptable["optimal sign"].values

# regenerate the beta images
ts_b = stats.get_coeffs(ts, mmix, mask)
ts_b = ts_b.reshape(io_generator.reference_img.shape[:3] + ts_b.shape[1:])
# trim edges from ts_b array
ts_b = _trim_edge_zeros(ts_b)

# Mask out remaining zeros
ts_b = np.ma.masked_where(ts_b == 0, ts_b)
component_maps_arr = stats.get_coeffs(ts, mmix, mask)
component_maps_arr = component_maps_arr.reshape(
io_generator.reference_img.shape[:3] + component_maps_arr.shape[1:],
)

# Get repetition time from reference image
tr = io_generator.reference_img.header.get_zooms()[-1]

# Create indices for 6 cuts, based on dimensions
cuts = [ts_b.shape[dim] // 6 for dim in range(3)]
expl_text = ""

# Remove trailing ';' from rationale column
# comptable["rationale"] = comptable["rationale"].str.rstrip(";")
for compnum in comptable.index.values:
if comptable.loc[compnum, "classification"] == "accepted":
line_color = "g"
expl_text = "accepted reason(s): " + str(comptable.loc[compnum, "classification_tags"])

elif comptable.loc[compnum, "classification"] == "rejected":
line_color = "r"
expl_text = "rejected reason(s): " + str(comptable.loc[compnum, "classification_tags"])

elif comptable.loc[compnum, "classification"] == "ignored":
line_color = "k"
expl_text = "ignored reason(s): " + str(comptable.loc[compnum, "classification_tags"])

else:
# Classification not added
# If new, this will keep code running
line_color = "0.75"
expl_text = "other classification"

allplot = plt.figure(figsize=(10, 9))
ax_ts = plt.subplot2grid((5, 6), (0, 0), rowspan=1, colspan=6, fig=allplot)

ax_ts.set_xlabel("TRs")
ax_ts.set_xlim(0, n_vols)
plt.yticks([])
# Make a second axis with units of time (s)
max_xticks = 10
xloc = plt.MaxNLocator(max_xticks)
ax_ts.xaxis.set_major_locator(xloc)

ax_ts2 = ax_ts.twiny()
ax1_xs = ax_ts.get_xticks()

ax2_xs = []
for x in ax1_xs:
# Limit to 2 decimal places
seconds_val = round(x * tr, 2)
ax2_xs.append(seconds_val)
ax_ts2.set_xticks(ax1_xs)
ax_ts2.set_xlim(ax_ts.get_xbound())
ax_ts2.set_xticklabels(ax2_xs)
ax_ts2.set_xlabel("seconds")

ax_ts.plot(mmix[:, compnum], color=line_color)

# Title will include variance from comptable
comp_var = f"{comptable.loc[compnum, 'variance explained']:.2f}"
comp_kappa = f"{comptable.loc[compnum, 'kappa']:.2f}"
comp_rho = f"{comptable.loc[compnum, 'rho']:.2f}"

plt_title = (
f"Comp. {compnum}: variance: {comp_var}%, kappa: {comp_kappa}, "
f"rho: {comp_rho}, {expl_text}"
)
component_img = nb.Nifti1Image(
component_maps_arr[:, :, :, compnum],
affine=io_generator.reference_img.affine,
header=io_generator.reference_img.header,
)

title = ax_ts.set_title(plt_title)
title.set_y(1.5)

# Set range to ~1/10th of max positive or negative beta
imgmax = 0.1 * np.abs(ts_b[:, :, :, compnum]).max()
imgmin = imgmax * -1

for idx, _ in enumerate(cuts):
for imgslice in range(1, 6):
ax = plt.subplot2grid((5, 6), (idx + 1, imgslice - 1), rowspan=1, colspan=1)
ax.axis("off")

if idx == 0:
to_plot = np.rot90(ts_b[imgslice * cuts[idx], :, :, compnum])
if idx == 1:
to_plot = np.rot90(ts_b[:, imgslice * cuts[idx], :, compnum])
if idx == 2:
to_plot = ts_b[:, :, imgslice * cuts[idx], compnum]

ax_im = ax.imshow(to_plot, vmin=imgmin, vmax=imgmax, aspect="equal", cmap=png_cmap)

# Add a color bar to the plot.
ax_cbar = allplot.add_axes([0.8, 0.3, 0.03, 0.37])
cbar = allplot.colorbar(ax_im, ax_cbar)
cbar.set_label("Component Beta", rotation=90)
cbar.ax.yaxis.set_label_position("left")
component_timeseries = mmix[:, compnum]

# Get fft and freqs for this subject
# Get fft and freqs for this component
# adapted from @dangom
spectrum, freqs = utils.get_spectrum(mmix[:, compnum], tr)

# Plot it
ax_fft = plt.subplot2grid((5, 6), (4, 0), rowspan=1, colspan=6)
ax_fft.plot(freqs, spectrum)
ax_fft.set_title("One Sided fft")
ax_fft.set_xlabel("Hz")
ax_fft.set_xlim(freqs[0], freqs[-1])
plt.yticks([])

# Fix spacing so TR label does overlap with other plots
allplot.subplots_adjust(hspace=0.4)
spectrum, freqs = utils.get_spectrum(component_timeseries, tr)

plot_name = f"{io_generator.prefix}comp_{str(compnum).zfill(3)}.png"
compplot_name = os.path.join(io_generator.out_dir, "figures", plot_name)
plt.savefig(compplot_name)
plt.close()

plot_component(
stat_img=component_img,
component_timeseries=component_timeseries,
power_spectrum=spectrum,
frequencies=freqs,
tr=tr,
classification_color=line_color,
png_cmap=png_cmap,
title=plt_title,
out_file=compplot_name,
)


def pca_results(criteria, n_components, all_varex, io_generator):
Expand Down

0 comments on commit 7c0b91f

Please sign in to comment.