Skip to content

Commit

Permalink
data augmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Nov 20, 2022
1 parent a0b89f0 commit 10ebd5d
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 17 deletions.
2 changes: 1 addition & 1 deletion configs/48k.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"learning_rate": 0.0002,
"betas": [0.8, 0.99],
"eps": 1e-9,
"batch_size": 16,
"batch_size": 32,
"fp16_run": true,
"lr_decay": 0.999875,
"segment_size": 16384,
Expand Down
16 changes: 8 additions & 8 deletions hifigan/data/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,29 @@ def __call__(self, batch):
torch.LongTensor([x["wav"].size(1) for x in batch]),
dim=0, descending=True)

max_x_mel_len = max([x["mel"].size(1) for x in batch])
max_x_wav_len = max([x["wav"].size(1) for x in batch])
max_y_wav_len = max([x["wav"].size(1) for x in batch])

x_mel_lengths = torch.LongTensor(len(batch))
x_wav_lengths = torch.LongTensor(len(batch))
y_wav_lengths = torch.LongTensor(len(batch))

x_mel_padded = torch.zeros(len(batch), batch[0]["mel"].size(0), max_x_mel_len, dtype=torch.float32)
x_wav_padded = torch.zeros(len(batch), 1, max_x_wav_len, dtype=torch.float32)
y_wav_padded = torch.zeros(len(batch), 1, max_y_wav_len, dtype=torch.float32)

for i in range(len(ids_sorted_decreasing)):
row = batch[ids_sorted_decreasing[i]]

mel = row["mel"]
x_mel_padded[i, :, :mel.size(1)] = mel
x_mel_lengths[i] = mel.size(1)
wav = row["wav"]
x_wav_padded[i, :, :wav.size(1)] = wav
x_wav_lengths[i] = wav.size(1)

wav = row["wav"]
y_wav_padded[i, :, :wav.size(1)] = wav
y_wav_lengths[i] = wav.size(1)

ret = {
"x_mel_values": x_mel_padded,
"x_mel_lengths": x_mel_lengths,
"x_wav_values": x_wav_padded,
"x_wav_lengths": x_wav_lengths,
"y_wav_values": y_wav_padded,
"y_wav_lengths": y_wav_lengths,
}
Expand Down
11 changes: 5 additions & 6 deletions hifigan/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,14 @@ def get_item(self, index: int):
audio_wav = load_audio(audio_path, sr=self.sampling_rate)
audio_wav = audio_wav.unsqueeze(0)

audio_spec = spectrogram_torch(audio_wav, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False)
audio_spec = torch.squeeze(audio_spec, 0)
# audio_spec = spectrogram_torch(audio_wav, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False)
# audio_spec = torch.squeeze(audio_spec, 0)

audio_mel = spec_to_mel_torch(audio_spec, self.filter_length, self.n_mel_channels, self.sampling_rate, self.mel_fmin, self.mel_fmax)
audio_mel = torch.squeeze(audio_mel, 0)
# audio_mel = spec_to_mel_torch(audio_spec, self.filter_length, self.n_mel_channels, self.sampling_rate, self.mel_fmin, self.mel_fmax)
# audio_mel = torch.squeeze(audio_mel, 0)

return {
"wav": audio_wav,
"mel": audio_mel,
"wav": audio_wav
}

def __getitem__(self, index):
Expand Down
22 changes: 20 additions & 2 deletions hifigan/model/hifigan.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .losses import discriminator_loss, kl_loss,feature_loss, generator_loss
from .. import utils
from .commons import slice_segments, rand_slice_segments, sequence_mask
from .pipeline import AudioPipeline

class HifiGAN(pl.LightningModule):
def __init__(self, **kwargs):
Expand All @@ -43,12 +44,25 @@ def __init__(self, **kwargs):
)
self.net_scale_d = MultiScaleDiscriminator(use_spectral_norm=self.hparams.model.use_spectral_norm)

self.audio_pipeline = AudioPipeline(freq=self.hparams.data.sampling_rate,
n_fft=self.hparams.data.filter_length,
n_mel=self.hparams.data.n_mel_channels,
win_length=self.hparams.data.win_length,
hop_length=self.hparams.data.hop_length,
device=self.device)
for param in self.audio_pipeline.parameters():
param.requires_grad = False

# metrics
self.valid_mel_loss = torchmetrics.MeanMetric()

def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, optimizer_idx: int):
x_mel, x_mel_lengths = batch["x_mel_values"], batch["x_mel_lengths"]
x_wav, x_wav_lengths = batch["x_wav_values"], batch["x_wav_lengths"]
y_wav, y_wav_lengths = batch["y_wav_values"], batch["y_wav_lengths"]

with torch.inference_mode():
x_mel = self.audio_pipeline(x_wav.squeeze(1))
x_mel_lengths = (x_wav_lengths / self.hparams.data.hop_length).long()

x_mel, ids_slice = rand_slice_segments(x_mel, x_mel_lengths, self.hparams.train.segment_size // self.hparams.data.hop_length)
y_wav = slice_segments(y_wav, ids_slice * self.hparams.data.hop_length, self.hparams.train.segment_size) # slice
Expand Down Expand Up @@ -157,8 +171,12 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, optimize
def validation_step(self, batch, batch_idx):
self.net_g.eval()

x_mel, x_mel_lengths = batch["x_mel_values"], batch["x_mel_lengths"]
x_wav, x_wav_lengths = batch["x_wav_values"], batch["x_wav_lengths"]
y_wav, y_wav_lengths = batch["y_wav_values"], batch["y_wav_lengths"]

with torch.inference_mode():
x_mel = self.audio_pipeline(x_wav.squeeze(1))
x_mel_lengths = (x_wav_lengths / self.hparams.data.hop_length).long()

y_spec = spectrogram_torch_audio(y_wav.squeeze(1),
self.hparams.data.filter_length,
Expand Down
49 changes: 49 additions & 0 deletions hifigan/model/pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch
import torchaudio
import torchaudio.transforms as T

import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
import random

from ..mel_processing import hann_window

class AudioPipeline(torch.nn.Module):
def __init__(
self,
freq=16000,
n_fft=1024,
n_mel=128,
win_length=1024,
hop_length=256,
device="cpu",
):
super().__init__()

self.freq=freq
self.device=device

pad = int((n_fft-hop_length)/2)
self.spec = T.Spectrogram(n_fft=n_fft, win_length=win_length, hop_length=hop_length,
pad=pad, power=None,center=False, pad_mode='reflect', normalized=False, onesided=True)

# self.strech = T.TimeStretch(hop_length=hop_length, n_freq=freq)
self.spec_aug = torch.nn.Sequential(
T.FrequencyMasking(freq_mask_param=80),
T.TimeMasking(time_mask_param=80),
)

self.mel_scale = T.MelScale(n_mels=n_mel, sample_rate=freq, n_stft=n_fft // 2 + 1)

def forward(self, waveform: torch.Tensor) -> torch.Tensor:
shift_waveform = waveform
# Convert to power spectrogram
spec = self.spec(shift_waveform)
spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-6)
# Apply SpecAugment
spec = self.spec_aug(spec)
# Convert to mel-scale
mel = self.mel_scale(spec)
return mel

0 comments on commit 10ebd5d

Please sign in to comment.