diff --git a/src/spikeanalysis/analysis_utils/latency_functions.py b/src/spikeanalysis/analysis_utils/latency_functions.py index a9dad4e..ccaead8 100644 --- a/src/spikeanalysis/analysis_utils/latency_functions.py +++ b/src/spikeanalysis/analysis_utils/latency_functions.py @@ -16,14 +16,14 @@ def latency_core_stats(bsl_fr: float, firing_data: np.array, time_bin_size: floa ) if final_prob <= 10e-6: break - elif n_bin * time_bin_size >= 0.200: # past 200 ms is not really a true latency + elif n_bin * time_bin_size >= 0.400: # past 400 ms is not really a true latency n_bin = np.shape(firing_data)[1] - 2 break - if n_bin == np.shape(firing_data)[1] - 2: # need to go to second last bin - latency[trial] = np.nan - else: - latency[trial] = (n_bin + 1) * time_bin_size + if n_bin == np.shape(firing_data)[1] - 2: # need to go to second last bin + latency[trial] = np.nan + else: + latency[trial] = (n_bin + 1) * time_bin_size return latency @@ -37,7 +37,7 @@ def latency_median(firing_counts: np.array, time_bin_size: float): latency = np.zeros((np.shape(firing_counts)[0])) for trial in range(np.shape(firing_counts)[0]): min_spike_time = np.nonzero(firing_counts[trial])[0] - if len(min_spike_time) == 0: + if len(min_spike_time) == 0 or (np.min(min_spike_time) + 1) * time_bin_size > 0.400: latency[trial] = np.nan else: latency[trial] = (np.min(min_spike_time) + 1) * time_bin_size diff --git a/test/test_latency_functions.py b/test/test_latency_functions.py index 030886c..556cbe2 100644 --- a/test/test_latency_functions.py +++ b/test/test_latency_functions.py @@ -49,9 +49,13 @@ def test_latency_median(): ] ) test_array = np.expand_dims(test_array, axis=0) - lat = lf.latency_median(test_array, time_bin_size=1) + lat = lf.latency_median(test_array, time_bin_size=0.1) print(lat) - assert lat == [3.0] + assert round(lat[0], 2) == 0.30 + + # nan test + lat = lf.latency_median(test_array, time_bin_size=1) + assert np.isnan(lat) def test_latency_nan():