diff --git a/audioenhancer/constants.py b/audioenhancer/constants.py index a79f389..2ef161f 100644 --- a/audioenhancer/constants.py +++ b/audioenhancer/constants.py @@ -4,10 +4,10 @@ SAMPLING_RATE = 44100 MAX_AUDIO_LENGTH = 10 -BATCH_SIZE = 8 +BATCH_SIZE = 2 EPOCH = 1 LOGGING_STEPS = 10 -GRADIENT_ACCUMULATION_STEPS = 1 +GRADIENT_ACCUMULATION_STEPS = 3 SAVE_STEPS = 100 EVAL_STEPS = 100 diff --git a/audioenhancer/dataset/loader.py b/audioenhancer/dataset/loader.py index 2f9c6f3..128e6db 100644 --- a/audioenhancer/dataset/loader.py +++ b/audioenhancer/dataset/loader.py @@ -8,8 +8,8 @@ import dac import torch -import torchaudio from audiotools import AudioSignal +from audiotools import transforms as tfm from torch.utils.data import Dataset @@ -23,6 +23,18 @@ def __init__( mono: bool = True, input_freq: int = 16000, output_freq: int = 16000, + transform: list = [ + tfm.CorruptPhase, + tfm.FrequencyNoise, + tfm.HighPass, + tfm.LowPass, + tfm.MuLawQuantization, + tfm.NoiseFloor, + tfm.Quantization, + tfm.Smoothing, + tfm.TimeNoise, + ], + overall_prob: float = 0.5, ): """Initializes the dataset. @@ -32,6 +44,8 @@ def __init__( mono (bool): Whether to load the audio as mono. input_freq (int): The input frequency of the audio. output_freq (int): The output frequency of the audio. + transform (list): The list of transforms to apply to the audio. + overall_prob (float): The overall probability of applying the transforms. """ super().__init__() @@ -58,6 +72,9 @@ def __init__( self.autoencoder.eval() self.autoencoder.requires_grad_(False) + prob = overall_prob / len(transform) + self._transform = tfm.Compose([trsfm(prob=prob) for trsfm in transform]) + def __len__(self) -> int: """Returns the number of waveforms in the dataset. @@ -91,6 +108,9 @@ def __getitem__(self, index: int) -> tuple: base_waveform = base_waveform.resample(self.autoencoder.sample_rate) compressed_waveform = compressed_waveform.resample(self.autoencoder.sample_rate) + kwargs = self._transform.instantiate(signal=compressed_waveform.clone()) + compressed_waveform = self._transform(compressed_waveform.clone(), **kwargs) + compressed_waveform = compressed_waveform[:, :, : self._pad_length_input] base_waveform = base_waveform[:, :, : self._pad_length_output] diff --git a/audioenhancer/model/audio_ae/model.py b/audioenhancer/model/audio_ae/model.py index a79fc81..d7e18c4 100644 --- a/audioenhancer/model/audio_ae/model.py +++ b/audioenhancer/model/audio_ae/model.py @@ -165,8 +165,8 @@ dim_out=1024, max_seq_len=0, attn_layers=Encoder( - dim=512, - depth=12, + dim=1024, + depth=16, heads=8, attn_flash=True, cross_attend=False, @@ -174,5 +174,6 @@ rotary_pos_emb=True, ff_swish=True, ff_glu=True, + use_scalenorm=True, ), ) diff --git a/scripts/train.py b/scripts/train.py index 1899f72..11661aa 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -157,12 +157,6 @@ def eval_model(model, test_loader): y = batch[1].to(device, dtype=dtype) c, d = x.shape[1], x.shape[2] - # normalize x over the last dimension - - mean_x = x.mean(dim=-1, keepdim=True) - std_x = x.std(dim=-1, keepdim=True) - x = (x - mean_x) / std_x - # rearrange x and y x = rearrange(x, "b c d t -> b (t c) d")