diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..b8d4869 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,9 @@ +# Changelog + +## 1.3.0 + +- Add number to word generation using [unicode-rbnf](https://github.com/rhasspy/unicode-rbnf) for range lists + +## 1.2.5 + +- Fix degenerate wildcard case diff --git a/hassil/VERSION b/hassil/VERSION index c813fe1..f0bb29e 100644 --- a/hassil/VERSION +++ b/hassil/VERSION @@ -1 +1 @@ -1.2.5 +1.3.0 diff --git a/hassil/intents.py b/hassil/intents.py index 9463807..7973bc7 100644 --- a/hassil/intents.py +++ b/hassil/intents.py @@ -80,11 +80,16 @@ class RangeSlotList(SlotList): stop: int step: int = 1 type: RangeType = RangeType.NUMBER + digits: bool = True + words: bool = True + words_language: Optional[str] = None + words_ruleset: Optional[str] = None def __post_init__(self): """Validate number range""" assert self.start < self.stop, "start must be less than stop" assert self.step > 0, "step must be positive" + assert self.digits or self.words, "must have digits, words, or both" @dataclass diff --git a/hassil/recognize.py b/hassil/recognize.py index cbef1e7..08a9598 100644 --- a/hassil/recognize.py +++ b/hassil/recognize.py @@ -2,11 +2,14 @@ import collections.abc import itertools +import logging import re from abc import ABC from dataclasses import dataclass, field from typing import Any, Dict, Iterable, List, Optional +from unicode_rbnf import RbnfEngine + from .expression import ( Expression, ListReference, @@ -28,11 +31,17 @@ from .util import normalize_text, normalize_whitespace NUMBER_START = re.compile(r"^(\s*-?[0-9]+)") +BREAK_WORDS_TABLE = str.maketrans("-_", " ") PUNCTUATION = re.compile(r"[.。,,?¿?؟!¡!;;::]+") WHITESPACE = re.compile(r"\s+") MISSING_ENTITY = "" +_LOGGER = logging.getLogger() + +# lang -> engine +_ENGINE_CACHE: Dict[str, RbnfEngine] = {} + class HassilError(Exception): """Base class for hassil errors""" @@ -112,6 +121,10 @@ class MatchSettings: """True if whitespace should be ignored during matching.""" allow_unmatched_entities: bool = False + """True if unmatched entities are kept for better error messages (slower).""" + + language: Optional[str] = None + """Optional language to use when converting digits to words.""" @dataclass @@ -231,6 +244,7 @@ def recognize( intent_context: Optional[Dict[str, Any]] = None, default_response: Optional[str] = "default", allow_unmatched_entities: bool = False, + language: Optional[str] = None, ) -> Optional[RecognizeResult]: """Return the first match of input text/words against a collection of intents. @@ -242,6 +256,7 @@ def recognize( intent_context: Slot values to use when not found in text default_response: Response key to use if not set in intent allow_unmatched_entities: True if entity values outside slot lists are allowed (slower) + language: Optional language to use when converting digits to words Returns the first result. If allow_unmatched_entities is True, you should check for unmatched entities. @@ -255,6 +270,7 @@ def recognize( intent_context=intent_context, default_response=default_response, allow_unmatched_entities=allow_unmatched_entities, + language=language, ): return result @@ -270,6 +286,7 @@ def recognize_all( intent_context: Optional[Dict[str, Any]] = None, default_response: Optional[str] = "default", allow_unmatched_entities: bool = False, + language: Optional[str] = None, ) -> Iterable[RecognizeResult]: """Return all matches for input text/words against a collection of intents. @@ -325,6 +342,7 @@ def recognize_all( expansion_rules=expansion_rules, ignore_whitespace=intents.settings.ignore_whitespace, allow_unmatched_entities=allow_unmatched_entities, + language=language, ) # Check sentence against each intent. @@ -391,6 +409,7 @@ def recognize_all( }, ignore_whitespace=settings.ignore_whitespace, allow_unmatched_entities=allow_unmatched_entities, + language=language, ) # Check each sentence template @@ -560,6 +579,7 @@ def is_match( intent_context: Optional[Dict[str, Any]] = None, ignore_whitespace: bool = False, allow_unmatched_entities: bool = False, + language: Optional[str] = None, ) -> Optional[MatchContext]: """Return the first match of input text/words against a sentence expression.""" text = normalize_text(text).strip() @@ -587,6 +607,7 @@ def is_match( expansion_rules=expansion_rules, ignore_whitespace=ignore_whitespace, allow_unmatched_entities=allow_unmatched_entities, + language=language, ) match_context = MatchContext( @@ -740,6 +761,11 @@ def match_expression( context_text = context_text.lstrip() context_starts_with = context_text.startswith(chunk_text) + if not context_starts_with: + # Try breaking words apart + context_text = context_text.translate(BREAK_WORDS_TABLE) + context_starts_with = context_text.startswith(chunk_text) + if context_starts_with: context_text = context_text[len(chunk_text) :] yield MatchContext( @@ -899,10 +925,16 @@ def match_expression( # List that represents a number range. # Numbers must currently be digits ("1" not "one"). range_list: RangeSlotList = slot_list + + # Look for digits at the start of the incoming text number_match = NUMBER_START.match(context.text) - if number_match is not None: + + digits_match = False + if range_list.digits and (number_match is not None): number_text = number_match[1] word_number = int(number_text) + + # Check if number is within range of our list if range_list.step == 1: # Unit step in_range = range_list.start <= word_number <= range_list.stop @@ -913,6 +945,8 @@ def match_expression( ) if in_range: + # Number is in range + digits_match = True entities = context.entities + [ MatchEntity( name=list_ref.slot_name, @@ -945,7 +979,60 @@ def match_expression( ) ], ) - elif settings.allow_unmatched_entities: + + # Only check number words if: + # 1. Words are enabled for this list + # 2. We didn't already match digits + # 3. the incoming text doesn't start with digits + words_match: bool = False + if range_list.words and (not digits_match) and (number_match is None): + words_language = range_list.words_language or settings.language + if words_language: + # Load number formatting engine + engine = _ENGINE_CACHE.get(words_language) + if engine is None: + engine = RbnfEngine.for_language(words_language) + _ENGINE_CACHE[words_language] = engine + + assert engine is not None + + for word_number in range( + range_list.start, range_list.stop + 1, range_list.step + ): + number_words = engine.format_number( + word_number, ruleset_name=range_list.words_ruleset + ).translate(BREAK_WORDS_TABLE) + + entities = context.entities + [ + MatchEntity( + name=list_ref.slot_name, + value=word_number, + text=number_words, + ) + ] + yield from match_expression( + settings, + MatchContext( + text=context.text, + entities=entities, + # Copy over + intent_context=context.intent_context, + is_start_of_word=context.is_start_of_word, + unmatched_entities=context.unmatched_entities, + ), + TextChunk(number_words), + ) + else: + _LOGGER.warning( + "No language set, so cannot convert %s digits to words", + list_ref.slot_name, + ) + + if ( + (not digits_match) + and (not words_match) + and settings.allow_unmatched_entities + ): # Report not a number yield MatchContext( # Copy over diff --git a/hassil/sample.py b/hassil/sample.py index 67a582d..7e83027 100644 --- a/hassil/sample.py +++ b/hassil/sample.py @@ -9,6 +9,7 @@ from typing import Dict, Iterable, Optional, Set, Tuple import yaml +from unicode_rbnf import RbnfEngine from .expression import ( Expression, @@ -25,6 +26,9 @@ _LOGGER = logging.getLogger("hassil.sample") +# lang -> engine +_ENGINE_CACHE: Dict[str, RbnfEngine] = {} + def sample_intents( intents: Intents, @@ -32,6 +36,7 @@ def sample_intents( expansion_rules: Optional[Dict[str, Sentence]] = None, max_sentences_per_intent: Optional[int] = None, intent_names: Optional[Set[str]] = None, + language: Optional[str] = None, ) -> Iterable[Tuple[str, str]]: """Sample text strings for sentences from intents.""" if slot_lists is None: @@ -63,6 +68,7 @@ def sample_intents( intent_sentence, slot_lists, expansion_rules, + language=language, ) for sentence_text in sentence_texts: yield (intent_name, sentence_text) @@ -85,6 +91,7 @@ def sample_expression( expression: Expression, slot_lists: Optional[Dict[str, SlotList]] = None, expansion_rules: Optional[Dict[str, Sentence]] = None, + language: Optional[str] = None, ) -> Iterable[str]: """Sample possible text strings from an expression.""" if isinstance(expression, TextChunk): @@ -98,6 +105,7 @@ def sample_expression( item, slot_lists, expansion_rules, + language=language, ) elif seq.type == SequenceType.GROUP: seq_sentences = map( @@ -105,6 +113,7 @@ def sample_expression( sample_expression, slot_lists=slot_lists, expansion_rules=expansion_rules, + language=language, ), seq.items, ) @@ -132,13 +141,39 @@ def sample_expression( text_value.text_in, slot_lists, expansion_rules, + language=language, ) elif isinstance(slot_list, RangeSlotList): range_list: RangeSlotList = slot_list - number_strs = map( - str, range(range_list.start, range_list.stop + 1, range_list.step) - ) - yield from number_strs + + if range_list.digits: + number_strs = map( + str, range(range_list.start, range_list.stop + 1, range_list.step) + ) + yield from number_strs + + if range_list.words: + words_language = range_list.words_language or language + if words_language: + engine = _ENGINE_CACHE.get(words_language) + if engine is None: + engine = RbnfEngine.for_language(words_language) + _ENGINE_CACHE[words_language] = engine + + assert engine is not None + + # digits -> words + for word_number in range( + range_list.start, range_list.stop + 1, range_list.step + ): + yield engine.format_number( + word_number, ruleset_name=range_list.words_ruleset + ) + else: + _LOGGER.warning( + "No language set, so cannot convert %s digits to words", + list_ref.slot_name, + ) else: raise ValueError(f"Unexpected slot list type: {slot_list}") @@ -153,6 +188,7 @@ def sample_expression( rule_body, slot_lists, expansion_rules, + language=language, ) else: raise ValueError(f"Unexpected expression: {expression}") diff --git a/requirements.txt b/requirements.txt index 1910948..ab6339c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ PyYAML>=6.0,<7 +unicode-rbnf==1.0.0 diff --git a/tests/test_recognize.py b/tests/test_recognize.py index e82f8a8..ec8166c 100644 --- a/tests/test_recognize.py +++ b/tests/test_recognize.py @@ -149,6 +149,25 @@ def test_brightness_area(intents, slot_lists): assert result.entities["name"].value == "all" +# pylint: disable=redefined-outer-name +def test_brightness_area_words(intents, slot_lists): + result = recognize( + "set brightness in the living room to forty-two percent", + intents, + slot_lists=slot_lists, + language="en", + ) + assert result is not None + assert result.intent.name == "SetBrightness" + + assert result.entities["area"].value == "area.living_room" + assert result.entities["brightness_pct"].value == 42 + + # From YAML + assert result.entities["domain"].value == "light" + assert result.entities["name"].value == "all" + + # pylint: disable=redefined-outer-name def test_brightness_name(intents, slot_lists): result = recognize( diff --git a/tests/test_sample.py b/tests/test_sample.py index a3dae9e..648cdd4 100644 --- a/tests/test_sample.py +++ b/tests/test_sample.py @@ -63,6 +63,33 @@ def test_list_range(): } +def test_list_range_missing_language(): + sentence = parse_sentence("run test {num}") + num_list = RangeSlotList(1, 3, words=True) + + # Range slot digits cannot be converted to words without a language available. + assert set(sample_expression(sentence, slot_lists={"num": num_list})) == { + "run test 1", + "run test 2", + "run test 3", + } + + +def test_list_range_words(): + sentence = parse_sentence("run test {num}") + num_list = RangeSlotList(1, 3, words=True) + assert set( + sample_expression(sentence, slot_lists={"num": num_list}, language="en") + ) == { + "run test 1", + "run test one", + "run test 2", + "run test two", + "run test 3", + "run test three", + } + + def test_rule(): sentence = parse_sentence("turn off ") assert set(