Skip to content

Commit

Permalink
[update] use only l1 loss
Browse files Browse the repository at this point in the history
  • Loading branch information
Jourdelune committed Jul 5, 2024
1 parent d35f3e3 commit 08486de
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 14 deletions.
39 changes: 38 additions & 1 deletion audioenhancer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,41 @@

from audioenhancer.model.audio_ae.model import model_xtransformer as model

import librosa
import numpy as np
from scipy.signal import butter, filtfilt
import soundfile as sf


def remove_noise(audio_path, output_path):
# Charger le fichier audio
y, sr = librosa.load(audio_path)

# Calculer le RMS (Root Mean Square) du signal original
rms_original = np.sqrt(np.mean(y**2))

# Définir les paramètres du filtre passe-bas
cutoff = 7000 # Fréquence de coupure à 7000 Hz
order = 2 # Ordre du filtre à 2
nyquist = 0.5 * sr
normal_cutoff = cutoff / nyquist

# Créer le filtre passe-bas
b, a = butter(order, normal_cutoff, btype="low", analog=False)

# Appliquer le filtre
y_filtered = filtfilt(b, a, y)

# Mélanger le signal original et le signal filtré
y_output = 0.85 * y + 0.15 * y_filtered # 85% original, 15% filtré

# Normaliser le volume du signal de sortie
rms_output = np.sqrt(np.mean(y_output**2))
y_output = y_output * (rms_original / rms_output)

# Sauvegarder le résultat
sf.write(output_path, y_output, sr)


class Inference:
def __init__(self, model_path: str, sampling_rate: int):
Expand Down Expand Up @@ -112,7 +147,9 @@ def inference(self, audio_path: str, chunk_duration: int = 3):
channels_first=False,
)
torchaudio.save(
"./data/output.mp3", output.T, self._sampling_rate, channels_first=False
"./data/output1.mp3", output.T, self._sampling_rate, channels_first=False
)

# remove_noise("./data/output1.mp3", "./data/output.mp3")

return os.path.abspath("./data/output.mp3")
37 changes: 24 additions & 13 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import argparse
import os

import schedulefree
import auraloss
import bitsandbytes as bnb
import schedulefree
import torch
import torch.nn as nn
from audiotools import AudioSignal
from einops import rearrange
from torch.nn import MSELoss
Expand Down Expand Up @@ -94,7 +95,7 @@
# return MSELoss()(logits, target) / 10

loss_fn = [MSELoss()]
losses = [L1Loss()]
losses = [nn.L1Loss()]

# split test and train
test_size = min(len(dataset) * 0.1, EVAL_STEPS)
Expand Down Expand Up @@ -181,7 +182,7 @@ def eval_model(model, test_loader):
y_hat = model(x, mask=None)

y_hat = rearrange(y_hat, "b (t c) d -> b c d t", c=c, d=d)
loss = 0
loss_total = 0
for i in range(y_hat.shape[0]):
with torch.no_grad():
z_q, _, _, _, _ = dataset.autoencoder.quantizer(y_hat[i].float(), None)
Expand All @@ -191,11 +192,16 @@ def eval_model(model, test_loader):
y_signal = AudioSignal(
base_waveform[i][:, :, : decoded.shape[-1]], sample_rate=SAMPLING_RATE
)
loss += sum(loss(signal, y_signal) for loss in losses)

loss /= y_hat.shape[0]
loss = loss + sum([loss_fn[i](y_hat, y) for i in range(len(loss_fn))]) * 0.2
loss_test += loss.detach().cpu().float().numpy()
for loss in losses:
out = loss(signal.audio_data.cpu(), y_signal.audio_data.cpu())
loss_total += out

loss_total /= y_hat.shape[0]
loss_total = (
loss_total + sum([loss_fn[i](y_hat, y) for i in range(len(loss_fn))]) * 0
)
loss_test += loss_total.detach().cpu().float().numpy()

return loss_test / len(test_loader)

Expand All @@ -221,7 +227,7 @@ def eval_model(model, test_loader):

y_hat = rearrange(y_hat, "b (t c) d -> b c d t", c=c, d=d)

loss = 0
loss_total = 0
for i in range(y_hat.shape[0]):
with torch.no_grad():
z_q, _, _, _, _ = dataset.autoencoder.quantizer(y_hat[i].float(), None)
Expand All @@ -231,11 +237,16 @@ def eval_model(model, test_loader):
y_signal = AudioSignal(
base_waveform[i][:, :, : decoded.shape[-1]], sample_rate=SAMPLING_RATE
)
loss += sum(loss(signal, y_signal) for loss in losses)

loss /= y_hat.shape[0]
loss = loss + sum([loss_fn[i](y_hat, y) for i in range(len(loss_fn))]) * 0.2
loss.backward()
for loss in losses:
out = loss(signal.audio_data.cpu(), y_signal.audio_data.cpu())
loss_total += out

loss_total /= y_hat.shape[0]
loss_total = (
loss_total + sum([loss_fn[i](y_hat, y) for i in range(len(loss_fn))]) * 0
)
loss_total.backward()

# batch_disc = torch.cat([y, y_hat], dim=0)
# disc_pred = discriminator(batch_disc)
Expand All @@ -246,7 +257,7 @@ def eval_model(model, test_loader):
# )
# logging_desc_loss += disc_loss.detach().cpu().float().numpy()

logging_loss += loss.detach().cpu().float().numpy()
logging_loss += loss_total.detach().cpu().float().numpy()
# loss += disc_pred[: -y.shape[0]].mean().squeeze()

if (step % GRADIENT_ACCUMULATION_STEPS) == 0:
Expand Down

0 comments on commit 08486de

Please sign in to comment.