Skip to content

Commit

Permalink
[update] add ds transformation
Browse files Browse the repository at this point in the history
  • Loading branch information
Jourdelune committed Jul 2, 2024
1 parent 1ed359b commit fd55a39
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 11 deletions.
4 changes: 2 additions & 2 deletions audioenhancer/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

SAMPLING_RATE = 44100
MAX_AUDIO_LENGTH = 10
BATCH_SIZE = 8
BATCH_SIZE = 2
EPOCH = 1
LOGGING_STEPS = 10
GRADIENT_ACCUMULATION_STEPS = 1
GRADIENT_ACCUMULATION_STEPS = 3
SAVE_STEPS = 100
EVAL_STEPS = 100

Expand Down
22 changes: 21 additions & 1 deletion audioenhancer/dataset/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

import dac
import torch
import torchaudio
from audiotools import AudioSignal
from audiotools import transforms as tfm
from torch.utils.data import Dataset


Expand All @@ -23,6 +23,18 @@ def __init__(
mono: bool = True,
input_freq: int = 16000,
output_freq: int = 16000,
transform: list = [
tfm.CorruptPhase,
tfm.FrequencyNoise,
tfm.HighPass,
tfm.LowPass,
tfm.MuLawQuantization,
tfm.NoiseFloor,
tfm.Quantization,
tfm.Smoothing,
tfm.TimeNoise,
],
overall_prob: float = 0.5,
):
"""Initializes the dataset.
Expand All @@ -32,6 +44,8 @@ def __init__(
mono (bool): Whether to load the audio as mono.
input_freq (int): The input frequency of the audio.
output_freq (int): The output frequency of the audio.
transform (list): The list of transforms to apply to the audio.
overall_prob (float): The overall probability of applying the transforms.
"""

super().__init__()
Expand All @@ -58,6 +72,9 @@ def __init__(
self.autoencoder.eval()
self.autoencoder.requires_grad_(False)

prob = overall_prob / len(transform)
self._transform = tfm.Compose([trsfm(prob=prob) for trsfm in transform])

def __len__(self) -> int:
"""Returns the number of waveforms in the dataset.
Expand Down Expand Up @@ -91,6 +108,9 @@ def __getitem__(self, index: int) -> tuple:
base_waveform = base_waveform.resample(self.autoencoder.sample_rate)
compressed_waveform = compressed_waveform.resample(self.autoencoder.sample_rate)

kwargs = self._transform.instantiate(signal=compressed_waveform.clone())
compressed_waveform = self._transform(compressed_waveform.clone(), **kwargs)

compressed_waveform = compressed_waveform[:, :, : self._pad_length_input]
base_waveform = base_waveform[:, :, : self._pad_length_output]

Expand Down
5 changes: 3 additions & 2 deletions audioenhancer/model/audio_ae/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,14 +165,15 @@
dim_out=1024,
max_seq_len=0,
attn_layers=Encoder(
dim=512,
depth=12,
dim=1024,
depth=16,
heads=8,
attn_flash=True,
cross_attend=False,
zero_init_branch_output=True,
rotary_pos_emb=True,
ff_swish=True,
ff_glu=True,
use_scalenorm=True,
),
)
6 changes: 0 additions & 6 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,6 @@ def eval_model(model, test_loader):
y = batch[1].to(device, dtype=dtype)
c, d = x.shape[1], x.shape[2]

# normalize x over the last dimension

mean_x = x.mean(dim=-1, keepdim=True)
std_x = x.std(dim=-1, keepdim=True)
x = (x - mean_x) / std_x

# rearrange x and y
x = rearrange(x, "b c d t -> b (t c) d")

Expand Down

0 comments on commit fd55a39

Please sign in to comment.