diff --git a/museval/metrics.py b/museval/metrics.py index d076dc9..ee30a38 100644 --- a/museval/metrics.py +++ b/museval/metrics.py @@ -57,7 +57,7 @@ MAX_SOURCES = 100 -def validate(reference_sources, estimated_sources): +def validate(reference_sources, estimated_sources, compute_permutation): """Checks that the input data to a metric are valid, and throws helpful errors if not. @@ -91,13 +91,22 @@ def validate(reference_sources, estimated_sources): "will all be empty np.ndarrays" ) elif _any_source_silent(reference_sources): - raise ValueError( - "All the reference sources should be non-silent (not " - "all-zeros), but at least one of the reference " - "sources is all 0s, which introduces ambiguity to the" - " evaluation. (Otherwise we can add infinitely many " - "all-zero sources.)" - ) + if compute_permutation: + raise ValueError( + "When the estimates are not labeled and we need to" + "compute the optimum permutation, all the reference " + "sources should be non-silent (not all-zeros), but at " + "least one of the reference sources is all 0s, which " + "introduces ambiguity to the evaluation. (Otherwise we " + "can add infinitely many all-zero sources.)" + ) + else: + warnings.warn( + "At least one of the reference sources is all 0s. This " + "will generate NaN values in the metrics of the silent " + "sources and this track will not be taken into account " + "when computing the aggregated metrics of these sources." + ) if estimated_sources.size == 0: warnings.warn( @@ -106,13 +115,24 @@ def validate(reference_sources, estimated_sources): "will all be empty np.ndarrays" ) elif _any_source_silent(estimated_sources): - raise ValueError( - "All the estimated sources should be non-silent (not " - "all-zeros), but at least one of the estimated " - "sources is all 0s. Since we require each reference " - "source to be non-silent, having a silent estimated " - "source will result in an underdetermined system." - ) + if compute_permutation: + raise ValueError( + "When the estimates are not labeled and we need to" + "compute the optimum permutation, all the estimated " + "sources should be non-silent (not all-zeros), but at " + "least one of the estimated sources is all 0s. Since " + "we require each reference source to be non-silent, " + "having a silent estimated source will result in an " + "underdetermined system." + ) + else: + warnings.warn( + "At least one of the estimated sources is all 0s. This " + "might generate NaN values in the metrics of the silent " + "sources in which case this track will not be taken into " + "account when computing the aggregated metrics of these " + "sources." + ) if ( estimated_sources.shape[0] > MAX_SOURCES @@ -240,7 +260,7 @@ def bss_eval( reference_sources = np.atleast_3d(reference_sources) # validate input - validate(reference_sources, estimated_sources) + validate(reference_sources, estimated_sources, compute_permutation) # If empty matrices were supplied, return empty lists (special case) if reference_sources.size == 0 or estimated_sources.size == 0: @@ -301,26 +321,21 @@ def compute_Cj(win=slice(0, nsampl)): ref_slice = reference_sources[:, win] est_slice = estimated_sources[:, win] - if not _any_source_silent(ref_slice) and not _any_source_silent(est_slice): - for jtrue in range(nsrc): - for k, jest in enumerate(candidate_permutations[:, jtrue]): - # if we have a silent frame set results as np.nan - if not done[jtrue, jest]: - s_true, e_spat, e_interf, e_artif = _bss_decomp_mtifilt( - reference_sources[:, win], - estimated_sources[jest, win], - jtrue, - C[jest], - Cj[jtrue, jest, 0], - ) - s_r[:, jtrue, jest, t] = _bss_crit( - s_true, e_spat, e_interf, e_artif, bsseval_sources_version - ) - done[jtrue, jest] = True - else: - a = np.empty((4, nsrc, nsrc)) - a[:] = np.nan - s_r[:, :, :, t] = a + for jtrue in range(nsrc): + for k, jest in enumerate(candidate_permutations[:, jtrue]): + # if we have a silent frame set results as np.nan + if not done[jtrue, jest]: + s_true, e_spat, e_interf, e_artif = _bss_decomp_mtifilt( + reference_sources[:, win], + estimated_sources[jest, win], + jtrue, + C[jest], + Cj[jtrue, jest, 0], + ) + s_r[:, jtrue, jest, t] = _bss_crit( + s_true, e_spat, e_interf, e_artif, bsseval_sources_version + ) + done[jtrue, jest] = True # select the best ordering if framewise_filters: @@ -333,9 +348,12 @@ def compute_Cj(win=slice(0, nsampl)): mean_sir = np.empty((len(candidate_permutations), 1)) axis_mean = None dum = np.arange(nsrc) - for i, perm in enumerate(candidate_permutations): - mean_sir[i] = np.mean(s_r[SIR, dum, perm, :], axis=axis_mean) - popt = candidate_permutations[np.argmax(mean_sir, axis=0)].T + if compute_permutation: + for i, perm in enumerate(candidate_permutations): + mean_sir[i] = np.mean(s_r[SIR, dum, perm, :], axis=axis_mean) + popt = candidate_permutations[np.argmax(mean_sir, axis=0)].T + else: + popt = candidate_permutations[[0]].T # now prepare the output if not framewise_filters: @@ -657,9 +675,13 @@ def _bss_crit(s_true, e_spat, e_interf, e_artif, bsseval_sources_version): def _safe_db(num, den): - """Properly handle the potential +Inf db SIR instead of raising a + """Properly handle the potential +Inf, -Inf, and NaN db SIR instead of raising a RuntimeWarning. """ + if den == 0 and num == 0: + return np.float64(np.NaN) if den == 0: - return np.inf + return np.float64(np.inf) + if num == 0: + return np.float64(- np.inf) return 10 * np.log10(num / den) diff --git a/tests/test_bsseval.py b/tests/test_bsseval.py index 4a4bb7c..313fb6a 100644 --- a/tests/test_bsseval.py +++ b/tests/test_bsseval.py @@ -77,7 +77,7 @@ def test_empty_input(is_framewise, is_sources, nb_win, nb_hop): def test_silent_input(references, estimates, is_framewise, is_sources, nb_win, nb_hop): estimates = np.zeros(references.shape) - with pytest.raises(ValueError): + with pytest.warns(UserWarning): metrics.bss_eval( references, estimates,