Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
jamescalam committed Dec 12, 2023
1 parent 4570bc5 commit 5aa6c77
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 27 deletions.
16 changes: 4 additions & 12 deletions semantic_router/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,7 @@ def _query(self, text: str, top_k: int = 5):
scores, idx = top_scores(sim, top_k)
# get the utterance categories (route names)
routes = self.categories[idx] if self.categories is not None else []
return [
{"route": d, "score": s.item()} for d, s in zip(routes, scores)
]
return [{"route": d, "score": s.item()} for d, s in zip(routes, scores)]
else:
return []

Expand All @@ -92,9 +90,7 @@ def _semantic_classify(self, query_results: list[dict]) -> tuple[str, list[float
scores_by_class[route] = [score]

# Calculate total score for each class
total_scores = {
route: sum(scores) for route, scores in scores_by_class.items()
}
total_scores = {route: sum(scores) for route, scores in scores_by_class.items()}
top_class = max(total_scores, key=lambda x: total_scores[x], default=None)

# Return the top class and its associated scores
Expand Down Expand Up @@ -201,9 +197,7 @@ def _query(self, text: str, top_k: int = 5):
scores = total_sim[idx]
# get the utterance categories (route names)
routes = self.categories[idx] if self.categories is not None else []
return [
{"route": d, "score": s.item()} for d, s in zip(routes, scores)
]
return [{"route": d, "score": s.item()} for d, s in zip(routes, scores)]
else:
return []

Expand All @@ -224,9 +218,7 @@ def _semantic_classify(self, query_results: list[dict]) -> tuple[str, list[float
scores_by_class[route] = [score]

# Calculate total score for each class
total_scores = {
route: sum(scores) for route, scores in scores_by_class.items()
}
total_scores = {route: sum(scores) for route, scores in scores_by_class.items()}
top_class = max(total_scores, key=lambda x: total_scores[x], default=None)

# Return the top class and its associated scores
Expand Down
16 changes: 4 additions & 12 deletions tests/unit/test_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,7 @@ def test_failover_score_threshold(self, base_encoder):

class TestHybridRouteLayer:
def test_initialization(self, openai_encoder, routes):
route_layer = HybridRouteLayer(
encoder=openai_encoder, routes=routes
)
route_layer = HybridRouteLayer(encoder=openai_encoder, routes=routes)
assert route_layer.score_threshold == 0.82
assert len(route_layer.index) == 5
assert len(set(route_layer.categories)) == 2
Expand All @@ -146,9 +144,7 @@ def test_add_multiple_routes(self, openai_encoder, routes):
assert len(set(route_layer.categories)) == 2

def test_query_and_classification(self, openai_encoder, routes):
route_layer = HybridRouteLayer(
encoder=openai_encoder, routes=routes
)
route_layer = HybridRouteLayer(encoder=openai_encoder, routes=routes)
query_result = route_layer("Hello")
assert query_result in ["Route 1", "Route 2"]

Expand All @@ -157,9 +153,7 @@ def test_query_with_no_index(self, openai_encoder):
assert route_layer("Anything") is None

def test_semantic_classify(self, openai_encoder, routes):
route_layer = HybridRouteLayer(
encoder=openai_encoder, routes=routes
)
route_layer = HybridRouteLayer(encoder=openai_encoder, routes=routes)
classification, score = route_layer._semantic_classify(
[
{"route": "Route 1", "score": 0.9},
Expand All @@ -170,9 +164,7 @@ def test_semantic_classify(self, openai_encoder, routes):
assert score == [0.9]

def test_semantic_classify_multiple_routes(self, openai_encoder, routes):
route_layer = HybridRouteLayer(
encoder=openai_encoder, routes=routes
)
route_layer = HybridRouteLayer(encoder=openai_encoder, routes=routes)
classification, score = route_layer._semantic_classify(
[
{"route": "Route 1", "score": 0.9},
Expand Down
4 changes: 1 addition & 3 deletions tests/unit/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,7 @@ def test_semanticspace_initialization(self):
assert semantic_space.routes == []

def test_semanticspace_add_route(self):
route = Route(
name="test", utterances=["hello", "hi"], description="greeting"
)
route = Route(name="test", utterances=["hello", "hi"], description="greeting")
semantic_space = SemanticSpace()
semantic_space.add(route)

Expand Down

0 comments on commit 5aa6c77

Please sign in to comment.