From 50d121e19e4de3991d7682e452d53b4b35db1f0f Mon Sep 17 00:00:00 2001 From: Tyler <103609620+spicytigermeat@users.noreply.github.com> Date: Mon, 1 Jan 2024 14:01:47 -0600 Subject: [PATCH] update sofa_gui.py Updates include: - workaround to prevent trying to load an icon if running on linux - add support for tgm_sofa_v004 - run in "forced" mode and not "match" mode (idk what they actually do but match is worse) --- sofa_gui.py | 36 +++++++++++++++++++++++++++--------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/sofa_gui.py b/sofa_gui.py index 0473291..20f4a7a 100644 --- a/sofa_gui.py +++ b/sofa_gui.py @@ -1,12 +1,15 @@ -# GUI specific imports +# general stuff import os, sys, re, glob +# GUI stuff import customtkinter as ctk import tkinter as tk -from ftfy import fix_text as fxy +from ftfy import fix_text as fxy # unicode text all around fix import threading +# function stuff import yaml import pathlib import warnings +#import pykakasi # japanese language processing warnings.filterwarnings("ignore") sys.path.append('.') @@ -28,10 +31,22 @@ def __init__(self, lang, wh_model): self.tokenizer = get_tokenizer(multilingual=True) self.number_tokens = [i for i in range(self.tokenizer.eot) if all(c in "0123456789" for c in self.tokenizer.decode([i]))] - def run_transcription(self, audio): + def conv_kana2roma(self, string): + # i have no idea how i want to implement this yet lol + #kks = pykakasi.kakasi(): + #kks.setMode("J", "H") + #kks.setMode("K", "H") + #jpconv = kks.getConverter() + return 0 + + def run_transcription(self, audio, lang): answer = self.model.transcribe(audio, suppress_tokens=[-1] + self.number_tokens) + if lang == "JP": + trns_str = fxy(self.conv_kana2roma(answer['text'])) + else: + trns_str = fxy(answer['text']) print(f"Wrote transcription for {audio} in corpus.") - return answer + return trns_str class App(ctk.CTk): def __init__(self): @@ -44,11 +59,11 @@ def whisper_function(self): self.prog_bar.start() trnsr = Transcriber(self.trans_lang_choice.get(), self.inf_wh_model.get()) try: - for file in glob.glob('corpus/**/*.wav'): + for file in glob.glob('corpus/**/*.wav', recursive=True): answer = '' out_name = file[:-4] + '.lab' - answer = trnsr.run_transcription(file) - output = answer['text'].lower() + trns_str = trnsr.run_transcription(file) + output = fxy(trns_str.lower()) final_op = re.sub(r"[.,!?]", "", output) with open(out_name, 'w+', encoding='utf-8') as whis: whis.write(final_op) @@ -96,7 +111,7 @@ def infer_sofa(self, ckpt, dictionary, op_format): # load model torch.set_grad_enabled(False) model = LitForcedAlignmentTask.load_from_checkpoint(ckpt) - model.set_inference_mode('match') + model.set_inference_mode('force') trainer = pl.Trainer(logger=False) # run predictions @@ -194,7 +209,10 @@ def update_wh_model(self): self.title(fxy(self._l[self.clang.get()]['app_ttl'])) self.geometry(f"{350}x{250}") self.resizable(height=False, width=False) - self.wm_iconbitmap("assets/tgm.ico") + + # including this to prevent error when running on Linux + if sys.platform == 'win32': + self.wm_iconbitmap("assets/tgm.ico") # grid configs self.grid_columnconfigure(0, weight=1)