diff --git a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py index 897c2837cc..312c3949b3 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py @@ -15,18 +15,25 @@ def test_calculate_pc_metrics(small_sorting_analyzer): res2 = pd.DataFrame(res2) for metric_name in res1.columns: - if metric_name != "nn_unit_id": - assert not np.all(np.isnan(res1[metric_name].values)) - assert not np.all(np.isnan(res2[metric_name].values)) - - # import matplotlib.pyplot as plt - # fig, ax = plt.subplots() - # ax.plot(res1[metric_name].values) - # ax.plot(res2[metric_name].values) - # ax.plot(res2[metric_name].values - res1[metric_name].values) - # plt.show() + values1 = res1[metric_name].values + values2 = res1[metric_name].values - np.testing.assert_almost_equal(res1[metric_name].values, res2[metric_name].values, decimal=4) + if metric_name != "nn_unit_id": + assert not np.all(np.isnan(values1)) + assert not np.all(np.isnan(values2)) + + if values1.dtype.kind == "f": + np.testing.assert_almost_equal(values1, values2, decimal=4) + # import matplotlib.pyplot as plt + # fig, axs = plt.subplots(nrows=2, share=True) + # ax =a xs[0] + # ax.plot(res1[metric_name].values) + # ax.plot(res2[metric_name].values) + # ax =a xs[1] + # ax.plot(res2[metric_name].values - res1[metric_name].values) + # plt.show() + else: + assert np.array_equal(values1, values2) def test_pca_metrics_multi_processing(small_sorting_analyzer):