Skip to content

Commit

Permalink
refactor: reduce redundancy
Browse files Browse the repository at this point in the history
  • Loading branch information
mehmetcanay committed Feb 23, 2024
1 parent d00cfe3 commit eafe8a4
Show file tree
Hide file tree
Showing 6 changed files with 220 additions and 288 deletions.
8 changes: 4 additions & 4 deletions index/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ def get_embeddings(self, messages: [str], model="text-embedding-ada-002"):
return [item["embedding"] for item in response["data"]]


class MPNetAdapter(EmbeddingModel):
class SentenceTransformerAdapter(EmbeddingModel):
def __init__(self, model="sentence-transformers/all-mpnet-base-v2"):
logging.getLogger().setLevel(logging.INFO)
self.mpnet_model = SentenceTransformer(model)
self.model = SentenceTransformer(model)

def get_embedding(self, text: str):
logging.info(f"Getting embedding for {text}")
Expand All @@ -51,13 +51,13 @@ def get_embedding(self, text: str):
return None
if isinstance(text, str):
text = text.replace("\n", " ")
return self.mpnet_model.encode(text)
return self.model.encode(text)
except Exception as e:
logging.error(f"Error getting embedding for {text}: {e}")
return None

def get_embeddings(self, messages: [str]) -> [[float]]:
embeddings = self.mpnet_model.encode(messages)
embeddings = self.model.encode(messages)
flattened_embeddings = [[float(element) for element in row] for row in embeddings]
return flattened_embeddings

Expand Down
10 changes: 5 additions & 5 deletions index/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,19 +197,19 @@ def score_mappings(matches: pd.DataFrame) -> float:
return accuracy


def evaluate(datasets, labels, store_results=False, model="gpt", results_root_dir="resources/results/pd"):
def evaluate(datasets, labels, model: str, matching_method="euclidean", store_results=False, results_root_dir="resources/results/pd"):
data = {}
for idx, source in enumerate(datasets):
acc = []
for idy, target in enumerate(datasets):
if model == "gpt":
if matching_method == "euclidean":
map = match_closest_descriptions(source, target)
elif model == "mpnet":
elif matching_method == "cosine":
map = match_closest_descriptions(source,target, matching_method=MatchingMethod.COSINE_EMBEDDING_DISTANCE)
elif model == "fuzzy":
elif matching_method == "fuzzy":
map = match_closest_descriptions(source, target, matching_method=MatchingMethod.FUZZY_STRING_MATCHING)
else:
raise NotImplementedError("Specified model is not implemented!")
raise NotImplementedError("Matching method is not implemented!")
if store_results:
map.to_excel(results_root_dir + f"/{model}_" + f"{labels[idx]}_to_{labels[idy]}.xlsx")
acc.append(round(score_mappings(map), 2))
Expand Down
Loading

0 comments on commit eafe8a4

Please sign in to comment.