Skip to content

Commit

Permalink
Add script to convert vits models
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Oct 10, 2023
1 parent 8455057 commit 534d40c
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 0 deletions.
Empty file added scripts/vits/__init__.py
Empty file.
170 changes: 170 additions & 0 deletions scripts/vits/export-onnx-ljs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
#!/usr/bin/env python3

"""
This script converts vits models trained using the LJ Speech dataset.
Usage:
(1) Download vits
cd /Users/fangjun/open-source
git clone https://github.com/jaywalnut310/vits
(2) Download pre-trained models from
https://huggingface.co/csukuangfj/vits-ljs/tree/main
wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/pretrained_ljs.pth
(3) Run this file
./export-onnx-ljs.py \
--config ~/open-source//vits/configs/ljs_base.json \
--checkpoint ~/open-source/icefall-models/vits-ljs/pretrained_ljs.pth
It will generate the following two files:
$ ls -lh *.onnx
-rw-r--r-- 1 fangjun staff 36M Oct 10 20:48 vits-ljs.int8.onnx
-rw-r--r-- 1 fangjun staff 109M Oct 10 20:48 vits-ljs.onnx
"""
import sys

# Please change this line to point to the vits directory.
# You can download vits from
# https://github.com/jaywalnut310/vits
sys.path.insert(0, "/Users/fangjun/open-source/vits")

import argparse
from pathlib import Path

import commons
import torch
import utils
from models import SynthesizerTrn
from onnxruntime.quantization import QuantType, quantize_dynamic
from text import text_to_sequence
from text.symbols import symbols


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
type=str,
required=True,
help="""Path to ljs_base.json.
You can find it at
https://huggingface.co/csukuangfj/vits-ljs/resolve/main/ljs_base.json
""",
)

parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="""Path to the checkpoint file.
You can find it at
https://huggingface.co/csukuangfj/vits-ljs/resolve/main/pretrained_ljs.pth
""",
)

return parser.parse_args()


class OnnxModel(torch.nn.Module):
def __init__(self, model: SynthesizerTrn):
super().__init__()
self.model = model

def forward(
self,
x,
x_lengths,
noise_scale=1,
length_scale=1,
noise_scale_w=1.0,
sid=None,
max_len=None,
):
return self.model.infer(
x=x,
x_lengths=x_lengths,
sid=sid,
noise_scale=noise_scale,
length_scale=length_scale,
noise_scale_w=noise_scale_w,
max_len=max_len,
)[0]


def get_text(text, hps):
text_norm = text_to_sequence(text, hps.data.text_cleaners)
if hps.data.add_blank:
text_norm = commons.intersperse(text_norm, 0)
text_norm = torch.LongTensor(text_norm)
return text_norm


def check_args(args):
assert Path(args.config).is_file(), args.config
assert Path(args.checkpoint).is_file(), args.checkpoint


@torch.no_grad()
def main():
args = get_args()
hps = utils.get_hparams_from_file(args.config)

net_g = SynthesizerTrn(
len(symbols),
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
**hps.model,
)
_ = net_g.eval()

_ = utils.load_checkpoint(args.checkpoint, net_g, None)

x = get_text("Liliana is the most beautiful assistant", hps)
x = x.unsqueeze(0)

x_length = torch.tensor([x.shape[1]], dtype=torch.int64)
noise_scale = torch.tensor([1], dtype=torch.float32)
length_scale = torch.tensor([1], dtype=torch.float32)
noise_scale_w = torch.tensor([1], dtype=torch.float32)

model = OnnxModel(net_g)

opset_version = 13

filename = "vits-ljs.onnx"

torch.onnx.export(
model,
(x, x_length, noise_scale, length_scale, noise_scale_w),
filename,
opset_version=opset_version,
input_names=["x", "x_length", "noise_scale", "length_scale", "noise_scale_w"],
output_names=["y"],
dynamic_axes={
"x": {0: "N", 1: "L"}, # n_audio is also known as batch_size
"x_length": {0: "N"},
"y": {0: "N", 2: "L"},
},
)

print("Generate int8 quantization models")

filename_int8 = "vits-ljs.int8.onnx"
quantize_dynamic(
model_input=filename,
model_output=filename_int8,
weight_type=QuantType.QUInt8,
)

print(f"Saved to {filename} and {filename_int8}")


if __name__ == "__main__":
main()

0 comments on commit 534d40c

Please sign in to comment.