Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added backtracking/backreference #10

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
vosk==0.3.45
wyoming==1.5.2
deepmerge==2.0
41 changes: 41 additions & 0 deletions tests/sentences_dir/en.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
sentences:
- testword1 testword2
# testword1 testword2

- in: testword1 {list} testword2
out: testword1 {list}
# testword1 listfirstitem
# testword1 listseconditem

- testword1 {list}
# testword1 listfirstitem
# testword1 listseconditem

- in: testword1 {list} <er2> testword2
out: testword1 {list} <er2>
# testword1 listfirstitem er2_1
# testword1 listfirstitem er2_2
# testword1 listfirstitem er2_3
# testword1 listseconditem er2_1
# testword1 listseconditem er2_2
# testword1 listseconditem er2_3
- in:
- testword1 {list} <er1>
- testword2 <er1> {list}
out: testword3 {list} <er1>
# testword3 listfirstitem er1_1
# testword3 listfirstitem er1_2
# testword3 listfirstitem er1_3
# testword3 listseconditem er1_1
# testword3 listseconditem er1_2
# testword3 listseconditem er1_3

lists:
list:
values:
- listfirstitem
- in: listseconditem
out: listseconditem
expansion_rules:
er1: "[er1_1|er1_2]"
er2: "(er2_1|er2_2|er2_3)"
67 changes: 58 additions & 9 deletions wyoming_vosk/sentences.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
import argparse
import functools
import itertools
import logging
import re
import sqlite3
import time
from collections import abc
from copy import deepcopy
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, Union

from deepmerge import always_merger

if TYPE_CHECKING:
from hassil.expression import Expression, Sentence
from hassil.intents import SlotList

_LOGGER = logging.getLogger()

LISTS_KEY = "lists"
EXP_RULES_KEY = "exp_rules"
CURR_EXP_RULE_KEY = "_curr_exp_rule"


@dataclass
class LanguageConfig:
Expand Down Expand Up @@ -202,14 +210,22 @@ def generate_sentences(sentences_yaml: Dict[str, Any], db_conn: sqlite3.Connecti
input_expression = hassil.parse_expression.parse_sentence(
input_template
)
for input_text, maybe_output_text in sample_expression_with_output(
for (
input_text,
maybe_output_text,
used_substitutions,
) in sample_expression_with_output(
input_expression,
slot_lists=slot_lists,
expansion_rules=expansion_rules,
):
substituted_output_text = __substitute(
output_text or maybe_output_text or input_text,
used_substitutions,
)
db_conn.execute(
"INSERT INTO sentences (input_text, output_text) VALUES (?, ?)",
(input_text, output_text or maybe_output_text or input_text),
(input_text, substituted_output_text),
)
words.update(w.strip() for w in input_text.split())
num_sentences += 1
Expand Down Expand Up @@ -246,7 +262,8 @@ def sample_expression_with_output(
expression: "Expression",
slot_lists: "Optional[Dict[str, SlotList]]" = None,
expansion_rules: "Optional[Dict[str, Sentence]]" = None,
) -> Iterable[Tuple[str, Optional[str]]]:
used_substitutions={LISTS_KEY: {}, EXP_RULES_KEY: {}},
) -> Iterable[Tuple[str, Optional[str], dict]]:
"""Sample possible text strings from an expression."""
from hassil.expression import (
ListReference,
Expand All @@ -259,24 +276,31 @@ def sample_expression_with_output(
from hassil.recognize import MissingListError, MissingRuleError
from hassil.util import normalize_whitespace

used_substitutions = deepcopy(used_substitutions)
if isinstance(expression, TextChunk):
chunk: TextChunk = expression
yield (chunk.original_text, chunk.original_text)
if CURR_EXP_RULE_KEY in used_substitutions:
curr_exp_rule_name = used_substitutions[CURR_EXP_RULE_KEY]
used_substitutions[EXP_RULES_KEY][curr_exp_rule_name] = chunk.original_text
yield (chunk.original_text, chunk.original_text, used_substitutions)
elif isinstance(expression, Sequence):
seq: Sequence = expression

if seq.type == SequenceType.ALTERNATIVE:
for item in seq.items:
yield from sample_expression_with_output(
item,
slot_lists,
expansion_rules,
used_substitutions,
)
elif seq.type == SequenceType.GROUP:
seq_sentences = map(
partial(
sample_expression_with_output,
slot_lists=slot_lists,
expansion_rules=expansion_rules,
used_substitutions=used_substitutions,
),
seq.items,
)
Expand All @@ -287,6 +311,9 @@ def sample_expression_with_output(
normalize_whitespace(
"".join(w[1] for w in sentence_words if w[1] is not None)
),
functools.reduce(
always_merger.merge, [w[-1] for w in sentence_words]
),
)
else:
raise ValueError(f"Unexpected sequence type: {seq}")
Expand All @@ -307,10 +334,15 @@ def sample_expression_with_output(
for text_value in text_list.values:
if text_value.value_out:
is_first_text = True
for input_text, output_text in sample_expression_with_output(
for (
input_text,
output_text,
used_substitutions,
) in sample_expression_with_output(
text_value.text_in,
slot_lists,
expansion_rules,
used_substitutions,
):
if is_first_text:
output_text = (
Expand All @@ -322,12 +354,19 @@ def sample_expression_with_output(
else:
output_text = None

yield (input_text, output_text)
used_substitutions[LISTS_KEY][
list_ref.list_name
] = text_value.value_out
yield (input_text, output_text, used_substitutions)
else:
used_substitutions[LISTS_KEY][
list_ref.list_name
] = text_value.value_out
yield from sample_expression_with_output(
text_value.text_in,
slot_lists,
expansion_rules,
used_substitutions,
)
else:
raise ValueError(f"Unexpected slot list type: {slot_list}")
Expand All @@ -338,15 +377,25 @@ def sample_expression_with_output(
raise MissingRuleError(f"Missing expansion rule <{rule_ref.rule_name}>")

rule_body = expansion_rules[rule_ref.rule_name]
used_substitutions.update({CURR_EXP_RULE_KEY: rule_ref.rule_name})
yield from sample_expression_with_output(
rule_body,
slot_lists,
expansion_rules,
rule_body, slot_lists, expansion_rules, used_substitutions
)
else:
raise ValueError(f"Unexpected expression: {expression}")


def __substitute(
out_sentence: str, used_substitutions: dict[str, dict[str, str]]
) -> str:
"""Substitutes templates with text used to generate input text"""
for list_name, list_item in used_substitutions[LISTS_KEY].items():
out_sentence = out_sentence.replace("{" + list_name + "}", list_item, 1)
for exp_name, exp_item in used_substitutions[EXP_RULES_KEY].items():
out_sentence = out_sentence.replace("<" + exp_name + ">", exp_item, 1)
return out_sentence


def correct_sentence(
text: str, config: LanguageConfig, score_cutoff: float = 0.0
) -> str:
Expand Down