diff --git a/constants.py b/constants.py index 6aa9caa..d98e431 100644 --- a/constants.py +++ b/constants.py @@ -2,7 +2,7 @@ Constants for the project. """ -SAMPLING_RATE = 16000 +SAMPLING_RATE = 16000*2 MAX_AUDIO_LENGTH = 30 BATCH_SIZE = 5 EPOCH = 10 diff --git a/inference.py b/inference.py index 8b0b5dc..36ca3f9 100644 --- a/inference.py +++ b/inference.py @@ -6,6 +6,8 @@ import torchaudio from model.soundstream import SoundStream +from constants import SAMPLING_RATE + model = SoundStream( D=256, @@ -18,7 +20,7 @@ def load(waveform_path): waveform, sample_rate = torchaudio.load(waveform_path) - resampler = torchaudio.transforms.Resample(sample_rate, 16000, dtype=waveform.dtype) + resampler = torchaudio.transforms.Resample(sample_rate, SAMPLING_RATE, dtype=waveform.dtype) waveform = resampler(waveform) waveform = waveform.mean(dim=0, keepdim=True) @@ -29,4 +31,4 @@ def load(waveform_path): audio = load("data/test.mp3") output = model(audio) -torchaudio.save("data/output.wav", output[0], 16000) +torchaudio.save("data/output.wav", output[0], SAMPLING_RATE)