Skip to content

Commit

Permalink
refactor multipass
Browse files Browse the repository at this point in the history
  • Loading branch information
Robert M Ochshorn authored and Robert M Ochshorn committed May 26, 2016
1 parent e40bfbb commit de26d56
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 116 deletions.
116 changes: 116 additions & 0 deletions gentle/multipass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import logging
from multiprocessing.pool import ThreadPool as Pool
import os
import wave

from gentle import standard_kaldi
from gentle import metasentence
from gentle import language_model
from gentle.paths import get_resource
from gentle import diff_align

# XXX: refactor out somewhere
proto_langdir = get_resource('PROTO_LANGDIR')
vocab_path = os.path.join(proto_langdir, "graphdir/words.txt")
with open(vocab_path) as f:
vocab = metasentence.load_vocabulary(f)

def prepare_multipass(alignment):
to_realign = []
last_aligned_word = None
cur_unaligned_words = []

for wd_idx,wd in enumerate(alignment):
if wd['case'] == 'not-found-in-audio':
cur_unaligned_words.append(wd)
elif wd['case'] == 'success':
if len(cur_unaligned_words) > 0:
to_realign.append({
"start": last_aligned_word,
"end": wd,
"words": cur_unaligned_words})
cur_unaligned_words = []

last_aligned_word = wd

if len(cur_unaligned_words) > 0:
to_realign.append({
"start": last_aligned_word,
"end": None,
"words": cur_unaligned_words})

return to_realign

def realign(wavfile, alignment, ms, nthreads=4, progress_cb=None):
to_realign = prepare_multipass(alignment)
realignments = []

def realign(chunk):
wav_obj = wave.open(wavfile, 'r')

start_t = (chunk["start"] or {"end": 0})["end"]
end_t = chunk["end"]
if end_t is None:
end_t = wav_obj.getnframes() / float(wav_obj.getframerate())
else:
end_t = end_t["start"]

duration = end_t - start_t
if duration < 0.01 or duration > 60:
logging.debug("cannot realign %d words with duration %f" % (len(chunk['words']), duration))
return

# Create a language model
offset_offset = chunk['words'][0]['startOffset']
chunk_len = chunk['words'][-1]['endOffset'] - offset_offset
chunk_transcript = ms.raw_sentence[offset_offset:offset_offset+chunk_len].encode("utf-8")
chunk_ms = metasentence.MetaSentence(chunk_transcript, vocab)
chunk_ks = chunk_ms.get_kaldi_sequence()

chunk_gen_hclg_filename = language_model.make_bigram_language_model(chunk_ks, proto_langdir)
k = standard_kaldi.Kaldi(
get_resource('data/nnet_a_gpu_online'),
chunk_gen_hclg_filename,
proto_langdir)

wav_obj = wave.open(wavfile, 'r')
wav_obj.setpos(int(start_t * wav_obj.getframerate()))
buf = wav_obj.readframes(int(duration * wav_obj.getframerate()))

k.push_chunk(buf)
ret = k.get_final()
k.stop()

word_alignment = diff_align.align(ret, chunk_ms)

# Adjust startOffset, endOffset, and timing to match originals
for wd in word_alignment:
if wd.get("end"):
# Apply timing offset
wd['start'] += start_t
wd['end'] += start_t

if wd.get("endOffset"):
wd['startOffset'] += offset_offset
wd['endOffset'] += offset_offset

# "chunk" should be replaced by "words"
realignments.append({"chunk": chunk, "words": word_alignment})

if progress_cb is not None:
progress_cb({"percent": len(realignments) / float(len(to_realign))})

pool = Pool(nthreads)
pool.map(realign, to_realign)
pool.close()

# Sub in the replacements
o_words = alignment
for ret in realignments:
st_idx = o_words.index(ret["chunk"]["words"][0])
end_idx= o_words.index(ret["chunk"]["words"][-1])+1
logging.debug('splice in: "%s' % (str(ret["words"])))
logging.debug('splice out: "%s' % (str(o_words[st_idx:end_idx])))
o_words = o_words[:st_idx] + ret["words"] + o_words[end_idx:]

return o_words
163 changes: 50 additions & 113 deletions serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from gentle import diff_align
from gentle import language_model
from gentle import metasentence
from gentle import multipass
from gentle import standard_kaldi
import gentle

Expand All @@ -37,15 +38,32 @@ def render_GET(self, req):
return json.dumps(self.status_dict)

class Transcriber():
def __init__(self, data_dir, nthreads=4):
def __init__(self, data_dir, nthreads=4, ntranscriptionthreads=2):
self.data_dir = data_dir
self.nthreads = nthreads
self.ntranscriptionthreads = ntranscriptionthreads

proto_langdir = get_resource('PROTO_LANGDIR')
vocab_path = os.path.join(proto_langdir, "graphdir/words.txt")
with open(vocab_path) as f:
self.vocab = metasentence.load_vocabulary(f)

# load kaldi instances for full transcription
gen_hclg_filename = get_resource('data/graph/HCLG.fst')

if os.path.exists(gen_hclg_filename) and self.ntranscriptionthreads > 0:
proto_langdir = get_resource('PROTO_LANGDIR')
nnet_gpu_path = get_resource('data/nnet_a_gpu_online')

kaldi_queue = Queue()
for i in range(self.ntranscriptionthreads):
kaldi_queue.put(standard_kaldi.Kaldi(
nnet_gpu_path,
gen_hclg_filename,
proto_langdir)
)
self.full_transcriber = MultiThreadedTranscriber(kaldi_queue, nthreads=self.ntranscriptionthreads)

self._status_dicts = {}

def get_status(self, uid):
Expand Down Expand Up @@ -99,133 +117,50 @@ def transcribe(self, uid, transcript, audio, async):
status['duration'] = wav_obj.getnframes() / float(wav_obj.getframerate())
status['status'] = 'TRANSCRIBING'

def on_progress(p):
for k,v in p.items():
status[k] = v

if len(transcript.strip()) > 0:
ms = metasentence.MetaSentence(transcript, self.vocab)
ks = ms.get_kaldi_sequence()
gen_hclg_filename = language_model.make_bigram_language_model(ks, proto_langdir)
else:
# TODO: We shouldn't load full language models every time;
# these should stay in-memory.
gen_hclg_filename = get_resource('data/graph/HCLG.fst')
if not os.path.exists(gen_hclg_filename):
status["status"] = "ERROR"
status["error"] = 'No transcript provided'
return

kaldi_queue = Queue()
for i in range(self.nthreads):
kaldi_queue.put(standard_kaldi.Kaldi(
get_resource('data/nnet_a_gpu_online'),
gen_hclg_filename,
proto_langdir)
)

def on_progress(p):
for k,v in p.items():
status[k] = v
kaldi_queue = Queue()
for i in range(self.nthreads):
kaldi_queue.put(standard_kaldi.Kaldi(
get_resource('data/nnet_a_gpu_online'),
gen_hclg_filename,
proto_langdir)
)

mtt = MultiThreadedTranscriber(kaldi_queue, nthreads=self.nthreads)
words = mtt.transcribe(wavfile, progress_cb=on_progress)
mtt = MultiThreadedTranscriber(kaldi_queue, nthreads=self.nthreads)
elif hasattr(self, 'full_transcriber'):
mtt = self.full_transcriber
else:
status['status'] = 'ERROR'
status['error'] = 'No transcript provided and no language model for full transcription'
return

# Clear queue
for i in range(self.nthreads):
k = kaldi_queue.get()
k.stop()
words = mtt.transcribe(wavfile, progress_cb=on_progress)

output = {}
if len(transcript.strip()) > 0:
# Clear queue (would this be gc'ed?)
for i in range(self.nthreads):
k = kaldi_queue.get()
k.stop()

# Align words
output['words'] = diff_align.align(words, ms)
output['transcript'] = transcript

# Perform a second-pass with unaligned words
logging.info("%d unaligned words (of %d)" % (len([X for X in output['words'] if X.get("case") == "not-found-in-audio"]), len(output['words'])))

to_realign = []
last_aligned_word = None
cur_unaligned_words = []

for wd_idx,wd in enumerate(output['words']):
if wd['case'] == 'not-found-in-audio':
cur_unaligned_words.append(wd)
elif wd['case'] == 'success':
if len(cur_unaligned_words) > 0:
to_realign.append({
"start": last_aligned_word,
"end": wd,
"words": cur_unaligned_words})
cur_unaligned_words = []

last_aligned_word = wd

if len(cur_unaligned_words) > 0:
to_realign.append({
"start": last_aligned_word,
"end": None,
"words": cur_unaligned_words})

realignments = []

def realign(chunk):
start_t = (chunk["start"] or {"end": 0})["end"]
end_t = (chunk["end"] or {"start": status["duration"]})["start"]
duration = end_t - start_t
if duration < 0.01 or duration > 60:
logging.info("cannot realign %d words with duration %f" % (len(chunk['words']), duration))
return

# Create a language model
offset_offset = chunk['words'][0]['startOffset']
chunk_len = chunk['words'][-1]['endOffset'] - offset_offset
chunk_transcript = ms.raw_sentence[offset_offset:offset_offset+chunk_len].encode("utf-8")
chunk_ms = metasentence.MetaSentence(chunk_transcript, self.vocab)
chunk_ks = chunk_ms.get_kaldi_sequence()

chunk_gen_hclg_filename = language_model.make_bigram_language_model(chunk_ks, proto_langdir)

k = standard_kaldi.Kaldi(
get_resource('data/nnet_a_gpu_online'),
chunk_gen_hclg_filename,
proto_langdir)

wav_obj = wave.open(wavfile, 'r')
wav_obj.setpos(int(start_t * wav_obj.getframerate()))
buf = wav_obj.readframes(int(duration * wav_obj.getframerate()))

k.push_chunk(buf)
ret = k.get_final()
k.stop()
status['status'] = 'ALIGNING'

word_alignment = diff_align.align(ret, chunk_ms)

# Adjust startOffset, endOffset, and timing to match originals
for wd in word_alignment:
if wd.get("end"):
# Apply timing offset
wd['start'] += start_t
wd['end'] += start_t

if wd.get("endOffset"):
wd['startOffset'] += offset_offset
wd['endOffset'] += offset_offset

# "chunk" should be replaced by "words"
realignments.append({"chunk": chunk, "words": word_alignment})

pool = Pool(self.nthreads)
pool.map(realign, to_realign)
pool.close()

# Sub in the replacements
o_words = output['words']
for ret in realignments:
st_idx = o_words.index(ret["chunk"]["words"][0])
end_idx= o_words.index(ret["chunk"]["words"][-1])+1
logging.debug('splice in: "%s' % (str(ret["words"])))
logging.debug('splice out: "%s' % (str(o_words[st_idx:end_idx])))
o_words = o_words[:st_idx] + ret["words"] + o_words[end_idx:]

output['words'] = o_words
output['words'] = multipass.realign(wavfile, output['words'], ms, nthreads=self.nthreads, progress_cb=on_progress)

logging.info("after 2nd pass: %d unaligned words (of %d)" % (len([X for X in output['words'] if X.get("case") == "not-found-in-audio"]), len(output['words'])))

Expand Down Expand Up @@ -361,7 +296,7 @@ def make_transcription_alignment(trans):
trans["words"] = words
return trans

def serve(port=8765, interface='0.0.0.0', installSignalHandlers=0, nthreads=4, data_dir=get_datadir('webdata')):
def serve(port=8765, interface='0.0.0.0', installSignalHandlers=0, nthreads=4, ntranscriptionthreads=2, data_dir=get_datadir('webdata')):
logging.info("SERVE %d, %s, %d", port, interface, installSignalHandlers)

if not os.path.exists(data_dir):
Expand All @@ -377,7 +312,7 @@ def serve(port=8765, interface='0.0.0.0', installSignalHandlers=0, nthreads=4, d
f.putChild('status.html', File(get_resource('www/status.html')))
f.putChild('preloader.gif', File(get_resource('www/preloader.gif')))

trans = Transcriber(data_dir, nthreads=nthreads)
trans = Transcriber(data_dir, nthreads=nthreads, ntranscriptionthreads=ntranscriptionthreads)
trans_ctrl = TranscriptionsController(trans)
f.putChild('transcriptions', trans_ctrl)

Expand All @@ -403,6 +338,8 @@ def serve(port=8765, interface='0.0.0.0', installSignalHandlers=0, nthreads=4, d
help='port number to run http server on')
parser.add_argument('--nthreads', default=multiprocessing.cpu_count(), type=int,
help='number of alignment threads')
parser.add_argument('--ntranscriptionthreads', default=2, type=int,
help='number of full-transcription threads (memory intensive)')
parser.add_argument('--log', default="INFO",
help='the log level (DEBUG, INFO, WARNING, ERROR, or CRITICAL)')

Expand All @@ -414,4 +351,4 @@ def serve(port=8765, interface='0.0.0.0', installSignalHandlers=0, nthreads=4, d
logging.info('gentle %s' % (gentle.__version__))
logging.info('listening at %s:%d\n' % (args.host, args.port))

serve(args.port, args.host, nthreads=args.nthreads, installSignalHandlers=1)
serve(args.port, args.host, nthreads=args.nthreads, ntranscriptionthreads=args.ntranscriptionthreads, installSignalHandlers=1)
10 changes: 7 additions & 3 deletions www/view_alignment.html
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,12 @@ <h1 class="home"><a href="/">Gentle</a></h1>

status_init = true;
}

if(ret.percent && (status_log.length == 0 || status_log[status_log.length-1].percent+0.0001 < ret.percent)) {
if(ret.status !== "TRANSCRIBING") {
if(ret.percent) {
$status_pro.value = (100*ret.percent);
}
}
else if(ret.percent && (status_log.length == 0 || status_log[status_log.length-1].percent+0.0001 < ret.percent)) {
// New entry
var $entry = document.createElement("div");
$entry.className = "entry";
Expand Down Expand Up @@ -369,7 +373,7 @@ <h1 class="home"><a href="/">Gentle</a></h1>
if (ret.status == 'ERROR') {
$preloader.style.visibility = 'hidden';
$trans.innerHTML = '<b>' + ret.status + ': ' + ret.error + '</b>';
} else if (ret.status == 'TRANSCRIBING') {
} else if (ret.status == 'TRANSCRIBING' || ret.status == 'ALIGNING') {
$preloader.style.visibility = 'visible';
render_status(ret);
setTimeout(update, 2000);
Expand Down

0 comments on commit de26d56

Please sign in to comment.