Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a simple language model maker script #405

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
353 changes: 353 additions & 0 deletions cython/pocketsphinx/lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,353 @@
#!/usr/bin/env python

import argparse
import sys
from math import log
import re
from collections import defaultdict
from datetime import date
import unicodedata as ud
from io import StringIO
from typing import Optional, Dict, TextIO, Any

# Author: Kevin Lenzo
# Based on a Perl script by Alex Rudnicky

class ArpaBoLM:
"""
A simple ARPA model builder
"""
log10 = log(10.0)
norm_exclude_categories = set(['P', 'S', 'C', 'M', 'Z'])

def __init__(
self,
sentfile: Optional[str] = None,
text: Optional[str] = None,
add_start: bool = False,
word_file: Optional[str] = None,
word_file_count: int = 1,
discount_mass: float = 0.5,
case: Optional[str] = None, # lower, upper
norm: bool = False,
verbose: bool = False,
):
self.sentfile = sentfile
self.text = text
self.add_start = add_start
self.word_file = word_file
self.word_file_count = word_file_count
self.discount_mass = discount_mass
self.case = case
self.norm = norm
self.verbose = verbose

self.logfile = sys.stdout

if self.verbose:
print('Started', date.today(),
file=self.logfile)

if discount_mass is None: # TODO: add other smoothing methods
self.discount_mass = 0.5
elif not 0.0 < discount_mass < 1.0:
raise AttributeError(f'Discount value ({discount_mass}) out of range [0.0, 1.0]')

self.deflator: float = 1.0 - self.discount_mass

self.sent_count = 0

self.grams_1: Any = defaultdict(int)
self.grams_2: Any = defaultdict(lambda: defaultdict(int))
self.grams_3: Any = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))

self.sum_1: int = 0
self.count_1: int = 0
self.count_2: int = 0
self.count_3: int = 0

self.prob_1: Dict[str, float] = {}
self.alpha_1: Dict[str, float] = {}
self.prob_2: Any = defaultdict(lambda: defaultdict(float))
self.alpha_2: Any = defaultdict(lambda: defaultdict(float))

if self.sentfile is not None:
with open(str(sentfile)) as infile:
self.read_corpus(infile)
if self.text is not None:
self.read_corpus(StringIO(text))

if self.word_file is not None:
self.read_word_file(self.word_file)

def read_word_file(self, path: str, count: Optional[int] = None) -> bool:
"""
Read in a file of words to add to the model,
if not present, with the given count (default 1)
"""
if self.verbose:
print('Reading word file:', path, file=self.logfile)

if count is None:
count = self.word_file_count

new_word_count = token_count = 0
with open(path) as words_file:
for token in words_file:
token = token.strip()
if not token:
continue
if self.case == 'lower':
token = token.lower()
elif self.case == 'upper':
token = token.upper()
if self.norm:
token = self.norm_token(token)
token_count += 1
# Here, we could just add one, bumping all the word counts;
# or just add N for the missing ones. We do the latter.
if token not in self.grams_1:
self.grams_1[token] = count
new_word_count += 1

if self.verbose:
print(
f'{new_word_count} new unique words',
f'from {token_count} tokens,',
f'each with count {count}',
file=self.logfile,
)
return True

def norm_token(self, token: str) -> str:
"""
Remove excluded leading and trailing character categories from a token
"""
while len(token) and ud.category(token[0])[0] in ArpaBoLM.norm_exclude_categories:
token = token[1:]
while len(token) and ud.category(token[-1])[0] in ArpaBoLM.norm_exclude_categories:
token = token[:-1]
return token

def read_corpus(self, infile):
"""
Read in a text training corpus from a file handle
"""
if self.verbose:
print('Reading corpus file, breaking per newline.', file=self.logfile)

sent_count = 0
for line in infile:
if self.case == 'lower':
line = line.lower()
elif self.case == 'upper':
line = line.upper()
line = line.strip()
line = re.sub(r'(.+)\(.+\)$', r'\1', line) # trailing file name in transcripts

words = line.split()
if self.add_start:
words = ['<s>'] + words + ['</s>']
if self.norm:
words = [self.norm_token(w) for w in words]
words = [w for w in words if len(w)]
if not words:
continue
sent_count += 1
wc = len(words)
for j in range(wc):
w1 = words[j]
self.grams_1[w1] += 1
if j + 1 < wc:
w2 = words[j + 1]
self.grams_2[w1][w2] += 1
if j + 2 < wc:
w3 = words[j + 2]
self.grams_3[w1][w2][w3] += 1

if self.verbose:
print(f'{sent_count} sentences', file=self.logfile)

def compute(self) -> bool:
"""
Compute all the things (derived values).

If an n-gram is not present, the back-off is

P( word_N | word_{N-1}, word_{N-2}, ...., word_1 ) =
P( word_N | word_{N-1}, word_{N-2}, ...., word_2 )
* backoff-weight( word_{N-1} | word_{N-2}, ...., word_1 )

If the sequence

( word_{N-1}, word_{N-2}, ...., word_1 )

is also not listed, then the term

backoff-weight( word_{N-1} | word_{N-2}, ...., word_1 )

gets replaced with 1.0 and the recursion continues.

"""
if not self.grams_1:
sys.exit('No input?')
return False

# token counts
self.sum_1 = sum(self.grams_1.values())

# type counts
self.count_1 = len(self.grams_1)
for w1, gram2 in self.grams_2.items():
self.count_2 += len(gram2)
for w2 in gram2:
self.count_3 += len(self.grams_3[w1][w2])

# unigram probabilities
for gram1, count in self.grams_1.items():
self.prob_1[gram1] = count * self.deflator / self.sum_1

# unigram alphas
for w1 in self.grams_1:
sum_denom = 0.0
for w2, count in self.grams_2[w1].items():
sum_denom += self.prob_1[w2]
self.alpha_1[w1] = self.discount_mass / (1.0 - sum_denom)

# bigram probabilities
for w1, grams2 in self.grams_2.items():
for w2, count in grams2.items():
self.prob_2[w1][w2] = count * self.deflator / self.grams_1[w1]

# bigram alphas
for w1, grams2 in self.grams_2.items():
for w2, count in grams2.items():
sum_denom = 0.0
for w3 in self.grams_3[w1][w2]:
sum_denom += self.prob_2[w2][w3]
self.alpha_2[w1][w2] = self.discount_mass / (1.0 - sum_denom)
return True

def write_file(self, out_path: str) -> bool:
"""
Write out the ARPAbo model to a file path
"""
try:
with open(out_path, 'w') as outfile:
self.write(outfile)
except Exception as e:
return False
return True

def write(self, outfile: TextIO) -> bool:
"""
Write the ARPAbo model to a file handle
"""
if self.verbose:
print('Writing output file', file=self.logfile)

print(
'Corpus:',
f'{self.sent_count} sentences;',
f'{self.sum_1} words,',
f'{self.count_1} 1-grams,',
f'{self.count_2} 2-grams,',
f'{self.count_3} 3-grams,',
f'with fixed discount mass {self.discount_mass}',
'with simple normalization' if self.norm else '',
file=outfile,
)

print(file=outfile)
print('\\data\\', file=outfile)

print(f'ngram 1={self.count_1}', file=outfile)
if self.count_2:
print(f'ngram 2={self.count_2}', file=outfile)
if self.count_3:
print(f'ngram 3={self.count_3}', file=outfile)
print(file=outfile)

print('\\1-grams:', file=outfile)
for w1, prob in sorted(self.prob_1.items()):
log_prob = log(prob) / ArpaBoLM.log10
log_alpha = log(self.alpha_1[w1]) / ArpaBoLM.log10
print(f'{log_prob:6.4f} {w1} {log_alpha:6.4f}', file=outfile)

if self.count_2:
print(file=outfile)
print('\\2-grams:', file=outfile)
for w1, grams2 in sorted(self.prob_2.items()):
for w2, prob in sorted(grams2.items()):
log_prob = log(prob) / ArpaBoLM.log10
log_alpha = log(self.alpha_2[w1][w2]) / ArpaBoLM.log10
print(f'{log_prob:6.4f} {w1} {w2} {log_alpha:6.4f}',
file=outfile)
if self.count_3:
print(file=outfile)
print('\\3-grams:', file=outfile)
for w1, grams2 in sorted(self.grams_3.items()):
for w2, grams3 in sorted(grams2.items()):
for w3, count in sorted(grams3.items()): # type: ignore
prob = count * self.deflator / self.grams_2[w1][w2]
log_prob = log(prob) / ArpaBoLM.log10
print(f"{log_prob:6.4f} {w1} {w2} {w3}",
file=outfile)

print(file=outfile)
print('\\end\\', file=outfile)
if self.verbose:
print('Finished', date.today(), file=self.logfile)

return True

def main() -> None:
parser = argparse.ArgumentParser(description='Create a fixed-backoff ARPA LM')
parser.add_argument('-s', '--sentfile', type=str,
help='sentence transcripts in sphintrain style or one-per-line texts')
parser.add_argument('-t', '--text', type=str)
parser.add_argument('-w', '--word-file', type=str,
help='add words from this file with count -C')
parser.add_argument('-C', '--word-file-count', type=int, default=1,
help='word count set for each word in --word-file (default 1)')
parser.add_argument('-d', '--discount-mass', type=float,
help='fixed discount mass [0.0, 1.0]')
parser.add_argument('-c', '--case', type=str,
help='fold case (values: lower, upper)')
parser.add_argument('-a', '--add-start', action='store_true',
help='add <s> at start, and at end of lines </s> for -s or -t')
parser.add_argument('-n', '--norm', action='store_true',
help='do rudimentary token normalization / remove punctuation')
parser.add_argument('-o', '--output', type=str,
help='output to this file (default stdout)')
parser.add_argument('-v', '--verbose', action='store_true',
help='extra log info (to stderr)')

args = parser.parse_args()

if args.case and args.case not in ['lower', 'upper']:
sys.exit('--case must be lower or upper (if given)')

lm = ArpaBoLM(
sentfile=args.sentfile,
text=args.text,
word_file=args.word_file,
word_file_count=args.word_file_count,
discount_mass=args.discount_mass,
case=args.case,
add_start=args.add_start,
norm=args.norm,
verbose=args.verbose,
)
lm.compute()

if args.output:
outfile: TextIO = open(args.output, 'w')
else:
outfile = sys.stdout

lm.write(outfile)


if __name__ == '__main__':
main()
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ classifiers = [
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"License :: OSI Approved :: BSD License",
"Operating System :: OS Independent",
"Topic :: Multimedia :: Sound/Audio :: Speech",
Expand All @@ -36,6 +37,9 @@ Documentation = "https://pocketsphinx.readthedocs.io/en/latest/"
Repository = "https://github.com/cmusphinx/pocketsphinx.git"
Issues = "https://github.com/cmusphinx/pocketsphinx/issues"

[project.scripts]
pocketsphinx_lm = "pocketsphinx.lm:main"

[tool.cibuildwheel]
# Build a reduced selection of binaries as there are tons of them
build = [
Expand Down
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ commands =

[gh]
python =
3.13 = py313
3.12 = py312
3.11 = py311
3.10 = py310
Expand Down
Loading