Skip to content

Commit

Permalink
extract_loudness: Fixing acids-ircam#31 - adjust how to apply a-weigh…
Browse files Browse the repository at this point in the history
…ting and make the function differentiable
  • Loading branch information
nielsrolf committed Nov 26, 2021
1 parent aaaf17d commit 6fbe4d6
Showing 1 changed file with 47 additions and 17 deletions.
64 changes: 47 additions & 17 deletions ddsp/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,23 +77,53 @@ def scale_function(x):
return 2 * torch.sigmoid(x)**(math.log(10)) + 1e-7


def extract_loudness(signal, sampling_rate, block_size, n_fft=2048):
S = li.stft(
signal,
n_fft=n_fft,
hop_length=block_size,
win_length=n_fft,
center=True,
)
S = np.log(abs(S) + 1e-7)
f = li.fft_frequencies(sampling_rate, n_fft)
a_weight = li.A_weighting(f)

S = S + a_weight.reshape(-1, 1)

S = np.mean(S, 0)[..., :-1]

return S
def amplitude_to_db(amplitude):
amin = 1e-20 # Avoid log(0) instabilities.
db = torch.log10(torch.clamp(amplitude, min=amin))
db *= 20.0
return db


def extract_loudness(audio, sampling_rate, block_size=None, n_fft=2048, frame_rate=None):
assert (block_size is None) != (frame_rate is None), "Specify exactly one of block_size or frame_rate"

if frame_rate is not None:
block_size = sample_rate // frame_rate
else:
frame_rate = int(sampling_rate / block_size)

if sampling_rate % frame_rate != 0:
raise ValueError(
'frame_rate: {} must evenly divide sample_rate: {}.'
'For default frame_rate: 250Hz, suggested sample_rate: 16kHz or 48kHz'
.format(frame_rate, sampling_rate))

if isinstance(audio, np.ndarray):
audio = torch.tensor(audio)

# Temporarily a batch dimension for single examples.
is_1d = (len(audio.shape) == 1)
audio = audio[None, :] if is_1d else audio

# Take STFT.
overlap = 1 - block_size / n_fft
amplitude = torch.stft(audio, n_fft=n_fft, hop_length=block_size, center=True, pad_mode='reflect', return_complex=True).abs()

# Compute power.
power_db = amplitude_to_db(amplitude)

# Perceptual weighting.
frequencies = li.fft_frequencies(sr=sampling_rate, n_fft=n_fft)
a_weighting = li.A_weighting(frequencies)[None,:,None]
loudness = power_db + a_weighting

loudness = torch.mean(torch.pow(10, loudness / 10.0), axis=1)
loudness = 10.0 * torch.log10(torch.clamp(loudness, min=1e-20))

# Remove temporary batch dimension.
loudness = loudness[0] if is_1d else loudness

return loudness


def extract_pitch(signal, sampling_rate, block_size):
Expand Down

0 comments on commit 6fbe4d6

Please sign in to comment.