From 5446afc62e1e95c4a4f6dbc701190e963f762aca Mon Sep 17 00:00:00 2001 From: Jourdelune Date: Thu, 4 Jul 2024 10:37:07 +0200 Subject: [PATCH] [update] switch to diffusion encodec --- .gitignore | 4 ++- audioenhancer/constants.py | 2 +- audioenhancer/dataset/loader.py | 43 +++++++++++++-------------- audioenhancer/inference.py | 42 ++++++++++++-------------- audioenhancer/model/audio_ae/model.py | 25 +++++----------- scripts/train.py | 10 +------ 6 files changed, 52 insertions(+), 74 deletions(-) diff --git a/.gitignore b/.gitignore index f98bf2b..4723027 100644 --- a/.gitignore +++ b/.gitignore @@ -165,4 +165,6 @@ test.py data/output.* *.pth runs/ -data/input.mp3 \ No newline at end of file +data/input.mp3 + +flagged diff --git a/audioenhancer/constants.py b/audioenhancer/constants.py index f5129d6..2b0b0bf 100644 --- a/audioenhancer/constants.py +++ b/audioenhancer/constants.py @@ -4,7 +4,7 @@ SAMPLING_RATE = 44100 MAX_AUDIO_LENGTH = 10 -BATCH_SIZE = 2 +BATCH_SIZE = 4 EPOCH = 1 LOGGING_STEPS = 10 GRADIENT_ACCUMULATION_STEPS = 2 diff --git a/audioenhancer/dataset/loader.py b/audioenhancer/dataset/loader.py index e3a13fd..acff97f 100644 --- a/audioenhancer/dataset/loader.py +++ b/audioenhancer/dataset/loader.py @@ -7,9 +7,8 @@ import os import random -import dac import torch -import torchaudio +from audiocraft.models import MultiBandDiffusion from audiotools import AudioSignal from audiotools import transforms as tfm from torch.utils.data import Dataset @@ -69,12 +68,9 @@ def __init__( self._input_freq = input_freq self._output_freq = output_freq - model_path = dac.utils.download(model_type="44khz") - self.autoencoder = dac.DAC.load(model_path).to("cuda") - self.autoencoder.eval() - self.autoencoder.requires_grad_(False) + self.autoencoder = MultiBandDiffusion.get_mbd_musicgen() - self._prob = overall_prob / len(transform) + self._prob = overall_prob / (len(transform) + 1) self._transform = tfm.Compose([trsfm(prob=self._prob) for trsfm in transform]) self._transform2 = tfm.Compose( @@ -126,13 +122,8 @@ def __getitem__(self, index: int) -> tuple: compressed_waveform.clone(), **kwargs ) - compressed_waveform = self.autoencoder.preprocess( - compressed_waveform.audio_data, compressed_waveform.sample_rate - ) - - base_waveform = self.autoencoder.preprocess( - base_waveform.audio_data, base_waveform.sample_rate - ) + compressed_waveform = compressed_waveform.audio_data + base_waveform = base_waveform.audio_data if base_waveform.shape[-1] < self._pad_length_output: base_waveform = torch.nn.functional.pad( @@ -158,16 +149,22 @@ def __getitem__(self, index: int) -> tuple: if base_waveform.shape[0] == 1: base_waveform = base_waveform.repeat(2, 1, 1) - encoded_compressed_waveform, _, _, _, _, _ = self.autoencoder.encode( - compressed_waveform - ) + if random.random() < self._prob: + strength = torch.rand(compressed_waveform.shape[:2]) * 0.01 + strength_expanded = ( + strength.unsqueeze(2) + .expand(-1, -1, compressed_waveform.shape[2]) + .cuda() + ) + noise = torch.randn_like(compressed_waveform).cuda() + compressed_waveform = compressed_waveform + noise * strength_expanded - encoded_base_waveform, _, codes, _, _, _ = self.autoencoder.encode( - base_waveform + encoded_compressed_waveform = self.autoencoder.get_condition( + compressed_waveform, sample_rate=self.autoencoder.sample_rate ) - return ( - encoded_compressed_waveform, - encoded_base_waveform, - codes, + encoded_base_waveform = self.autoencoder.get_condition( + base_waveform, sample_rate=self.autoencoder.sample_rate ) + + return (encoded_compressed_waveform, encoded_base_waveform) diff --git a/audioenhancer/inference.py b/audioenhancer/inference.py index 4f8093c..900ed8c 100644 --- a/audioenhancer/inference.py +++ b/audioenhancer/inference.py @@ -4,12 +4,12 @@ import os -import dac import torch import torchaudio from einops import rearrange from audioenhancer.model.audio_ae.model import model_xtransformer_small as model +from audiocraft.models import MultiBandDiffusion class Inference: @@ -22,8 +22,7 @@ def __init__(self, model_path: str, sampling_rate: int): self._sampling_rate = sampling_rate - autoencoder_path = dac.utils.download(model_type="44khz") - self._autoencoder = dac.DAC.load(autoencoder_path).to(self.device) + self.autoencoder = MultiBandDiffusion.get_mbd_musicgen() def load(self, waveform_path): """ @@ -45,6 +44,7 @@ def load(self, waveform_path): return waveform.to(self.device) + @torch.no_grad() def inference(self, audio_path: str, chunk_duration: int = 10): """Run inference on the given audio file. @@ -68,35 +68,31 @@ def inference(self, audio_path: str, chunk_duration: int = 10): 0, ) - with torch.no_grad(): - encoded, encoded_q, _, _, _, _ = self._autoencoder.encode( - chunk.transpose(0, 1) - ) - - # create input for the model - decoded = self._autoencoder.decode(encoded_q) + encoded = self.autoencoder.get_condition( + chunk.transpose(0, 1), sample_rate=self._sampling_rate + ) - decoded = decoded.transpose(0, 1) + # create input for the model + decoded = self.autoencoder.generate(encoded) - ae_input = torch.cat([ae_input, decoded], dim=2) + decoded = decoded.transpose(0, 1) - encoded = encoded.unsqueeze(0) - c, d = encoded.shape[1], encoded.shape[2] - encoded = rearrange(encoded, "b c d t -> b (t c) d") + ae_input = torch.cat([ae_input, decoded], dim=2) - pred = self.model(encoded) + encoded = encoded.unsqueeze(0) + c, d = encoded.shape[1], encoded.shape[2] + encoded = rearrange(encoded, "b c d t -> b (t c) d") - pred = rearrange(pred, "b (t c) d -> b c d t", c=c, d=d) - pred = pred.squeeze(0) + pred = self.model(encoded) - # quantize - z_q, _, _, _, _ = self._autoencoder.quantizer(pred, None) + pred = rearrange(pred, "b (t c) d -> b c d t", c=c, d=d) + pred = pred.squeeze(0) - decoded = self._autoencoder.decode(z_q) + decoded = self.autoencoder.generate(pred) - decoded = decoded.transpose(0, 1) + decoded = decoded.transpose(0, 1) - output = torch.cat([output, decoded], dim=2) + output = torch.cat([output, decoded], dim=2) # fix runtime error: numpy output = output.squeeze(0).detach().cpu() diff --git a/audioenhancer/model/audio_ae/model.py b/audioenhancer/model/audio_ae/model.py index 7a15d75..07916d8 100644 --- a/audioenhancer/model/audio_ae/model.py +++ b/audioenhancer/model/audio_ae/model.py @@ -15,7 +15,6 @@ from auraloss.freq import MultiResolutionSTFTLoss from x_transformers import ContinuousTransformerWrapper, Decoder, Encoder -from audioenhancer.model.audio_ae.latent import LatentProcessor from audioenhancer.model.audio_ae.vdiffusion import CustomVDiffusion model = DiffusionModel( @@ -162,11 +161,11 @@ ) model_xtransformer = ContinuousTransformerWrapper( - dim_in=1024, - dim_out=1024, + dim_in=128, + dim_out=128, max_seq_len=0, attn_layers=Encoder( - dim=1024, + dim=128, depth=16, heads=8, attn_flash=True, @@ -180,13 +179,13 @@ ) model_xtransformer_small = ContinuousTransformerWrapper( - dim_in=1024, - dim_out=1024, + dim_in=128, + dim_out=128, max_seq_len=0, attn_layers=Encoder( - dim=1024, - depth=12, - heads=8, + dim=512, + depth=18, + heads=18, attn_flash=True, cross_attend=False, zero_init_branch_output=True, @@ -196,11 +195,3 @@ use_scalenorm=True, ), ) - -mamba_model = LatentProcessor( - in_dim=1024, - out_dim=1024, - num_code_book=9, - latent_dim=2048, - num_layer=8, -) diff --git a/scripts/train.py b/scripts/train.py index d1a8100..3a059a9 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -59,7 +59,7 @@ args = parser.parse_args() -dtype = torch.float32 +dtype = torch.bfloat16 # Load the dataset dataset = SynthDataset( @@ -129,7 +129,6 @@ ], lr=1e-4, betas=(0.95, 0.999), - eps=1e-6, weight_decay=1e-3, ) @@ -158,13 +157,6 @@ # print number of parameters print(f"Number of parameters: {sum(p.numel() for p in model.parameters()) / 1e6}M") -import dac - -autoencoder_path = dac.utils.download(model_type="44khz") -autoencoder = dac.DAC.load(autoencoder_path).to("cuda") -autoencoder.eval() -autoencoder.requires_grad_(False) - def eval_model(model, test_loader): model.eval()