diff --git a/scripts/vits/__init__.py b/scripts/vits/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/scripts/vits/export-onnx-ljs.py b/scripts/vits/export-onnx-ljs.py new file mode 100755 index 000000000..916f12bb2 --- /dev/null +++ b/scripts/vits/export-onnx-ljs.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 + +""" +This script converts vits models trained using the LJ Speech dataset. + +Usage: + +(1) Download vits + +cd /Users/fangjun/open-source +git clone https://github.com/jaywalnut310/vits + +(2) Download pre-trained models from +https://huggingface.co/csukuangfj/vits-ljs/tree/main + +wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/pretrained_ljs.pth + +(3) Run this file + +./export-onnx-ljs.py \ + --config ~/open-source//vits/configs/ljs_base.json \ + --checkpoint ~/open-source/icefall-models/vits-ljs/pretrained_ljs.pth + +It will generate the following two files: + +$ ls -lh *.onnx +-rw-r--r-- 1 fangjun staff 36M Oct 10 20:48 vits-ljs.int8.onnx +-rw-r--r-- 1 fangjun staff 109M Oct 10 20:48 vits-ljs.onnx +""" +import sys + +# Please change this line to point to the vits directory. +# You can download vits from +# https://github.com/jaywalnut310/vits +sys.path.insert(0, "/Users/fangjun/open-source/vits") + +import argparse +from pathlib import Path + +import commons +import torch +import utils +from models import SynthesizerTrn +from onnxruntime.quantization import QuantType, quantize_dynamic +from text import text_to_sequence +from text.symbols import symbols + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", + type=str, + required=True, + help="""Path to ljs_base.json. + You can find it at + https://huggingface.co/csukuangfj/vits-ljs/resolve/main/ljs_base.json + """, + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="""Path to the checkpoint file. + You can find it at + https://huggingface.co/csukuangfj/vits-ljs/resolve/main/pretrained_ljs.pth + + """, + ) + + return parser.parse_args() + + +class OnnxModel(torch.nn.Module): + def __init__(self, model: SynthesizerTrn): + super().__init__() + self.model = model + + def forward( + self, + x, + x_lengths, + noise_scale=1, + length_scale=1, + noise_scale_w=1.0, + sid=None, + max_len=None, + ): + return self.model.infer( + x=x, + x_lengths=x_lengths, + sid=sid, + noise_scale=noise_scale, + length_scale=length_scale, + noise_scale_w=noise_scale_w, + max_len=max_len, + )[0] + + +def get_text(text, hps): + text_norm = text_to_sequence(text, hps.data.text_cleaners) + if hps.data.add_blank: + text_norm = commons.intersperse(text_norm, 0) + text_norm = torch.LongTensor(text_norm) + return text_norm + + +def check_args(args): + assert Path(args.config).is_file(), args.config + assert Path(args.checkpoint).is_file(), args.checkpoint + + +@torch.no_grad() +def main(): + args = get_args() + hps = utils.get_hparams_from_file(args.config) + + net_g = SynthesizerTrn( + len(symbols), + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + **hps.model, + ) + _ = net_g.eval() + + _ = utils.load_checkpoint(args.checkpoint, net_g, None) + + x = get_text("Liliana is the most beautiful assistant", hps) + x = x.unsqueeze(0) + + x_length = torch.tensor([x.shape[1]], dtype=torch.int64) + noise_scale = torch.tensor([1], dtype=torch.float32) + length_scale = torch.tensor([1], dtype=torch.float32) + noise_scale_w = torch.tensor([1], dtype=torch.float32) + + model = OnnxModel(net_g) + + opset_version = 13 + + filename = "vits-ljs.onnx" + + torch.onnx.export( + model, + (x, x_length, noise_scale, length_scale, noise_scale_w), + filename, + opset_version=opset_version, + input_names=["x", "x_length", "noise_scale", "length_scale", "noise_scale_w"], + output_names=["y"], + dynamic_axes={ + "x": {0: "N", 1: "L"}, # n_audio is also known as batch_size + "x_length": {0: "N"}, + "y": {0: "N", 2: "L"}, + }, + ) + + print("Generate int8 quantization models") + + filename_int8 = "vits-ljs.int8.onnx" + quantize_dynamic( + model_input=filename, + model_output=filename_int8, + weight_type=QuantType.QUInt8, + ) + + print(f"Saved to {filename} and {filename_int8}") + + +if __name__ == "__main__": + main()