Skip to content

Commit

Permalink
update test code
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Nov 14, 2022
1 parent 1e72f27 commit c8af27d
Show file tree
Hide file tree
Showing 11 changed files with 83 additions and 17 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,4 @@ dmypy.json
.pyre/

/logs/
*.wav
54 changes: 54 additions & 0 deletions configs/16k.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
{
"trainer": {
"max_epochs": 20000,
"limit_val_batches": 1,
"accumulate_grad_batches": 1,
"default_root_dir": "./logs",
"val_check_interval": 1000
},
"train": {
"log_interval": 200,
"eval_interval": 1000,
"seed": 1234,
"max_epochs": 20000,
"learning_rate": 2e-4,
"betas": [0.8, 0.99],
"eps": 1e-9,
"batch_size": 64,
"fp16_run": true,
"lr_decay": 0.999875,
"segment_size": 16384,
"init_lr_ratio": 1,
"warmup_epochs": 0,
"c_mel": 45
},
"data": {
"training_files": "filelists/example_audio_filelist_train.txt",
"validation_files": "filelists/example_audio_filelist_valid.txt",
"max_wav_value": 32768.0,
"sampling_rate": 16000,
"filter_length": 1024,
"hop_length": 256,
"win_length": 1024,
"n_mel_channels": 256,
"mel_fmin": 0.0,
"mel_fmax": null
},
"model": {
"inter_channels": 256,
"hidden_channels": 256,
"filter_channels": 768,
"kernel_size": 3,
"p_dropout": 0.1,
"resblock": "1",
"resblock_kernel_sizes": [3,7,11],
"resblock_dilation_sizes": [
[1,3,5],
[1,3,5],
[1,3,5]
],
"upsample_rates": [8,8,4],
"upsample_initial_channel": 512,
"upsample_kernel_sizes": [16,16,8]
}
}
Empty file added configs/32k.json
Empty file.
Empty file added configs/44.1k.json
Empty file.
10 changes: 2 additions & 8 deletions configs/base.json → configs/48k.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"learning_rate": 2e-4,
"betas": [0.8, 0.99],
"eps": 1e-9,
"batch_size": 32,
"batch_size": 64,
"fp16_run": true,
"lr_decay": 0.999875,
"segment_size": 16384,
Expand All @@ -37,10 +37,7 @@
"model": {
"inter_channels": 256,
"hidden_channels": 256,
"hubert_channels": 1280,
"filter_channels": 768,
"n_heads": 4,
"n_layers": 6,
"kernel_size": 3,
"p_dropout": 0.1,
"resblock": "1",
Expand All @@ -52,9 +49,6 @@
],
"upsample_rates": [8,8,4,2],
"upsample_initial_channel": 512,
"upsample_kernel_sizes": [16,16,4,4],
"n_layers_q": 3,
"use_spectral_norm": false,
"gin_channels": 256
"upsample_kernel_sizes": [16,16,4,4]
}
}
2 changes: 2 additions & 0 deletions hifigan/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from ..utils import load_filepaths, load_wav_to_torch

resamplers = {}

def load_audio(filename: str, sr: Optional[int] = None):
global resamplers
audio, sampling_rate = load_wav_to_torch(filename)
Expand Down
4 changes: 2 additions & 2 deletions hifigan/model/generators/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
from ..commons import init_weights, get_padding


class Generator(torch.nn.Module):
class HifiGANGenerator(torch.nn.Module):
def __init__(self, initial_channel: int,
resblock: Union[str, ResBlock2],
resblock_kernel_sizes: List[int],
resblock_dilation_sizes: List[int],
upsample_rates: List[int],
upsample_initial_channel: int,
upsample_kernel_sizes: List[int]):
super(Generator, self).__init__()
super(HifiGANGenerator, self).__init__()
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.conv_pre = Conv1d(in_channels=initial_channel, out_channels=upsample_initial_channel, kernel_size=7, stride=1, padding=3)
Expand Down
9 changes: 3 additions & 6 deletions hifigan/model/hifigan.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from .discriminators.multi_scale_discriminator import MultiScaleDiscriminator
from .discriminators.multi_period_discriminator import MultiPeriodDiscriminator
from .generators.generator import Generator
from .generators.generator import HifiGANGenerator

from ..mel_processing import spec_to_mel_torch, mel_spectrogram_torch, spectrogram_torch, spectrogram_torch_audio
from .losses import discriminator_loss, kl_loss,feature_loss, generator_loss
Expand All @@ -27,7 +27,7 @@ def __init__(self, **kwargs):
super().__init__()
self.save_hyperparameters(*[k for k in kwargs])

self.net_g = Generator(
self.net_g = HifiGANGenerator(
self.hparams.model.inter_channels,
self.hparams.model.resblock,
self.hparams.model.resblock_kernel_sizes,
Expand Down Expand Up @@ -163,11 +163,8 @@ def validation_step(self, batch, batch_idx):
y_spec_lengths = (y_wav_lengths / self.hparams.data.hop_length).long()

# remove else
y_spec = y_spec[:1]
y_spec_lengths = y_spec_lengths[:1]

y_hat = self.net_g(x_mel)
y_hat_lengths = y_spec_lengths
y_hat_lengths = torch.tensor([y_hat.shape[2]], dtype=torch.long)

mel = spec_to_mel_torch(
y_spec,
Expand Down
3 changes: 3 additions & 0 deletions hub_conf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
dependencies = ["torch", "torchaudio"]

from hifigan.model.generators.generator import HifiGANGenerator
15 changes: 15 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch, torchaudio

from hifigan.mel_processing import mel_spectrogram_torch

# Load checkpoint (either hubert_soft or hubert_discrete)
hifigan = torch.hub.load("vtuber-plan/hifi-gan:main", "HifiGANGenerator").cuda()

# Load audio
wav, sr = torchaudio.load("test.wav")
assert sr == 48000
wav = wav.unsqueeze(0).cuda()

mel = mel_spectrogram_torch(wav, 2048, 256, 48000, 512, 2048, 0, None, False)

units = hifigan(mel)
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_hparams(config_path: str) -> HParams:

def main():
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, default="./configs/base.json", help='JSON file for configuration')
parser.add_argument('-c', '--config', type=str, default="./configs/48k.json", help='JSON file for configuration')
parser.add_argument('-a', '--accelerator', type=str, default="gpu", help='training device')
parser.add_argument('-d', '--device', type=str, default="0", help='training device ids')
args = parser.parse_args()
Expand Down

0 comments on commit c8af27d

Please sign in to comment.