diff --git a/configs/48k.json b/configs/48k.json index ca53784..3f1acba 100644 --- a/configs/48k.json +++ b/configs/48k.json @@ -12,7 +12,7 @@ "learning_rate": 0.0002, "betas": [0.8, 0.99], "eps": 1e-9, - "batch_size": 16, + "batch_size": 32, "fp16_run": true, "lr_decay": 0.999875, "segment_size": 16384, diff --git a/hifigan/data/collate.py b/hifigan/data/collate.py index 5cc49a8..d76f906 100644 --- a/hifigan/data/collate.py +++ b/hifigan/data/collate.py @@ -11,29 +11,29 @@ def __call__(self, batch): torch.LongTensor([x["wav"].size(1) for x in batch]), dim=0, descending=True) - max_x_mel_len = max([x["mel"].size(1) for x in batch]) + max_x_wav_len = max([x["wav"].size(1) for x in batch]) max_y_wav_len = max([x["wav"].size(1) for x in batch]) - x_mel_lengths = torch.LongTensor(len(batch)) + x_wav_lengths = torch.LongTensor(len(batch)) y_wav_lengths = torch.LongTensor(len(batch)) - x_mel_padded = torch.zeros(len(batch), batch[0]["mel"].size(0), max_x_mel_len, dtype=torch.float32) + x_wav_padded = torch.zeros(len(batch), 1, max_x_wav_len, dtype=torch.float32) y_wav_padded = torch.zeros(len(batch), 1, max_y_wav_len, dtype=torch.float32) for i in range(len(ids_sorted_decreasing)): row = batch[ids_sorted_decreasing[i]] - mel = row["mel"] - x_mel_padded[i, :, :mel.size(1)] = mel - x_mel_lengths[i] = mel.size(1) + wav = row["wav"] + x_wav_padded[i, :, :wav.size(1)] = wav + x_wav_lengths[i] = wav.size(1) wav = row["wav"] y_wav_padded[i, :, :wav.size(1)] = wav y_wav_lengths[i] = wav.size(1) ret = { - "x_mel_values": x_mel_padded, - "x_mel_lengths": x_mel_lengths, + "x_wav_values": x_wav_padded, + "x_wav_lengths": x_wav_lengths, "y_wav_values": y_wav_padded, "y_wav_lengths": y_wav_lengths, } diff --git a/hifigan/data/dataset.py b/hifigan/data/dataset.py index c556001..ce5eb86 100644 --- a/hifigan/data/dataset.py +++ b/hifigan/data/dataset.py @@ -50,15 +50,14 @@ def get_item(self, index: int): audio_wav = load_audio(audio_path, sr=self.sampling_rate) audio_wav = audio_wav.unsqueeze(0) - audio_spec = spectrogram_torch(audio_wav, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False) - audio_spec = torch.squeeze(audio_spec, 0) + # audio_spec = spectrogram_torch(audio_wav, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False) + # audio_spec = torch.squeeze(audio_spec, 0) - audio_mel = spec_to_mel_torch(audio_spec, self.filter_length, self.n_mel_channels, self.sampling_rate, self.mel_fmin, self.mel_fmax) - audio_mel = torch.squeeze(audio_mel, 0) + # audio_mel = spec_to_mel_torch(audio_spec, self.filter_length, self.n_mel_channels, self.sampling_rate, self.mel_fmin, self.mel_fmax) + # audio_mel = torch.squeeze(audio_mel, 0) return { - "wav": audio_wav, - "mel": audio_mel, + "wav": audio_wav } def __getitem__(self, index): diff --git a/hifigan/model/hifigan.py b/hifigan/model/hifigan.py index 068716f..497da99 100644 --- a/hifigan/model/hifigan.py +++ b/hifigan/model/hifigan.py @@ -22,6 +22,7 @@ from .losses import discriminator_loss, kl_loss,feature_loss, generator_loss from .. import utils from .commons import slice_segments, rand_slice_segments, sequence_mask +from .pipeline import AudioPipeline class HifiGAN(pl.LightningModule): def __init__(self, **kwargs): @@ -43,12 +44,25 @@ def __init__(self, **kwargs): ) self.net_scale_d = MultiScaleDiscriminator(use_spectral_norm=self.hparams.model.use_spectral_norm) + self.audio_pipeline = AudioPipeline(freq=self.hparams.data.sampling_rate, + n_fft=self.hparams.data.filter_length, + n_mel=self.hparams.data.n_mel_channels, + win_length=self.hparams.data.win_length, + hop_length=self.hparams.data.hop_length, + device=self.device) + for param in self.audio_pipeline.parameters(): + param.requires_grad = False + # metrics self.valid_mel_loss = torchmetrics.MeanMetric() def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, optimizer_idx: int): - x_mel, x_mel_lengths = batch["x_mel_values"], batch["x_mel_lengths"] + x_wav, x_wav_lengths = batch["x_wav_values"], batch["x_wav_lengths"] y_wav, y_wav_lengths = batch["y_wav_values"], batch["y_wav_lengths"] + + with torch.inference_mode(): + x_mel = self.audio_pipeline(x_wav.squeeze(1)) + x_mel_lengths = (x_wav_lengths / self.hparams.data.hop_length).long() x_mel, ids_slice = rand_slice_segments(x_mel, x_mel_lengths, self.hparams.train.segment_size // self.hparams.data.hop_length) y_wav = slice_segments(y_wav, ids_slice * self.hparams.data.hop_length, self.hparams.train.segment_size) # slice @@ -157,8 +171,12 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, optimize def validation_step(self, batch, batch_idx): self.net_g.eval() - x_mel, x_mel_lengths = batch["x_mel_values"], batch["x_mel_lengths"] + x_wav, x_wav_lengths = batch["x_wav_values"], batch["x_wav_lengths"] y_wav, y_wav_lengths = batch["y_wav_values"], batch["y_wav_lengths"] + + with torch.inference_mode(): + x_mel = self.audio_pipeline(x_wav.squeeze(1)) + x_mel_lengths = (x_wav_lengths / self.hparams.data.hop_length).long() y_spec = spectrogram_torch_audio(y_wav.squeeze(1), self.hparams.data.filter_length, diff --git a/hifigan/model/pipeline.py b/hifigan/model/pipeline.py new file mode 100644 index 0000000..4499d73 --- /dev/null +++ b/hifigan/model/pipeline.py @@ -0,0 +1,49 @@ +import torch +import torchaudio +import torchaudio.transforms as T + +import torch +from torch import nn +from torch.nn import functional as F +from torch import optim +import random + +from ..mel_processing import hann_window + +class AudioPipeline(torch.nn.Module): + def __init__( + self, + freq=16000, + n_fft=1024, + n_mel=128, + win_length=1024, + hop_length=256, + device="cpu", + ): + super().__init__() + + self.freq=freq + self.device=device + + pad = int((n_fft-hop_length)/2) + self.spec = T.Spectrogram(n_fft=n_fft, win_length=win_length, hop_length=hop_length, + pad=pad, power=None,center=False, pad_mode='reflect', normalized=False, onesided=True) + + # self.strech = T.TimeStretch(hop_length=hop_length, n_freq=freq) + self.spec_aug = torch.nn.Sequential( + T.FrequencyMasking(freq_mask_param=80), + T.TimeMasking(time_mask_param=80), + ) + + self.mel_scale = T.MelScale(n_mels=n_mel, sample_rate=freq, n_stft=n_fft // 2 + 1) + + def forward(self, waveform: torch.Tensor) -> torch.Tensor: + shift_waveform = waveform + # Convert to power spectrogram + spec = self.spec(shift_waveform) + spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-6) + # Apply SpecAugment + spec = self.spec_aug(spec) + # Convert to mel-scale + mel = self.mel_scale(spec) + return mel \ No newline at end of file