Skip to content

Commit

Permalink
Merge pull request #38 from maxhawkins/dupe
Browse files Browse the repository at this point in the history
clean up language model generation
  • Loading branch information
strob committed Nov 26, 2015
2 parents 06dbc62 + 50c2f24 commit abfa0db
Show file tree
Hide file tree
Showing 5 changed files with 3,247 additions and 3,293 deletions.
79 changes: 0 additions & 79 deletions gentle/generate_wp.py

This file was deleted.

94 changes: 63 additions & 31 deletions gentle/language_model.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,84 @@
import logging
import math
import os
import shutil
import subprocess
import sys
import tempfile

from paths import get_binary
from generate_wp import language_model_from_word_sequence
from metasentence import MetaSentence

MKGRAPH_PATH = get_binary("mkgraph")

def get_language_model(kaldi_seq, proto_langdir='PROTO_LANGDIR'):
"""Generates a language model to fit the text
def make_bigram_lm_fst(word_sequence):
'''
Use the given token sequence to make a bigram language model
in OpenFST plain text format.
'''
word_sequence = ['[oov]', '[oov]'] + word_sequence + ['[oov]']

bigrams = {}
prev_word = word_sequence[0]
for word in word_sequence[1:]:
bigrams.setdefault(prev_word, set()).add(word)
prev_word = word

node_ids = {}
def get_node_id(word):
node_id = node_ids.get(word, len(node_ids) + 1)
node_ids[word] = node_id
return node_id

output = ""
for from_word in sorted(bigrams.keys()):
from_id = get_node_id(from_word)

successors = bigrams[from_word]
if len(successors) > 0:
weight = -math.log(1.0 / len(successors))
else:
weight = 0

for to_word in sorted(successors):
to_id = get_node_id(to_word)
output += '%d %d %s %s %f' % (from_id, to_id, to_word, to_word, weight)
output += "\n"

output += "%d 0\n" % (len(node_ids))

return output

def make_bigram_language_model(kaldi_seq, proto_langdir='PROTO_LANGDIR'):
"""Generates a language model to fit the text.
Returns the filename of the generated language model FST.
The caller is resposible for removing the generated file.
`proto_langdir` is a path to a directory containing prototype model data
`kaldi_seq` is a list of words within kaldi's vocabulary.
"""

# Create a language model directory
lang_model_dir = tempfile.mkdtemp()
logging.info('saving language model to %s', lang_model_dir)

# Symlink in necessary files from the prototype directory
for dirpath, dirnames, filenames in os.walk(proto_langdir, followlinks=True):
for dirname in dirnames:
relpath = os.path.relpath(os.path.join(dirpath, dirname), proto_langdir)
os.makedirs(os.path.join(lang_model_dir, relpath))
for filename in filenames:
abspath = os.path.abspath(os.path.join(dirpath, filename))
relpath = os.path.relpath(os.path.join(dirpath, filename), proto_langdir)
dstpath = os.path.join(lang_model_dir, relpath)
os.symlink(abspath, dstpath)

# Generate a textual FST
txt_fst = language_model_from_word_sequence(kaldi_seq)
txt_fst_file = os.path.join(lang_model_dir, 'G.txt')
open(txt_fst_file, 'w').write(txt_fst)
txt_fst = make_bigram_lm_fst(kaldi_seq)
txt_fst_file = tempfile.NamedTemporaryFile(delete=False)
txt_fst_file.write(txt_fst)
txt_fst_file.close()

words_file = os.path.join(proto_langdir, "graphdir/words.txt")
subprocess.check_output([MKGRAPH_PATH,
os.path.join(lang_model_dir, 'langdir'),
os.path.join(lang_model_dir, 'modeldir'),
txt_fst_file,
words_file,
os.path.join(lang_model_dir, 'graphdir', 'HCLG.fst')])
hclg_filename = tempfile.mktemp(suffix='_HCLG.fst')
try:
subprocess.check_output([MKGRAPH_PATH,
proto_langdir,
txt_fst_file.name,
hclg_filename])
except Exception, e:
os.unlink(hclg_filename)
raise e
finally:
os.unlink(txt_fst_file.name)

# Return the language model directory
return lang_model_dir
return hclg_filename

if __name__=='__main__':
import sys
get_language_model(open(sys.argv[1]).read())
make_bigram_language_model(open(sys.argv[1]).read())
19 changes: 9 additions & 10 deletions gentle/language_model_transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,17 @@ def lm_transcribe(audio_f, transcript, proto_langdir, nnet_dir,

ks = ms.get_kaldi_sequence()

gen_model_dir = language_model.get_language_model(ks, proto_langdir)
gen_hclg_filename = language_model.make_bigram_language_model(ks, proto_langdir)
try:
k = standard_kaldi.Kaldi(nnet_dir, gen_hclg_filename, proto_langdir)

gen_hclg_path = os.path.join(gen_model_dir, 'graphdir', 'HCLG.fst')
k = standard_kaldi.Kaldi(nnet_dir, gen_hclg_path, proto_langdir)
trans = standard_kaldi.transcribe(k, audio_f,
partial_results_cb=partial_cb,
partial_results_kwargs=partial_kwargs)

trans = standard_kaldi.transcribe(k, audio_f,
partial_results_cb=partial_cb,
partial_results_kwargs=partial_kwargs)

ret = diff_align.align(trans["words"], ms)

shutil.rmtree(gen_model_dir)
ret = diff_align.align(trans["words"], ms)
finally:
os.unlink(gen_hclg_filename)

return {
"transcript": transcript,
Expand Down
25 changes: 12 additions & 13 deletions mkgraph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ int main(int argc, char *argv[]) {
using namespace fst;
using fst::script::ArcSort;
try {
const char *usage = "Usage: ./mkgraph [options] <lang-dir> <model-dir> <grammar-fst> <words-txt> <out-fst>\n";
const char *usage = "Usage: ./mkgraph [options] <proto-dir> <grammar-fst> <out-fst>\n";

ParseOptions po(usage);
po.Read(argc, argv);
if (po.NumArgs() != 5) {
if (po.NumArgs() != 3) {
po.PrintUsage();
return 1;
}
Expand All @@ -27,17 +27,16 @@ int main(int argc, char *argv[]) {
float self_loop_scale = 0.1;
bool reverse = false;

std::string lang_dir = po.GetArg(1),
model_dir = po.GetArg(2),
grammar_fst_filename = po.GetArg(3),
words_filename = po.GetArg(4),
out_filename = po.GetArg(5);

std::string lang_fst_filename = lang_dir + "/L.fst",
lang_disambig_fst_filename = lang_dir + "/L_disambig.fst",
disambig_phones_filename = lang_dir + "/phones/disambig.int",
model_filename = model_dir + "/final.mdl",
tree_filename = model_dir + "/tree";
std::string proto_dir = po.GetArg(1),
grammar_fst_filename = po.GetArg(2),
out_filename = po.GetArg(3);

std::string lang_fst_filename = proto_dir + "/langdir/L.fst",
lang_disambig_fst_filename = proto_dir + "/langdir/L_disambig.fst",
disambig_phones_filename = proto_dir + "/langdir/phones/disambig.int",
model_filename = proto_dir + "/modeldir/final.mdl",
tree_filename = proto_dir + "/modeldir/tree",
words_filename = proto_dir + "/graphdir/words.txt";

if (!std::ifstream(lang_fst_filename.c_str())) {
std::cerr << "expected " << lang_fst_filename << " to exist" << std::endl;
Expand Down
Loading

0 comments on commit abfa0db

Please sign in to comment.