Skip to content

Commit

Permalink
Added missing fuzzy match to OpenAI classifiers
Browse files Browse the repository at this point in the history
  • Loading branch information
x-tabdeveloping committed Nov 21, 2024
1 parent 95a5611 commit 8669f8d
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
9 changes: 8 additions & 1 deletion stormtrooper/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,16 @@ def fit(self, X: Optional[Iterable[str]], y: Iterable[str]):
self.n_classes = len(self.classes_)
return self

def fuzzy_match_label(self, label: str) -> str:
if label not in self.classes_:
label, _ = process.extractOne(label, self.classes_)
return label

def get_user_prompt(self, text: str) -> str:
if getattr(self, "classes_", None) is None:
raise NotFittedError("No class labels have been learnt yet, fit the model.")
raise NotFittedError(
"No class labels have been learnt yet, fit the model."
)
if getattr(self, "examples_", None) is not None:
text_examples = []
for label, examples in self.examples_.items():
Expand Down
10 changes: 8 additions & 2 deletions stormtrooper/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,10 @@ async def predict_one_async(self, text: str) -> str:
return response.choices[0].message.content

def predict_one(self, text: str) -> str:
return asyncio.run(self.predict_one_async(text))
label = asyncio.run(self.predict_one_async(text))
if self.fuzzy_match:
label = self.fuzzy_match_label(label)
return label

async def predict_async(self, X: Iterable[str]) -> np.ndarray:
if self.classes_ is None:
Expand All @@ -101,4 +104,7 @@ def predict(self, X: Iterable[str]) -> np.ndarray:
array of shape (n_texts)
Array of string class labels.
"""
return asyncio.run(self.predict_async(X))
labels = asyncio.run(self.predict_async(X))
if self.fuzzy_match:
labels = [self.fuzzy_match_label(label) for label in labels]
return labels

0 comments on commit 8669f8d

Please sign in to comment.