Skip to content

Commit

Permalink
Merge pull request #178 from home-assistant/synesthesiam-20241120-res…
Browse files Browse the repository at this point in the history
…trict-trie

Make trie more restrictive
  • Loading branch information
synesthesiam authored Nov 20, 2024
2 parents f80c0ed + d0f0d5e commit c5f9a76
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 7 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog

## 2.0.3

- Make trie more restrictive (`two` will not match `t|wo`)

## 2.0.2

- Require `unicode-rbnf>=2.1` which includes important bugfixes
Expand Down
2 changes: 1 addition & 1 deletion hassil/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.0.2
2.0.3
10 changes: 4 additions & 6 deletions hassil/trie.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ def insert(self, text: str, value: Any) -> None:

current_children = current_node.children

def find(self, text: str) -> Iterable[Tuple[int, str, Any]]:
def find(self, text: str, unique: bool = True) -> Iterable[Tuple[int, str, Any]]:
"""Yield (end_pos, text, value) pairs of all words found in the string."""
q = deque([(self.roots, 0)])
q = deque([(self.roots, i) for i in range(len(text))])
visited = set()

while q:
Expand All @@ -60,15 +60,13 @@ def find(self, text: str) -> Iterable[Tuple[int, str, Any]]:

current_char = text[current_position]

if current_position < len(text):
q.append((current_children, current_position + 1))

node = current_children.get(current_char)
if (node is not None) and (node.id not in visited):
visited.add(node.id)

if node.text is not None:
# End is one past the current position
if unique:
visited.add(node.id)
yield (current_position + 1, node.text, node.value)

if node.children and (current_position < len(text)):
Expand Down
13 changes: 13 additions & 0 deletions tests/test_trie.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,19 @@ def test_insert_find() -> None:
(45, "twenty two", 22),
]

# Without unique, *[two]* and twenty [two] will return 2
assert list(
trie.find("set to 1, then *two*, then finally twenty two please!", unique=False)
) == [
(8, "1", 1),
(19, "two", 2),
(45, "two", 2),
(45, "twenty two", 22),
]

# Test a character in between
assert not list(trie.find("tw|o"))

# Test non-existent value
assert not list(trie.find("three"))

Expand Down

0 comments on commit c5f9a76

Please sign in to comment.