diff --git a/gentle/multipass.py b/gentle/multipass.py new file mode 100644 index 00000000..2c6d14dc --- /dev/null +++ b/gentle/multipass.py @@ -0,0 +1,116 @@ +import logging +from multiprocessing.pool import ThreadPool as Pool +import os +import wave + +from gentle import standard_kaldi +from gentle import metasentence +from gentle import language_model +from gentle.paths import get_resource +from gentle import diff_align + +# XXX: refactor out somewhere +proto_langdir = get_resource('PROTO_LANGDIR') +vocab_path = os.path.join(proto_langdir, "graphdir/words.txt") +with open(vocab_path) as f: + vocab = metasentence.load_vocabulary(f) + +def prepare_multipass(alignment): + to_realign = [] + last_aligned_word = None + cur_unaligned_words = [] + + for wd_idx,wd in enumerate(alignment): + if wd['case'] == 'not-found-in-audio': + cur_unaligned_words.append(wd) + elif wd['case'] == 'success': + if len(cur_unaligned_words) > 0: + to_realign.append({ + "start": last_aligned_word, + "end": wd, + "words": cur_unaligned_words}) + cur_unaligned_words = [] + + last_aligned_word = wd + + if len(cur_unaligned_words) > 0: + to_realign.append({ + "start": last_aligned_word, + "end": None, + "words": cur_unaligned_words}) + + return to_realign + +def realign(wavfile, alignment, ms, nthreads=4, progress_cb=None): + to_realign = prepare_multipass(alignment) + realignments = [] + + def realign(chunk): + wav_obj = wave.open(wavfile, 'r') + + start_t = (chunk["start"] or {"end": 0})["end"] + end_t = chunk["end"] + if end_t is None: + end_t = wav_obj.getnframes() / float(wav_obj.getframerate()) + else: + end_t = end_t["start"] + + duration = end_t - start_t + if duration < 0.01 or duration > 60: + logging.debug("cannot realign %d words with duration %f" % (len(chunk['words']), duration)) + return + + # Create a language model + offset_offset = chunk['words'][0]['startOffset'] + chunk_len = chunk['words'][-1]['endOffset'] - offset_offset + chunk_transcript = ms.raw_sentence[offset_offset:offset_offset+chunk_len].encode("utf-8") + chunk_ms = metasentence.MetaSentence(chunk_transcript, vocab) + chunk_ks = chunk_ms.get_kaldi_sequence() + + chunk_gen_hclg_filename = language_model.make_bigram_language_model(chunk_ks, proto_langdir) + k = standard_kaldi.Kaldi( + get_resource('data/nnet_a_gpu_online'), + chunk_gen_hclg_filename, + proto_langdir) + + wav_obj = wave.open(wavfile, 'r') + wav_obj.setpos(int(start_t * wav_obj.getframerate())) + buf = wav_obj.readframes(int(duration * wav_obj.getframerate())) + + k.push_chunk(buf) + ret = k.get_final() + k.stop() + + word_alignment = diff_align.align(ret, chunk_ms) + + # Adjust startOffset, endOffset, and timing to match originals + for wd in word_alignment: + if wd.get("end"): + # Apply timing offset + wd['start'] += start_t + wd['end'] += start_t + + if wd.get("endOffset"): + wd['startOffset'] += offset_offset + wd['endOffset'] += offset_offset + + # "chunk" should be replaced by "words" + realignments.append({"chunk": chunk, "words": word_alignment}) + + if progress_cb is not None: + progress_cb({"percent": len(realignments) / float(len(to_realign))}) + + pool = Pool(nthreads) + pool.map(realign, to_realign) + pool.close() + + # Sub in the replacements + o_words = alignment + for ret in realignments: + st_idx = o_words.index(ret["chunk"]["words"][0]) + end_idx= o_words.index(ret["chunk"]["words"][-1])+1 + logging.debug('splice in: "%s' % (str(ret["words"]))) + logging.debug('splice out: "%s' % (str(o_words[st_idx:end_idx]))) + o_words = o_words[:st_idx] + ret["words"] + o_words[end_idx:] + + return o_words diff --git a/serve.py b/serve.py index 97a06972..e574b3db 100644 --- a/serve.py +++ b/serve.py @@ -24,6 +24,7 @@ from gentle import diff_align from gentle import language_model from gentle import metasentence +from gentle import multipass from gentle import standard_kaldi import gentle @@ -37,15 +38,32 @@ def render_GET(self, req): return json.dumps(self.status_dict) class Transcriber(): - def __init__(self, data_dir, nthreads=4): + def __init__(self, data_dir, nthreads=4, ntranscriptionthreads=2): self.data_dir = data_dir self.nthreads = nthreads + self.ntranscriptionthreads = ntranscriptionthreads proto_langdir = get_resource('PROTO_LANGDIR') vocab_path = os.path.join(proto_langdir, "graphdir/words.txt") with open(vocab_path) as f: self.vocab = metasentence.load_vocabulary(f) + # load kaldi instances for full transcription + gen_hclg_filename = get_resource('data/graph/HCLG.fst') + + if os.path.exists(gen_hclg_filename) and self.ntranscriptionthreads > 0: + proto_langdir = get_resource('PROTO_LANGDIR') + nnet_gpu_path = get_resource('data/nnet_a_gpu_online') + + kaldi_queue = Queue() + for i in range(self.ntranscriptionthreads): + kaldi_queue.put(standard_kaldi.Kaldi( + nnet_gpu_path, + gen_hclg_filename, + proto_langdir) + ) + self.full_transcriber = MultiThreadedTranscriber(kaldi_queue, nthreads=self.ntranscriptionthreads) + self._status_dicts = {} def get_status(self, uid): @@ -99,41 +117,40 @@ def transcribe(self, uid, transcript, audio, async): status['duration'] = wav_obj.getnframes() / float(wav_obj.getframerate()) status['status'] = 'TRANSCRIBING' + def on_progress(p): + for k,v in p.items(): + status[k] = v + if len(transcript.strip()) > 0: ms = metasentence.MetaSentence(transcript, self.vocab) ks = ms.get_kaldi_sequence() gen_hclg_filename = language_model.make_bigram_language_model(ks, proto_langdir) - else: - # TODO: We shouldn't load full language models every time; - # these should stay in-memory. - gen_hclg_filename = get_resource('data/graph/HCLG.fst') - if not os.path.exists(gen_hclg_filename): - status["status"] = "ERROR" - status["error"] = 'No transcript provided' - return - - kaldi_queue = Queue() - for i in range(self.nthreads): - kaldi_queue.put(standard_kaldi.Kaldi( - get_resource('data/nnet_a_gpu_online'), - gen_hclg_filename, - proto_langdir) - ) - def on_progress(p): - for k,v in p.items(): - status[k] = v + kaldi_queue = Queue() + for i in range(self.nthreads): + kaldi_queue.put(standard_kaldi.Kaldi( + get_resource('data/nnet_a_gpu_online'), + gen_hclg_filename, + proto_langdir) + ) - mtt = MultiThreadedTranscriber(kaldi_queue, nthreads=self.nthreads) - words = mtt.transcribe(wavfile, progress_cb=on_progress) + mtt = MultiThreadedTranscriber(kaldi_queue, nthreads=self.nthreads) + elif hasattr(self, 'full_transcriber'): + mtt = self.full_transcriber + else: + status['status'] = 'ERROR' + status['error'] = 'No transcript provided and no language model for full transcription' + return - # Clear queue - for i in range(self.nthreads): - k = kaldi_queue.get() - k.stop() + words = mtt.transcribe(wavfile, progress_cb=on_progress) output = {} if len(transcript.strip()) > 0: + # Clear queue (would this be gc'ed?) + for i in range(self.nthreads): + k = kaldi_queue.get() + k.stop() + # Align words output['words'] = diff_align.align(words, ms) output['transcript'] = transcript @@ -141,91 +158,9 @@ def on_progress(p): # Perform a second-pass with unaligned words logging.info("%d unaligned words (of %d)" % (len([X for X in output['words'] if X.get("case") == "not-found-in-audio"]), len(output['words']))) - to_realign = [] - last_aligned_word = None - cur_unaligned_words = [] - - for wd_idx,wd in enumerate(output['words']): - if wd['case'] == 'not-found-in-audio': - cur_unaligned_words.append(wd) - elif wd['case'] == 'success': - if len(cur_unaligned_words) > 0: - to_realign.append({ - "start": last_aligned_word, - "end": wd, - "words": cur_unaligned_words}) - cur_unaligned_words = [] - - last_aligned_word = wd - - if len(cur_unaligned_words) > 0: - to_realign.append({ - "start": last_aligned_word, - "end": None, - "words": cur_unaligned_words}) - - realignments = [] - - def realign(chunk): - start_t = (chunk["start"] or {"end": 0})["end"] - end_t = (chunk["end"] or {"start": status["duration"]})["start"] - duration = end_t - start_t - if duration < 0.01 or duration > 60: - logging.info("cannot realign %d words with duration %f" % (len(chunk['words']), duration)) - return - - # Create a language model - offset_offset = chunk['words'][0]['startOffset'] - chunk_len = chunk['words'][-1]['endOffset'] - offset_offset - chunk_transcript = ms.raw_sentence[offset_offset:offset_offset+chunk_len].encode("utf-8") - chunk_ms = metasentence.MetaSentence(chunk_transcript, self.vocab) - chunk_ks = chunk_ms.get_kaldi_sequence() - - chunk_gen_hclg_filename = language_model.make_bigram_language_model(chunk_ks, proto_langdir) - - k = standard_kaldi.Kaldi( - get_resource('data/nnet_a_gpu_online'), - chunk_gen_hclg_filename, - proto_langdir) - - wav_obj = wave.open(wavfile, 'r') - wav_obj.setpos(int(start_t * wav_obj.getframerate())) - buf = wav_obj.readframes(int(duration * wav_obj.getframerate())) - - k.push_chunk(buf) - ret = k.get_final() - k.stop() + status['status'] = 'ALIGNING' - word_alignment = diff_align.align(ret, chunk_ms) - - # Adjust startOffset, endOffset, and timing to match originals - for wd in word_alignment: - if wd.get("end"): - # Apply timing offset - wd['start'] += start_t - wd['end'] += start_t - - if wd.get("endOffset"): - wd['startOffset'] += offset_offset - wd['endOffset'] += offset_offset - - # "chunk" should be replaced by "words" - realignments.append({"chunk": chunk, "words": word_alignment}) - - pool = Pool(self.nthreads) - pool.map(realign, to_realign) - pool.close() - - # Sub in the replacements - o_words = output['words'] - for ret in realignments: - st_idx = o_words.index(ret["chunk"]["words"][0]) - end_idx= o_words.index(ret["chunk"]["words"][-1])+1 - logging.debug('splice in: "%s' % (str(ret["words"]))) - logging.debug('splice out: "%s' % (str(o_words[st_idx:end_idx]))) - o_words = o_words[:st_idx] + ret["words"] + o_words[end_idx:] - - output['words'] = o_words + output['words'] = multipass.realign(wavfile, output['words'], ms, nthreads=self.nthreads, progress_cb=on_progress) logging.info("after 2nd pass: %d unaligned words (of %d)" % (len([X for X in output['words'] if X.get("case") == "not-found-in-audio"]), len(output['words']))) @@ -361,7 +296,7 @@ def make_transcription_alignment(trans): trans["words"] = words return trans -def serve(port=8765, interface='0.0.0.0', installSignalHandlers=0, nthreads=4, data_dir=get_datadir('webdata')): +def serve(port=8765, interface='0.0.0.0', installSignalHandlers=0, nthreads=4, ntranscriptionthreads=2, data_dir=get_datadir('webdata')): logging.info("SERVE %d, %s, %d", port, interface, installSignalHandlers) if not os.path.exists(data_dir): @@ -377,7 +312,7 @@ def serve(port=8765, interface='0.0.0.0', installSignalHandlers=0, nthreads=4, d f.putChild('status.html', File(get_resource('www/status.html'))) f.putChild('preloader.gif', File(get_resource('www/preloader.gif'))) - trans = Transcriber(data_dir, nthreads=nthreads) + trans = Transcriber(data_dir, nthreads=nthreads, ntranscriptionthreads=ntranscriptionthreads) trans_ctrl = TranscriptionsController(trans) f.putChild('transcriptions', trans_ctrl) @@ -403,6 +338,8 @@ def serve(port=8765, interface='0.0.0.0', installSignalHandlers=0, nthreads=4, d help='port number to run http server on') parser.add_argument('--nthreads', default=multiprocessing.cpu_count(), type=int, help='number of alignment threads') + parser.add_argument('--ntranscriptionthreads', default=2, type=int, + help='number of full-transcription threads (memory intensive)') parser.add_argument('--log', default="INFO", help='the log level (DEBUG, INFO, WARNING, ERROR, or CRITICAL)') @@ -414,4 +351,4 @@ def serve(port=8765, interface='0.0.0.0', installSignalHandlers=0, nthreads=4, d logging.info('gentle %s' % (gentle.__version__)) logging.info('listening at %s:%d\n' % (args.host, args.port)) - serve(args.port, args.host, nthreads=args.nthreads, installSignalHandlers=1) + serve(args.port, args.host, nthreads=args.nthreads, ntranscriptionthreads=args.ntranscriptionthreads, installSignalHandlers=1) diff --git a/www/view_alignment.html b/www/view_alignment.html index 9d634261..f205ead1 100644 --- a/www/view_alignment.html +++ b/www/view_alignment.html @@ -333,8 +333,12 @@

Gentle

status_init = true; } - - if(ret.percent && (status_log.length == 0 || status_log[status_log.length-1].percent+0.0001 < ret.percent)) { + if(ret.status !== "TRANSCRIBING") { + if(ret.percent) { + $status_pro.value = (100*ret.percent); + } + } + else if(ret.percent && (status_log.length == 0 || status_log[status_log.length-1].percent+0.0001 < ret.percent)) { // New entry var $entry = document.createElement("div"); $entry.className = "entry"; @@ -369,7 +373,7 @@

Gentle

if (ret.status == 'ERROR') { $preloader.style.visibility = 'hidden'; $trans.innerHTML = '' + ret.status + ': ' + ret.error + ''; - } else if (ret.status == 'TRANSCRIBING') { + } else if (ret.status == 'TRANSCRIBING' || ret.status == 'ALIGNING') { $preloader.style.visibility = 'visible'; render_status(ret); setTimeout(update, 2000);