Skip to content

Commit

Permalink
update sofa_gui.py
Browse files Browse the repository at this point in the history
Updates include:
- workaround to prevent trying to load an icon if running on linux
- add support for tgm_sofa_v004
- run in "forced" mode and not "match" mode (idk what they actually do but match is worse)
  • Loading branch information
spicytigermeat authored Jan 1, 2024
1 parent cc56ca5 commit 50d121e
Showing 1 changed file with 27 additions and 9 deletions.
36 changes: 27 additions & 9 deletions sofa_gui.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# GUI specific imports
# general stuff
import os, sys, re, glob
# GUI stuff
import customtkinter as ctk
import tkinter as tk
from ftfy import fix_text as fxy
from ftfy import fix_text as fxy # unicode text all around fix
import threading
# function stuff
import yaml
import pathlib
import warnings
#import pykakasi # japanese language processing

warnings.filterwarnings("ignore")
sys.path.append('.')
Expand All @@ -28,10 +31,22 @@ def __init__(self, lang, wh_model):
self.tokenizer = get_tokenizer(multilingual=True)
self.number_tokens = [i for i in range(self.tokenizer.eot) if all(c in "0123456789" for c in self.tokenizer.decode([i]))]

def run_transcription(self, audio):
def conv_kana2roma(self, string):
# i have no idea how i want to implement this yet lol
#kks = pykakasi.kakasi():
#kks.setMode("J", "H")
#kks.setMode("K", "H")
#jpconv = kks.getConverter()
return 0

def run_transcription(self, audio, lang):
answer = self.model.transcribe(audio, suppress_tokens=[-1] + self.number_tokens)
if lang == "JP":
trns_str = fxy(self.conv_kana2roma(answer['text']))
else:
trns_str = fxy(answer['text'])
print(f"Wrote transcription for {audio} in corpus.")
return answer
return trns_str

class App(ctk.CTk):
def __init__(self):
Expand All @@ -44,11 +59,11 @@ def whisper_function(self):
self.prog_bar.start()
trnsr = Transcriber(self.trans_lang_choice.get(), self.inf_wh_model.get())
try:
for file in glob.glob('corpus/**/*.wav'):
for file in glob.glob('corpus/**/*.wav', recursive=True):
answer = ''
out_name = file[:-4] + '.lab'
answer = trnsr.run_transcription(file)
output = answer['text'].lower()
trns_str = trnsr.run_transcription(file)
output = fxy(trns_str.lower())
final_op = re.sub(r"[.,!?]", "", output)
with open(out_name, 'w+', encoding='utf-8') as whis:
whis.write(final_op)
Expand Down Expand Up @@ -96,7 +111,7 @@ def infer_sofa(self, ckpt, dictionary, op_format):
# load model
torch.set_grad_enabled(False)
model = LitForcedAlignmentTask.load_from_checkpoint(ckpt)
model.set_inference_mode('match')
model.set_inference_mode('force')
trainer = pl.Trainer(logger=False)

# run predictions
Expand Down Expand Up @@ -194,7 +209,10 @@ def update_wh_model(self):
self.title(fxy(self._l[self.clang.get()]['app_ttl']))
self.geometry(f"{350}x{250}")
self.resizable(height=False, width=False)
self.wm_iconbitmap("assets/tgm.ico")

# including this to prevent error when running on Linux
if sys.platform == 'win32':
self.wm_iconbitmap("assets/tgm.ico")

# grid configs
self.grid_columnconfigure(0, weight=1)
Expand Down

0 comments on commit 50d121e

Please sign in to comment.