Skip to content

Commit

Permalink
Merge pull request #78 from home-assistant/synesthesiam-20231031-rbnf
Browse files Browse the repository at this point in the history
Add number to word generation for range lists
  • Loading branch information
synesthesiam authored Oct 31, 2023
2 parents 86e4c9b + 1409ea4 commit f47b6f6
Show file tree
Hide file tree
Showing 8 changed files with 191 additions and 7 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Changelog

## 1.3.0

- Add number to word generation using [unicode-rbnf](https://github.com/rhasspy/unicode-rbnf) for range lists

## 1.2.5

- Fix degenerate wildcard case
2 changes: 1 addition & 1 deletion hassil/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.2.5
1.3.0
5 changes: 5 additions & 0 deletions hassil/intents.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,16 @@ class RangeSlotList(SlotList):
stop: int
step: int = 1
type: RangeType = RangeType.NUMBER
digits: bool = True
words: bool = True
words_language: Optional[str] = None
words_ruleset: Optional[str] = None

def __post_init__(self):
"""Validate number range"""
assert self.start < self.stop, "start must be less than stop"
assert self.step > 0, "step must be positive"
assert self.digits or self.words, "must have digits, words, or both"


@dataclass
Expand Down
91 changes: 89 additions & 2 deletions hassil/recognize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

import collections.abc
import itertools
import logging
import re
from abc import ABC
from dataclasses import dataclass, field
from typing import Any, Dict, Iterable, List, Optional

from unicode_rbnf import RbnfEngine

from .expression import (
Expression,
ListReference,
Expand All @@ -28,11 +31,17 @@
from .util import normalize_text, normalize_whitespace

NUMBER_START = re.compile(r"^(\s*-?[0-9]+)")
BREAK_WORDS_TABLE = str.maketrans("-_", " ")
PUNCTUATION = re.compile(r"[.。,,?¿?؟!¡!;;::]+")
WHITESPACE = re.compile(r"\s+")

MISSING_ENTITY = "<missing>"

_LOGGER = logging.getLogger()

# lang -> engine
_ENGINE_CACHE: Dict[str, RbnfEngine] = {}


class HassilError(Exception):
"""Base class for hassil errors"""
Expand Down Expand Up @@ -112,6 +121,10 @@ class MatchSettings:
"""True if whitespace should be ignored during matching."""

allow_unmatched_entities: bool = False
"""True if unmatched entities are kept for better error messages (slower)."""

language: Optional[str] = None
"""Optional language to use when converting digits to words."""


@dataclass
Expand Down Expand Up @@ -231,6 +244,7 @@ def recognize(
intent_context: Optional[Dict[str, Any]] = None,
default_response: Optional[str] = "default",
allow_unmatched_entities: bool = False,
language: Optional[str] = None,
) -> Optional[RecognizeResult]:
"""Return the first match of input text/words against a collection of intents.
Expand All @@ -242,6 +256,7 @@ def recognize(
intent_context: Slot values to use when not found in text
default_response: Response key to use if not set in intent
allow_unmatched_entities: True if entity values outside slot lists are allowed (slower)
language: Optional language to use when converting digits to words
Returns the first result.
If allow_unmatched_entities is True, you should check for unmatched entities.
Expand All @@ -255,6 +270,7 @@ def recognize(
intent_context=intent_context,
default_response=default_response,
allow_unmatched_entities=allow_unmatched_entities,
language=language,
):
return result

Expand All @@ -270,6 +286,7 @@ def recognize_all(
intent_context: Optional[Dict[str, Any]] = None,
default_response: Optional[str] = "default",
allow_unmatched_entities: bool = False,
language: Optional[str] = None,
) -> Iterable[RecognizeResult]:
"""Return all matches for input text/words against a collection of intents.
Expand Down Expand Up @@ -325,6 +342,7 @@ def recognize_all(
expansion_rules=expansion_rules,
ignore_whitespace=intents.settings.ignore_whitespace,
allow_unmatched_entities=allow_unmatched_entities,
language=language,
)

# Check sentence against each intent.
Expand Down Expand Up @@ -391,6 +409,7 @@ def recognize_all(
},
ignore_whitespace=settings.ignore_whitespace,
allow_unmatched_entities=allow_unmatched_entities,
language=language,
)

# Check each sentence template
Expand Down Expand Up @@ -560,6 +579,7 @@ def is_match(
intent_context: Optional[Dict[str, Any]] = None,
ignore_whitespace: bool = False,
allow_unmatched_entities: bool = False,
language: Optional[str] = None,
) -> Optional[MatchContext]:
"""Return the first match of input text/words against a sentence expression."""
text = normalize_text(text).strip()
Expand Down Expand Up @@ -587,6 +607,7 @@ def is_match(
expansion_rules=expansion_rules,
ignore_whitespace=ignore_whitespace,
allow_unmatched_entities=allow_unmatched_entities,
language=language,
)

match_context = MatchContext(
Expand Down Expand Up @@ -740,6 +761,11 @@ def match_expression(
context_text = context_text.lstrip()
context_starts_with = context_text.startswith(chunk_text)

if not context_starts_with:
# Try breaking words apart
context_text = context_text.translate(BREAK_WORDS_TABLE)
context_starts_with = context_text.startswith(chunk_text)

if context_starts_with:
context_text = context_text[len(chunk_text) :]
yield MatchContext(
Expand Down Expand Up @@ -899,10 +925,16 @@ def match_expression(
# List that represents a number range.
# Numbers must currently be digits ("1" not "one").
range_list: RangeSlotList = slot_list

# Look for digits at the start of the incoming text
number_match = NUMBER_START.match(context.text)
if number_match is not None:

digits_match = False
if range_list.digits and (number_match is not None):
number_text = number_match[1]
word_number = int(number_text)

# Check if number is within range of our list
if range_list.step == 1:
# Unit step
in_range = range_list.start <= word_number <= range_list.stop
Expand All @@ -913,6 +945,8 @@ def match_expression(
)

if in_range:
# Number is in range
digits_match = True
entities = context.entities + [
MatchEntity(
name=list_ref.slot_name,
Expand Down Expand Up @@ -945,7 +979,60 @@ def match_expression(
)
],
)
elif settings.allow_unmatched_entities:

# Only check number words if:
# 1. Words are enabled for this list
# 2. We didn't already match digits
# 3. the incoming text doesn't start with digits
words_match: bool = False
if range_list.words and (not digits_match) and (number_match is None):
words_language = range_list.words_language or settings.language
if words_language:
# Load number formatting engine
engine = _ENGINE_CACHE.get(words_language)
if engine is None:
engine = RbnfEngine.for_language(words_language)
_ENGINE_CACHE[words_language] = engine

assert engine is not None

for word_number in range(
range_list.start, range_list.stop + 1, range_list.step
):
number_words = engine.format_number(
word_number, ruleset_name=range_list.words_ruleset
).translate(BREAK_WORDS_TABLE)

entities = context.entities + [
MatchEntity(
name=list_ref.slot_name,
value=word_number,
text=number_words,
)
]
yield from match_expression(
settings,
MatchContext(
text=context.text,
entities=entities,
# Copy over
intent_context=context.intent_context,
is_start_of_word=context.is_start_of_word,
unmatched_entities=context.unmatched_entities,
),
TextChunk(number_words),
)
else:
_LOGGER.warning(
"No language set, so cannot convert %s digits to words",
list_ref.slot_name,
)

if (
(not digits_match)
and (not words_match)
and settings.allow_unmatched_entities
):
# Report not a number
yield MatchContext(
# Copy over
Expand Down
44 changes: 40 additions & 4 deletions hassil/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Dict, Iterable, Optional, Set, Tuple

import yaml
from unicode_rbnf import RbnfEngine

from .expression import (
Expression,
Expand All @@ -25,13 +26,17 @@

_LOGGER = logging.getLogger("hassil.sample")

# lang -> engine
_ENGINE_CACHE: Dict[str, RbnfEngine] = {}


def sample_intents(
intents: Intents,
slot_lists: Optional[Dict[str, SlotList]] = None,
expansion_rules: Optional[Dict[str, Sentence]] = None,
max_sentences_per_intent: Optional[int] = None,
intent_names: Optional[Set[str]] = None,
language: Optional[str] = None,
) -> Iterable[Tuple[str, str]]:
"""Sample text strings for sentences from intents."""
if slot_lists is None:
Expand Down Expand Up @@ -63,6 +68,7 @@ def sample_intents(
intent_sentence,
slot_lists,
expansion_rules,
language=language,
)
for sentence_text in sentence_texts:
yield (intent_name, sentence_text)
Expand All @@ -85,6 +91,7 @@ def sample_expression(
expression: Expression,
slot_lists: Optional[Dict[str, SlotList]] = None,
expansion_rules: Optional[Dict[str, Sentence]] = None,
language: Optional[str] = None,
) -> Iterable[str]:
"""Sample possible text strings from an expression."""
if isinstance(expression, TextChunk):
Expand All @@ -98,13 +105,15 @@ def sample_expression(
item,
slot_lists,
expansion_rules,
language=language,
)
elif seq.type == SequenceType.GROUP:
seq_sentences = map(
partial(
sample_expression,
slot_lists=slot_lists,
expansion_rules=expansion_rules,
language=language,
),
seq.items,
)
Expand Down Expand Up @@ -132,13 +141,39 @@ def sample_expression(
text_value.text_in,
slot_lists,
expansion_rules,
language=language,
)
elif isinstance(slot_list, RangeSlotList):
range_list: RangeSlotList = slot_list
number_strs = map(
str, range(range_list.start, range_list.stop + 1, range_list.step)
)
yield from number_strs

if range_list.digits:
number_strs = map(
str, range(range_list.start, range_list.stop + 1, range_list.step)
)
yield from number_strs

if range_list.words:
words_language = range_list.words_language or language
if words_language:
engine = _ENGINE_CACHE.get(words_language)
if engine is None:
engine = RbnfEngine.for_language(words_language)
_ENGINE_CACHE[words_language] = engine

assert engine is not None

# digits -> words
for word_number in range(
range_list.start, range_list.stop + 1, range_list.step
):
yield engine.format_number(
word_number, ruleset_name=range_list.words_ruleset
)
else:
_LOGGER.warning(
"No language set, so cannot convert %s digits to words",
list_ref.slot_name,
)

else:
raise ValueError(f"Unexpected slot list type: {slot_list}")
Expand All @@ -153,6 +188,7 @@ def sample_expression(
rule_body,
slot_lists,
expansion_rules,
language=language,
)
else:
raise ValueError(f"Unexpected expression: {expression}")
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
PyYAML>=6.0,<7
unicode-rbnf==1.0.0
19 changes: 19 additions & 0 deletions tests/test_recognize.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,25 @@ def test_brightness_area(intents, slot_lists):
assert result.entities["name"].value == "all"


# pylint: disable=redefined-outer-name
def test_brightness_area_words(intents, slot_lists):
result = recognize(
"set brightness in the living room to forty-two percent",
intents,
slot_lists=slot_lists,
language="en",
)
assert result is not None
assert result.intent.name == "SetBrightness"

assert result.entities["area"].value == "area.living_room"
assert result.entities["brightness_pct"].value == 42

# From YAML
assert result.entities["domain"].value == "light"
assert result.entities["name"].value == "all"


# pylint: disable=redefined-outer-name
def test_brightness_name(intents, slot_lists):
result = recognize(
Expand Down
Loading

0 comments on commit f47b6f6

Please sign in to comment.