-
Notifications
You must be signed in to change notification settings - Fork 1
/
run_training_pipeline.py
83 lines (67 loc) · 3.45 KB
/
run_training_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import argparse
import os
import random
import sys
import torch
from TrainingInterfaces.TrainingPipelines.FastSpeech2_IntegrationTest import run as fs_integration_test
from TrainingInterfaces.TrainingPipelines.HiFiGAN_Avocodo import run as hifi_codo
from TrainingInterfaces.TrainingPipelines.JointEmbeddingFunction import run as embedding
from TrainingInterfaces.TrainingPipelines.PortaSpeech_IntegrationTest import run as ps_integration_test
from TrainingInterfaces.TrainingPipelines.PortaSpeech_MetaCheckpoint import run as meta
from TrainingInterfaces.TrainingPipelines.finetuning_example import run as fine_tuning_example
from TrainingInterfaces.TrainingPipelines.pretrain_aligner import run as aligner
from TrainingInterfaces.TrainingPipelines.finetune_crewchief import run as crewchief_jim
pipeline_dict = {
"meta" : meta,
"embedding": embedding,
"hifi_codo": hifi_codo,
"aligner" : aligner,
"fine_ex" : fine_tuning_example,
"crewchief_jim" : crewchief_jim,
"fs_it" : fs_integration_test,
"ps_it" : ps_integration_test,
}
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='IMS Speech Synthesis Toolkit - Call to Train')
parser.add_argument('pipeline',
choices=list(pipeline_dict.keys()),
help="Select pipeline to train.")
parser.add_argument('--gpu_id',
type=str,
help="Which GPU to run on. If not specified runs on CPU, but other than for integration tests that doesn't make much sense.",
default="cpu")
parser.add_argument('--resume_checkpoint',
type=str,
help="Path to checkpoint to resume from.",
default=None)
parser.add_argument('--resume',
action="store_true",
help="Automatically load the highest checkpoint and continue from there.",
default=False)
parser.add_argument('--finetune',
action="store_true",
help="Whether to fine-tune from the specified checkpoint.",
default=False)
parser.add_argument('--model_save_dir',
type=str,
help="Directory where the checkpoints should be saved to.",
default=None)
parser.add_argument('--wandb',
action="store_true",
help="Whether to use weigths and biases to track training runs. Requires you to run wandb login and place your auth key before.",
default=False)
parser.add_argument('--wandb_resume_id',
type=str,
help="ID of a stopped wandb run to continue tracking",
default=None)
args = parser.parse_args()
if args.finetune and args.resume_checkpoint is None and not args.resume:
print("Need to provide path to checkpoint to fine-tune from!")
sys.exit()
pipeline_dict[args.pipeline](gpu_id=args.gpu_id,
resume_checkpoint=args.resume_checkpoint,
resume=args.resume,
finetune=args.finetune,
model_dir=args.model_save_dir,
use_wandb=args.wandb,
wandb_resume_id=args.wandb_resume_id)