From 651326dec9058b7146ea2e75058bb1aa55c0cd07 Mon Sep 17 00:00:00 2001 From: Mehmet Can Ay Date: Tue, 3 Sep 2024 13:59:45 +0200 Subject: [PATCH 1/3] refactor: enable filtering --- api/routes.py | 38 +++++++++++++++++++++++++++----------- 1 file changed, 27 insertions(+), 11 deletions(-) diff --git a/api/routes.py b/api/routes.py index 64d52f3..d373fbb 100644 --- a/api/routes.py +++ b/api/routes.py @@ -5,15 +5,15 @@ import uvicorn from datastew import DataDictionarySource -from datastew.embedding import MPNetAdapter +from datastew.embedding import GPT4Adapter, MPNetAdapter from datastew.process.ols import OLSTerminologyImportTask from datastew.repository import WeaviateRepository -from datastew.repository.model import Terminology, Concept, Mapping +from datastew.repository.model import Concept, Mapping, Terminology from datastew.visualisation import get_plot_for_current_database_state -from fastapi import FastAPI, HTTPException, File, UploadFile +from fastapi import FastAPI, File, HTTPException, UploadFile from starlette.background import BackgroundTasks from starlette.middleware.cors import CORSMiddleware -from starlette.responses import RedirectResponse, HTMLResponse +from starlette.responses import HTMLResponse, RedirectResponse app = FastAPI( title="INDEX", @@ -200,11 +200,25 @@ async def get_closest_mappings_for_text(text: str, terminology_name: str = "SNOM # Endpoint to get mappings for a data dictionary source @app.post("/mappings/dict", tags=["mappings"], description="Get mappings for a data dictionary source.") -async def get_closest_mappings_for_dictionary(file: UploadFile = File(...), variable_field: str = 'variable', - description_field: str = 'description'): +async def get_closest_mappings_for_dictionary(file: UploadFile = File(...), + selected_model: str = "sentence-transformers/all-mpnet-base-v2", + selected_terminology: str = "SNOMED CT", + variable_field: str = "variable", + description_field: str = "description"): try: + if selected_model == "text-embedding-ada-002": + embedding_model = GPT4Adapter(selected_model) + elif selected_model == "sentence-transformers/all-mpnet-base-v2": + embedding_model = MPNetAdapter(selected_model) + else: + raise HTTPException(status_code=400, detail="Unsupported embedding model.") + # Determine file extension and create a temporary file with the correct extension - _, file_extension = os.path.splitext(file.filename) + if file.filename is not None: + file_extension = os.path.splitext(file.filename)[1].lower() + else: + raise HTTPException(status_code=400, detail="Invalid file type. The file must have a suffix.") + with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file: tmp_file.write(await file.read()) tmp_file_path = tmp_file.name @@ -219,15 +233,17 @@ async def get_closest_mappings_for_dictionary(file: UploadFile = File(...), vari variable = row['variable'] description = row['description'] embedding = embedding_model.get_embedding(description) - closest_mappings, similarities = repository.get_closest_mappings(embedding, limit=5) + closest_mappings = repository.get_terminology_and_model_specific_closest_mappings( + embedding, selected_terminology, selected_model, limit=5 + ) mappings_list = [] - for mapping, similarity in zip(closest_mappings, similarities): + for mapping, similarity in closest_mappings: concept = mapping.concept terminology = concept.terminology mappings_list.append({ "concept": { - "id": concept.concept_id, - "name": concept.name, + "id": concept.concept_identifier, + "name": concept.pref_label, "terminology": { "id": terminology.id, "name": terminology.name From 08d20cde49e17a5b16f7ce167304f37100b148fd Mon Sep 17 00:00:00 2001 From: Mehmet Can Ay Date: Mon, 16 Sep 2024 14:51:20 +0200 Subject: [PATCH 2/3] fix: model selection --- api/routes.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/api/routes.py b/api/routes.py index d373fbb..16a2750 100644 --- a/api/routes.py +++ b/api/routes.py @@ -174,10 +174,19 @@ async def create_mapping(concept_id: str, text: str): @app.post("/mappings", tags=["mappings"]) -async def get_closest_mappings_for_text(text: str, terminology_name: str = "SNOMED CT", - sentence_embedder: str = "sentence-transformers/all-mpnet-base-v2", limit: int = 5): +async def get_closest_mappings_for_text(text: str, + terminology_name: str = "SNOMED CT", + model: str = "sentence-transformers/all-mpnet-base-v2", + limit: int = 5): + if model == "text-embedding-ada-002": + embedding_model = GPT4Adapter(model) + elif model == "sentence-transformers/all-mpnet-base-v2": + embedding_model = MPNetAdapter(model) + else: + raise HTTPException(status_code=400, detail="Unsupported embedding model.") + embedding = embedding_model.get_embedding(text).tolist() - closest_mappings = repository.get_terminology_and_model_specific_closest_mappings(embedding, terminology_name, sentence_embedder, limit) + closest_mappings = repository.get_terminology_and_model_specific_closest_mappings(embedding, terminology_name, model, limit) mappings = [] for mapping, similarity in closest_mappings: concept = mapping.concept @@ -201,15 +210,15 @@ async def get_closest_mappings_for_text(text: str, terminology_name: str = "SNOM # Endpoint to get mappings for a data dictionary source @app.post("/mappings/dict", tags=["mappings"], description="Get mappings for a data dictionary source.") async def get_closest_mappings_for_dictionary(file: UploadFile = File(...), - selected_model: str = "sentence-transformers/all-mpnet-base-v2", - selected_terminology: str = "SNOMED CT", + model: str = "sentence-transformers/all-mpnet-base-v2", + terminology_name: str = "SNOMED CT", variable_field: str = "variable", description_field: str = "description"): try: - if selected_model == "text-embedding-ada-002": - embedding_model = GPT4Adapter(selected_model) - elif selected_model == "sentence-transformers/all-mpnet-base-v2": - embedding_model = MPNetAdapter(selected_model) + if model == "text-embedding-ada-002": + embedding_model = GPT4Adapter(model) + elif model == "sentence-transformers/all-mpnet-base-v2": + embedding_model = MPNetAdapter(model) else: raise HTTPException(status_code=400, detail="Unsupported embedding model.") @@ -234,7 +243,7 @@ async def get_closest_mappings_for_dictionary(file: UploadFile = File(...), description = row['description'] embedding = embedding_model.get_embedding(description) closest_mappings = repository.get_terminology_and_model_specific_closest_mappings( - embedding, selected_terminology, selected_model, limit=5 + embedding, terminology_name, model, limit=5 ) mappings_list = [] for mapping, similarity in closest_mappings: From b732aa194b88b7b97f94de75693e793026bd9f57 Mon Sep 17 00:00:00 2001 From: Mehmet Can Ay Date: Tue, 17 Sep 2024 09:51:57 +0200 Subject: [PATCH 3/3] refactor: remove token using models --- api/routes.py | 57 ++++++++++++++++++++++----------------------------- 1 file changed, 24 insertions(+), 33 deletions(-) diff --git a/api/routes.py b/api/routes.py index 16a2750..f42e6a1 100644 --- a/api/routes.py +++ b/api/routes.py @@ -178,33 +178,30 @@ async def get_closest_mappings_for_text(text: str, terminology_name: str = "SNOMED CT", model: str = "sentence-transformers/all-mpnet-base-v2", limit: int = 5): - if model == "text-embedding-ada-002": - embedding_model = GPT4Adapter(model) - elif model == "sentence-transformers/all-mpnet-base-v2": - embedding_model = MPNetAdapter(model) - else: - raise HTTPException(status_code=400, detail="Unsupported embedding model.") - - embedding = embedding_model.get_embedding(text).tolist() - closest_mappings = repository.get_terminology_and_model_specific_closest_mappings(embedding, terminology_name, model, limit) - mappings = [] - for mapping, similarity in closest_mappings: - concept = mapping.concept - terminology = concept.terminology - mappings.append({ - "concept": { - "id": concept.concept_identifier, - "name": concept.pref_label, - "terminology": { - "id": terminology.id, - "name": terminology.name - } - }, - "text": mapping.text, - "similarity": similarity - }) + try: + embedding_model = MPNetAdapter(model) + embedding = embedding_model.get_embedding(text).tolist() + closest_mappings = repository.get_terminology_and_model_specific_closest_mappings(embedding, terminology_name, model, limit) + mappings = [] + for mapping, similarity in closest_mappings: + concept = mapping.concept + terminology = concept.terminology + mappings.append({ + "concept": { + "id": concept.concept_identifier, + "name": concept.pref_label, + "terminology": { + "id": terminology.id, + "name": terminology.name + } + }, + "text": mapping.text, + "similarity": similarity + }) - return mappings + return mappings + except Exception as e: + raise HTTPException(status_code=400, detail=f"Failed to get closest mappings: {str(e)}") # Endpoint to get mappings for a data dictionary source @@ -215,13 +212,7 @@ async def get_closest_mappings_for_dictionary(file: UploadFile = File(...), variable_field: str = "variable", description_field: str = "description"): try: - if model == "text-embedding-ada-002": - embedding_model = GPT4Adapter(model) - elif model == "sentence-transformers/all-mpnet-base-v2": - embedding_model = MPNetAdapter(model) - else: - raise HTTPException(status_code=400, detail="Unsupported embedding model.") - + embedding_model = MPNetAdapter(model) # Determine file extension and create a temporary file with the correct extension if file.filename is not None: file_extension = os.path.splitext(file.filename)[1].lower()