Skip to content

Commit

Permalink
add bigvgan end to end recipe
Browse files Browse the repository at this point in the history
  • Loading branch information
Flux9665 committed Sep 29, 2024
1 parent 41e3163 commit 7bcab96
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 22 deletions.
13 changes: 2 additions & 11 deletions Modules/Vocoder/BigVGAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Licensed under the MIT license.

# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
# LICENSE is in incl_licenses directory.

import torch
from alias_free_torch import Activation1d
Expand Down Expand Up @@ -63,10 +62,6 @@ def __init__(self,
self.ups[i].apply(init_weights)
self.conv_post.apply(init_weights)

# for Avocodo discriminator
self.out_proj_x1 = torch.nn.Conv1d(upsample_initial_channel // 4, 1, 7, 1, padding=3)
self.out_proj_x2 = torch.nn.Conv1d(upsample_initial_channel // 8, 1, 7, 1, padding=3)

if weights is not None:
self.load_state_dict(weights)

Expand All @@ -86,17 +81,13 @@ def forward(self, x):
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
if i == 1:
x1 = self.out_proj_x1(x)
elif i == 2:
x2 = self.out_proj_x2(x)

# post conv
x = self.activation_post(x)
x = self.conv_post(x)
x = torch.tanh(x)

return x, x2, x1
return x

def remove_weight_norm(self):
print('Removing weight norm...')
Expand Down Expand Up @@ -128,4 +119,4 @@ def get_padding(kernel_size, dilation=1):
if __name__ == '__main__':
vgan = BigVGAN()
print(f"BigVGAN parameter count: {sum(p.numel() for p in vgan.parameters() if p.requires_grad)}")
print(BigVGAN()(torch.randn([1, 128, 100]))[0].shape)
print(BigVGAN()(torch.randn([1, 128, 100])).shape)
11 changes: 1 addition & 10 deletions Modules/Vocoder/HiFiGAN_train_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from Modules.Vocoder.AdversarialLoss import discriminator_adv_loss
from Modules.Vocoder.AdversarialLoss import generator_adv_loss
from Modules.Vocoder.FeatureMatchingLoss import feature_loss
from Modules.Vocoder.MelSpecLoss import MelSpectrogramLoss
from Utility.utils import delete_old_checkpoints
from Utility.utils import get_most_recent_checkpoint
Expand All @@ -33,7 +32,7 @@ def train_loop(generator,
batch_size=32,
epochs=100,
resume=False,
generator_steps_per_discriminator_step=2,
generator_steps_per_discriminator_step=1,
generator_warmup=30000,
use_wandb=False,
finetune=False
Expand Down Expand Up @@ -85,7 +84,6 @@ def train_loop(generator,
discriminator_losses = list()
generator_losses = list()
mel_losses = list()
feat_match_losses = list()
adversarial_losses = list()

optimizer_g.zero_grad()
Expand All @@ -112,11 +110,6 @@ def train_loop(generator,
adversarial_losses.append(adversarial_loss.item())
generator_total_loss = generator_total_loss + adversarial_loss * 2 # based on own experience

d_gold_outs, d_gold_fmaps = d(gold_wave)
feature_matching_loss = feature_loss(d_gold_fmaps, d_fmaps)
feat_match_losses.append(feature_matching_loss.item())
generator_total_loss = generator_total_loss + feature_matching_loss

if torch.isnan(generator_total_loss):
print("Loss turned to NaN, skipping. The GAN possibly collapsed.")
continue
Expand Down Expand Up @@ -177,8 +170,6 @@ def train_loop(generator,
log_dict = dict()
log_dict["Generator Loss"] = round(sum(generator_losses) / len(generator_losses), 3)
log_dict["Mel Loss"] = round(sum(mel_losses) / len(mel_losses), 3)
if len(feat_match_losses) > 0:
log_dict["Feature Matching Loss"] = round(sum(feat_match_losses) / len(feat_match_losses), 3)
if len(adversarial_losses) > 0:
log_dict["Adversarial Loss"] = round(sum(adversarial_losses) / len(adversarial_losses), 3)
if len(discriminator_losses) > 0:
Expand Down
79 changes: 79 additions & 0 deletions Recipes/BigVGAN_e2e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import time

import torch
import wandb

from Modules.Vocoder.BigVGAN import BigVGAN
from Modules.Vocoder.HiFiGAN_Discriminators import AvocodoHiFiGANJointDiscriminator
from Modules.Vocoder.HiFiGAN_E2E_Dataset import HiFiGANDataset
from Modules.Vocoder.HiFiGAN_train_loop import train_loop
from Utility.path_to_transcript_dicts import *
from Utility.storage_config import MODELS_DIR


def run(gpu_id, resume_checkpoint, finetune, resume, model_dir, use_wandb, wandb_resume_id, gpu_count):
if gpu_id == "cpu":
device = torch.device("cpu")
else:
device = torch.device("cuda")

if gpu_count > 1:
print("Multi GPU training not supported for BigVGAN!")
import sys
sys.exit()

print("Preparing")
if model_dir is not None:
model_save_dir = model_dir
else:
model_save_dir = os.path.join(MODELS_DIR, "BigVGAN_e2e")
os.makedirs(model_save_dir, exist_ok=True)

# To prepare the data, have a look at Modules/Vocoder/run_end-to-end_data_creation

print("Collecting new data...")

file_lists_for_this_run_combined = list()
file_lists_for_this_run_combined_synthetic = list()

fl = list(build_path_to_transcript_libritts_all_clean().keys())
fisher_yates_shuffle(fl)
fisher_yates_shuffle(fl)
for i, f in enumerate(fl):
if os.path.exists(f.replace(".wav", "_synthetic_spec.pt")):
file_lists_for_this_run_combined.append(f)
file_lists_for_this_run_combined_synthetic.append(f.replace(".wav", "_synthetic_spec.pt"))
print("filepaths collected")

train_set = HiFiGANDataset(list_of_original_paths=file_lists_for_this_run_combined,
list_of_synthetic_paths=file_lists_for_this_run_combined_synthetic)

generator = BigVGAN()
discriminator = AvocodoHiFiGANJointDiscriminator()

print("Training model")
if use_wandb:
wandb.init(
name=f"{__name__.split('.')[-1]}_{time.strftime('%Y%m%d-%H%M%S')}" if wandb_resume_id is None else None,
id=wandb_resume_id, # this is None if not specified in the command line arguments.
resume="must" if wandb_resume_id is not None else None)
train_loop(batch_size=16,
epochs=5180000,
generator=generator,
discriminator=discriminator,
train_dataset=train_set,
device=device,
epochs_per_save=1,
model_save_dir=model_save_dir,
path_to_checkpoint=resume_checkpoint,
resume=resume,
use_wandb=use_wandb,
finetune=finetune)
if use_wandb:
wandb.finish()


def fisher_yates_shuffle(lst):
for i in range(len(lst) - 1, 0, -1):
j = random.randint(0, i)
lst[i], lst[j] = lst[j], lst[i]
4 changes: 3 additions & 1 deletion run_training_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch

from Recipes.AlignerPipeline import run as aligner
from Recipes.BigVGAN_e2e import run as be2e
from Recipes.HiFiGAN_combined import run as HiFiGAN
from Recipes.HiFiGAN_e2e import run as e2e
from Recipes.ToucanTTS_IntegrationTest import run as tt_integration_test
Expand Down Expand Up @@ -37,7 +38,8 @@
"aligner" : aligner,
# vocoder training (not recommended, best to use provided checkpoint)
"hifigan" : HiFiGAN,
"e2e" : e2e
"e2e" : e2e,
"be2e": be2e
}

if __name__ == '__main__':
Expand Down

0 comments on commit 7bcab96

Please sign in to comment.