Skip to content

Commit

Permalink
set_tstop_oct25
Browse files Browse the repository at this point in the history
  • Loading branch information
tianqi-cheng committed Oct 25, 2023
1 parent 45215b2 commit a821c26
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 38 deletions.
11 changes: 3 additions & 8 deletions hnn_core/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,16 +467,11 @@ def savgol_filter(self, h_freq):
self.sfreq)
return self

def plot(self, tmin=None, tmax=None, layer='agg', decim=None, ax=None,
color='k', show=True):
def plot(self, layer='agg', decim=None, ax=None, color='k', show=True):
"""Simple layer-specific plot function.
Parameters
----------
tmin : float or None
Start time of plot (in ms). If None, plot entire simulation.
tmax : float or None
End time of plot (in ms). If None, plot entire simulation.
layer : str
The layer to plot. Can be one of 'agg', 'L2', and 'L5'
decimate : int
Expand All @@ -493,8 +488,8 @@ def plot(self, tmin=None, tmax=None, layer='agg', decim=None, ax=None,
fig : instance of plt.fig
The matplotlib figure handle.
"""
return plot_dipole(self, tmin=tmin, tmax=tmax, ax=ax, layer=layer,
decim=decim, color=color, show=show)
return plot_dipole(self, ax=ax, layer=layer, decim=decim, color=color,
show=show)

def plot_psd(self, fmin=0, fmax=None, tmin=None, tmax=None, layer='agg',
color=None, label=None, ax=None, show=True):
Expand Down
14 changes: 5 additions & 9 deletions hnn_core/extracellular.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,9 +421,9 @@ def smooth(self, window_len):

return self

def plot_lfp(self, *, trial_no=None, contact_no=None, tmin=None, tmax=None,
ax=None, decim=None, color='cividis', voltage_offset=50,
voltage_scalebar=200, show=True):
def plot_lfp(self, *, trial_no=None, contact_no=None, ax=None, decim=None,
color='cividis', voltage_offset=50, voltage_scalebar=200,
show=True):
"""Plot laminar local field potential time series.
One plot is created for each trial. Multiple trials can be overlaid
Expand All @@ -435,11 +435,6 @@ def plot_lfp(self, *, trial_no=None, contact_no=None, tmin=None, tmax=None,
Trial number(s) to plot
contact_no : int | list of int | slice
Electrode contact number(s) to plot
tmin : float | None
Start time of plot in milliseconds. If None, plot entire
simulation.
tmax : float | None
End time of plot in milliseconds. If None, plot entire simulation.
ax : instance of matplotlib figure | None
The matplotlib axis
decim : int | list of int | None (default)
Expand Down Expand Up @@ -486,7 +481,7 @@ def plot_lfp(self, *, trial_no=None, contact_no=None, tmin=None, tmax=None,

for trial_data in plot_data:
fig = plot_laminar_lfp(
self.times, trial_data, tmin=tmin, tmax=tmax, ax=ax,
self.times, trial_data, ax=ax,
decim=decim, color=color,
voltage_offset=voltage_offset,
voltage_scalebar=voltage_scalebar,
Expand Down Expand Up @@ -534,6 +529,7 @@ class _ExtracellularArrayBuilder(object):
The instance of :class:`hnn_core.extracellular.ExtracellularArray` to
build in NEURON-Python
"""

def __init__(self, array):
self.array = array
self.n_contacts = array.n_contacts
Expand Down
37 changes: 16 additions & 21 deletions hnn_core/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ def plt_show(show=True, fig=None, **kwargs):
(fig or plt).show(**kwargs)


def plot_laminar_lfp(times, data, contact_labels, tmin=None, tmax=None,
ax=None, decim=None, color='cividis',
voltage_offset=50, voltage_scalebar=200, show=True):
def plot_laminar_lfp(times, data, contact_labels, ax=None, decim=None,
color='cividis', voltage_offset=50, voltage_scalebar=200,
show=True):
"""Plot laminar extracellular electrode array voltage time series.
Parameters
Expand All @@ -89,10 +89,6 @@ def plot_laminar_lfp(times, data, contact_labels, tmin=None, tmax=None,
Sampling times (in ms).
data : Two-dimensional Numpy array
The extracellular voltages as an (n_contacts, n_times) array.
tmin : float | None
Start time of plot in milliseconds. If None, plot entire simulation.
tmax : float | None
End time of plot in milliseconds. If None, plot entire simulation.
ax : instance of matplotlib figure | None
The matplotlib axis
decim : int | list of int | None (default)
Expand Down Expand Up @@ -168,11 +164,11 @@ def plot_laminar_lfp(times, data, contact_labels, tmin=None, tmax=None,
trace_offsets = np.arange(n_offsets)[:, np.newaxis] * voltage_offset

for contact_no, trace in enumerate(np.atleast_2d(data)):
plot_data, plot_times = _get_plot_data_trange(times, trace, tmin, tmax)
plot_data = trace
plot_times = times

if decim is not None:
plot_data, plot_times = _decimate_plot_data(decim, plot_data,
plot_times)
plot_data, plot_times = _decimate_plot_data(decim, trace, times)

if isinstance(color, np.ndarray):
col = color[contact_no]
Expand All @@ -182,6 +178,7 @@ def plot_laminar_lfp(times, data, contact_labels, tmin=None, tmax=None,
col = color
ax.plot(plot_times, plot_data + trace_offsets[contact_no],
label=f'C{contact_no}', color=col)
ax.set_xlim(right=plot_times[-1])

if voltage_offset is not None:
ax.set_ylim(-voltage_offset, n_offsets * voltage_offset)
Expand Down Expand Up @@ -220,18 +217,14 @@ def plot_laminar_lfp(times, data, contact_labels, tmin=None, tmax=None,
return ax.get_figure()


def plot_dipole(dpl, tmin=None, tmax=None, ax=None, layer='agg', decim=None,
def plot_dipole(dpl, ax=None, layer='agg', decim=None,
color='k', label="average", average=False, show=True):
"""Simple layer-specific plot function.
Parameters
----------
dpl : instance of Dipole | list of Dipole instances
The Dipole object.
tmin : float or None
Start time of plot in milliseconds. If None, plot entire simulation.
tmax : float or None
End time of plot in milliseconds. If None, plot entire simulation.
ax : instance of matplotlib figure | None
The matplotlib axis
layer : str
Expand Down Expand Up @@ -288,19 +281,19 @@ def plot_dipole(dpl, tmin=None, tmax=None, ax=None, layer='agg', decim=None,
if layer in dpl_trial.data.keys():

# extract scaled data and times
data, times = _get_plot_data_trange(dpl_trial.times,
dpl_trial.data[layer],
tmin, tmax)
data = dpl_trial.data[layer]
times = dpl_trial.times
if decim is not None:
data, times = _decimate_plot_data(decim, data, times)
data, times = _decimate_plot_data(
decim, dpl_trial.data[layer], dpl_trial.times)
if idx == len(dpl) - 1 and average:
# the average dpl
ax.plot(times, data, color=color, label=label, lw=1.5)
else:
alpha = 0.5 if average else 1.
ax.plot(times, data, color=_lighten_color(color, 0.5),
alpha=alpha, lw=1.)

ax.set_xlim(right=dpl_trial.times[-1])
if average:
ax.legend()

Expand Down Expand Up @@ -477,6 +470,7 @@ def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None,

ax.set_ylabel("Counts")
ax.legend()
ax.set_xlim(right=cell_response.times[-1])

plt_show(show)
return ax.get_figure()
Expand Down Expand Up @@ -558,7 +552,7 @@ def plot_spikes_raster(cell_response, trial_idx=None, tmin=None, tmax=None,
ax.set_facecolor('k')
ax.set_xlabel('Time (ms)')
ax.get_yaxis().set_visible(False)
ax.set_xlim(left=0)
ax.set_xlim(left=0, right=cell_response.times[-1])

plt_show(show)
return ax.get_figure()
Expand Down Expand Up @@ -1229,6 +1223,7 @@ def plot_laminar_csd(times, data, contact_labels, tmin=None, tmax=None,

ax.set_xlabel('Time (ms)')
ax.set_ylabel('Electrode depth')
ax.set_xlim(right=times[-1])
plt.tight_layout()
plt_show(show)

Expand Down

0 comments on commit a821c26

Please sign in to comment.