Skip to content

Commit

Permalink
Merge pull request #436 from aurelio-labs/update_route_thresholds_man…
Browse files Browse the repository at this point in the history
…ually

feat: Update route thresholds manually
  • Loading branch information
jamescalam authored Oct 1, 2024
2 parents 8600d9b + e6f1414 commit c7d82c5
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 2 deletions.
36 changes: 34 additions & 2 deletions semantic_router/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,8 +435,40 @@ def add(self, route: Route):
def list_route_names(self) -> List[str]:
return [route.name for route in self.routes]

def update(self, route_name: str, utterances: List[str]):
raise NotImplementedError("This method has not yet been implemented.")
def update(
self,
name: str,
threshold: Optional[float] = None,
utterances: Optional[List[str]] = None,
):
"""Updates the route specified in name. Allows the update of
threshold and/or utterances. If no values are provided via the
threshold or utterances parameters, those fields are not updated.
If neither field is provided raises a ValueError.
The name must exist within the local RouteLayer, if not a
KeyError will be raised.
"""

if threshold is None and utterances is None:
raise ValueError(
"At least one of 'threshold' or 'utterances' must be provided."
)
if utterances:
raise NotImplementedError(
"The update method cannot be used for updating utterances yet."
)

route = self.get(name)
if route:
if threshold:
old_threshold = route.score_threshold
route.score_threshold = threshold
logger.info(
f"Updated threshold for route '{route.name}' from {old_threshold} to {threshold}"
)
else:
raise ValueError(f"Route '{name}' not found. Nothing updated.")

def delete(self, route_name: str):
"""Deletes a route given a specific route name.
Expand Down
39 changes: 39 additions & 0 deletions tests/unit/test_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,45 @@ def test_refresh_routes_not_implemented(self, openai_encoder, routes, index_cls)
):
route_layer._refresh_routes()

def test_update_threshold(self, openai_encoder, routes, index_cls):
index = init_index(index_cls)
route_layer = RouteLayer(encoder=openai_encoder, routes=routes, index=index)
route_name = "Route 1"
new_threshold = 0.8
route_layer.update(name=route_name, threshold=new_threshold)
updated_route = route_layer.get(route_name)
assert (
updated_route.score_threshold == new_threshold
), f"Expected threshold to be updated to {new_threshold}, but got {updated_route.score_threshold}"

def test_update_non_existent_route(self, openai_encoder, routes, index_cls):
index = init_index(index_cls)
route_layer = RouteLayer(encoder=openai_encoder, routes=routes, index=index)
non_existent_route = "Non-existent Route"
with pytest.raises(
ValueError,
match=f"Route '{non_existent_route}' not found. Nothing updated.",
):
route_layer.update(name=non_existent_route, threshold=0.7)

def test_update_without_parameters(self, openai_encoder, routes, index_cls):
index = init_index(index_cls)
route_layer = RouteLayer(encoder=openai_encoder, routes=routes, index=index)
with pytest.raises(
ValueError,
match="At least one of 'threshold' or 'utterances' must be provided.",
):
route_layer.update(name="Route 1")

def test_update_utterances_not_implemented(self, openai_encoder, routes, index_cls):
index = init_index(index_cls)
route_layer = RouteLayer(encoder=openai_encoder, routes=routes, index=index)
with pytest.raises(
NotImplementedError,
match="The update method cannot be used for updating utterances yet.",
):
route_layer.update(name="Route 1", utterances=["New utterance"])


class TestLayerFit:
def test_eval(self, openai_encoder, routes, test_data):
Expand Down

0 comments on commit c7d82c5

Please sign in to comment.