From 46c9d5f5bd951550cd2bdb325ecc83ca4e9e9e0c Mon Sep 17 00:00:00 2001 From: ostix360 Date: Thu, 27 Jun 2024 21:06:05 +0200 Subject: [PATCH] fix possible bugs --- audioenhancer/constants.py | 3 +-- scripts/inference.py | 6 +++--- scripts/train.py | 9 +++------ 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/audioenhancer/constants.py b/audioenhancer/constants.py index 7dda589..a1646d5 100644 --- a/audioenhancer/constants.py +++ b/audioenhancer/constants.py @@ -3,7 +3,6 @@ """ SAMPLING_RATE = 16000 -UPSAMPLE_RATE = 32000 MAX_AUDIO_LENGTH = 10 BATCH_SIZE = 3 EPOCH = 1 @@ -12,4 +11,4 @@ SAVE_STEPS = 1000 INPUT_FREQ = 16000 -OUTPUT_FREQ = 16000 +OUTPUT_FREQ = 32000 diff --git a/scripts/inference.py b/scripts/inference.py index c672c4d..10832b7 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -7,7 +7,7 @@ import torch import torchaudio -from audioenhancer.constants import SAMPLING_RATE, MAX_AUDIO_LENGTH +from audioenhancer.constants import SAMPLING_RATE, MAX_AUDIO_LENGTH, INPUT_FREQ from audioenhancer.model.audio_ae.auto_encoder import AutoEncoder1d parser = argparse.ArgumentParser() @@ -81,7 +81,7 @@ def load(waveform_path): """ waveform, sample_rate = torchaudio.load(waveform_path) resampler = torchaudio.transforms.Resample( - sample_rate, args.sampling_rate, dtype=waveform.dtype + sample_rate, INPUT_FREQ, dtype=waveform.dtype ) waveform = resampler(waveform) if waveform.shape[0] == 1: @@ -118,5 +118,5 @@ def load(waveform_path): output = output.detach().cpu() audio = audio.squeeze(0).detach().cpu() -torchaudio.save("./data/input.mp3", audio.T, args.sampling_rate, channels_first=False) +torchaudio.save("./data/input.mp3", audio.T, SAMPLING_RATE, channels_first=False) torchaudio.save("./data/output.mp3", output.T, args.sampling_rate, channels_first=False) diff --git a/scripts/train.py b/scripts/train.py index bcb270c..9fd1c34 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -190,9 +190,8 @@ x = batch[0].to(device, dtype=dtype) y = batch[1].to(device, dtype=dtype) - up_y = batch[2].to(device, dtype=dtype) - y_hat = model(y) + y_hat = model(x) loss = sum(loss(y_hat, up_y) for loss in loss_fn) loss.backward() @@ -255,10 +254,8 @@ for batch in test_loader: x = batch[0].to(device, dtype=dtype) y = batch[1].to(device, dtype=dtype) - up_y = batch[2].to(device, dtype=dtype) - - y_hat = model(y) - loss = sum(loss(y_hat, up_y) for loss in loss_fn) + y_hat = model(x) + loss = sum(loss(y_hat, y) for loss in loss_fn) # batch_disc = torch.cat([y, y_hat], dim=0) # disc_pred = discriminator(batch_disc) # disc_pred = torch.sigmoid(disc_pred).squeeze()