From 6fbe4d6eaabcbdfa5ad3597223323cb2962b1a58 Mon Sep 17 00:00:00 2001 From: Niels Warncke Date: Fri, 26 Nov 2021 22:12:11 +0100 Subject: [PATCH] extract_loudness: Fixing #31 - adjust how to apply a-weighting and make the function differentiable --- ddsp/core.py | 64 ++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 47 insertions(+), 17 deletions(-) diff --git a/ddsp/core.py b/ddsp/core.py index 6b40721..0a6b95f 100644 --- a/ddsp/core.py +++ b/ddsp/core.py @@ -77,23 +77,53 @@ def scale_function(x): return 2 * torch.sigmoid(x)**(math.log(10)) + 1e-7 -def extract_loudness(signal, sampling_rate, block_size, n_fft=2048): - S = li.stft( - signal, - n_fft=n_fft, - hop_length=block_size, - win_length=n_fft, - center=True, - ) - S = np.log(abs(S) + 1e-7) - f = li.fft_frequencies(sampling_rate, n_fft) - a_weight = li.A_weighting(f) - - S = S + a_weight.reshape(-1, 1) - - S = np.mean(S, 0)[..., :-1] - - return S +def amplitude_to_db(amplitude): + amin = 1e-20 # Avoid log(0) instabilities. + db = torch.log10(torch.clamp(amplitude, min=amin)) + db *= 20.0 + return db + + +def extract_loudness(audio, sampling_rate, block_size=None, n_fft=2048, frame_rate=None): + assert (block_size is None) != (frame_rate is None), "Specify exactly one of block_size or frame_rate" + + if frame_rate is not None: + block_size = sample_rate // frame_rate + else: + frame_rate = int(sampling_rate / block_size) + + if sampling_rate % frame_rate != 0: + raise ValueError( + 'frame_rate: {} must evenly divide sample_rate: {}.' + 'For default frame_rate: 250Hz, suggested sample_rate: 16kHz or 48kHz' + .format(frame_rate, sampling_rate)) + + if isinstance(audio, np.ndarray): + audio = torch.tensor(audio) + + # Temporarily a batch dimension for single examples. + is_1d = (len(audio.shape) == 1) + audio = audio[None, :] if is_1d else audio + + # Take STFT. + overlap = 1 - block_size / n_fft + amplitude = torch.stft(audio, n_fft=n_fft, hop_length=block_size, center=True, pad_mode='reflect', return_complex=True).abs() + + # Compute power. + power_db = amplitude_to_db(amplitude) + + # Perceptual weighting. + frequencies = li.fft_frequencies(sr=sampling_rate, n_fft=n_fft) + a_weighting = li.A_weighting(frequencies)[None,:,None] + loudness = power_db + a_weighting + + loudness = torch.mean(torch.pow(10, loudness / 10.0), axis=1) + loudness = 10.0 * torch.log10(torch.clamp(loudness, min=1e-20)) + + # Remove temporary batch dimension. + loudness = loudness[0] if is_1d else loudness + + return loudness def extract_pitch(signal, sampling_rate, block_size):