diff --git a/TrainingInterfaces/Text_to_Spectrogram/ToucanTTS/toucantts_meta_train_loop.py b/TrainingInterfaces/Text_to_Spectrogram/ToucanTTS/toucantts_meta_train_loop.py index 1ad687cb..0fe56ae5 100644 --- a/TrainingInterfaces/Text_to_Spectrogram/ToucanTTS/toucantts_meta_train_loop.py +++ b/TrainingInterfaces/Text_to_Spectrogram/ToucanTTS/toucantts_meta_train_loop.py @@ -54,6 +54,11 @@ def train_loop(net, """ net = net.to(device) + if steps % steps_per_checkpoint == 0: + steps = steps + 1 + else: + steps = steps + ((steps_per_checkpoint + 1) - (steps % steps_per_checkpoint)) # making sure to stop at the closest point that makes sense to the specified stopping point + style_embedding_function = StyleEmbedding().to(device) check_dict = torch.load(path_to_embed_model, map_location=device) style_embedding_function.load_state_dict(check_dict["style_emb_func"])