-
-
Notifications
You must be signed in to change notification settings - Fork 170
/
run_zero_shot_lang_emb_injection.py
31 lines (24 loc) · 1.88 KB
/
run_zero_shot_lang_emb_injection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import argparse
from huggingface_hub import hf_hub_download
from Preprocessing.multilinguality.create_distance_lookups import CacheCreator
from Preprocessing.multilinguality.create_lang_dist_dataset import LangDistDatasetCreator
from Preprocessing.multilinguality.generate_zero_shot_lang_embs import approximate_and_inject_language_embeddings
from Utility.storage_config import MODEL_DIR
if __name__ == "__main__":
default_model_path = hf_hub_download(cache_dir=MODEL_DIR, repo_id="Flux9665/ToucanTTS", filename="ToucanTTS.pt")
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", "-m", type=str, default=default_model_path, help="model path from which to obtain pretrained language embeddings")
parser.add_argument("--distance_type", "-d", type=str, choices=["map", "tree", "asp", "learned", "combined"], default="learned",
help="which type of distance to use for finding nearest languages")
parser.add_argument("--n_closest", "-k", type=int, default=50, help="how many nearest languages to select for language embedding approximation")
args = parser.parse_args()
# make sure that cache files exist
cc = CacheCreator(cache_root="Preprocessing/multilinguality")
cc.create_required_files(model_path=hf_hub_download(cache_dir=MODEL_DIR, repo_id="Flux9665/ToucanTTS", filename="ToucanTTS.pt"))
# create distance dataset
dc = LangDistDatasetCreator(args.model_path, cache_root="Preprocessing/multilinguality")
distance_dataset = dc.create_dataset(args.distance_type, n_closest=args.n_closest, zero_shot=True)
# generate zero-shot lang embs and inject into pretrained model, then save to modified model path
approximate_and_inject_language_embeddings(model_path=args.model_path,
df=distance_dataset,
iso_lookup=dc.iso_lookup)