From 470177bfa207a02f7239b00b1952cdd963ae9ad3 Mon Sep 17 00:00:00 2001 From: Kai Schlamp Date: Mon, 18 Mar 2024 22:45:59 +0000 Subject: [PATCH] Allow the providers to specify max_results for search and RAG retrieval --- TODO.md | 6 +++++- radis/rag/site.py | 17 +++++++++++++-- radis/rag/tasks.py | 14 +++++-------- radis/search/site.py | 40 ++++++++++++++++++++++++----------- radis/search/views.py | 5 ++++- radis/vespa/apps.py | 41 +++++++++++++++++++++++++----------- radis/vespa/vespa_app.py | 45 ++++++++++++++++++++++++++++++---------- 7 files changed, 120 insertions(+), 48 deletions(-) diff --git a/TODO.md b/TODO.md index 8f8890b4..e094a22d 100644 --- a/TODO.md +++ b/TODO.md @@ -2,6 +2,8 @@ ## High Priority +- Check if for RAG ranking should be turned off for performance improvements (and using some fixed sort order) +- Some present provider.max_results to the user somehow, especially important if the query results (step 1) is larger - task control panel - Limit RAG search result choice (stop when there x positive results) - Limit RAG search by maximum number of reports that gets processed @@ -60,10 +62,12 @@ - When a new question is added to the catalog all existing and also upcoming reports will be evaluated - Users can filter by those categories in the normal search (make this plug-in able) - Subscriptions app - - Users can subscribe to Patient IDs, modalities, keywords, questions (see RAG app), categories (see Categories app) + - Users can subscribe to Patient IDs, questions (see RAG app), categories (see Categories app) + - Can also filter by modalities, study description, patient sex, patient age range - Cave, make sure categories app are evaluated before subscriptions - Users get notified by Email when new matching reports arrive in the future - Maybe link to report in RADIS in Email, optionally full report text in Email + - Maybe only allow a maximum number of hits - Maybe set a maximum number of reports in Email - Allow to export collections to ADIT to transfer the corresponding studies diff --git a/radis/rag/site.py b/radis/rag/site.py index fe472189..8c167760 100644 --- a/radis/rag/site.py +++ b/radis/rag/site.py @@ -13,12 +13,25 @@ class RetrievalResult(NamedTuple): class RetrievalProvider(NamedTuple): + """ + A class representing a retrieval provider. + + Attributes: + - name (str): The name of the retrieval provider. + - handler (SearchHandler): The function that handles the retrieval. + - max_results (int): The maximum number of results that can be returned. + Must be smaller than offset + limit when searching. + - info_template (str): The template to be rendered as info. + """ + name: str handler: RetrievalHandler + max_results: int + info_template: str retrieval_providers: dict[str, RetrievalProvider] = {} -def register_retrieval_provider(name: str, handler: RetrievalHandler): - retrieval_providers[name] = RetrievalProvider(name, handler) +def register_retrieval_provider(retrieval_provider: RetrievalProvider): + retrieval_providers[retrieval_provider.name] = retrieval_provider diff --git a/radis/rag/tasks.py b/radis/rag/tasks.py index b18954e7..0aabbddb 100644 --- a/radis/rag/tasks.py +++ b/radis/rag/tasks.py @@ -129,15 +129,13 @@ def collect_tasks(self, job: RagJob) -> Iterator[RagTask]: elif job.patient_sex == "F": patient_sex = "F" - # TODO: We have to handle the result size somehow. Maybe page through - # the results or alter the Vespa setting that infinite results are - # returned. - # Maybe also set a limit of maximum results that can be collected to - # be processed by a LLM. + provider = job.provider + retrieval_provider = retrieval_providers[provider] + search = Search( query=job.query, offset=0, - limit=100, + limit=retrieval_provider.max_results, filters=SearchFilters( study_date_from=job.study_date_from, study_date_till=job.study_date_till, @@ -151,9 +149,7 @@ def collect_tasks(self, job: RagJob) -> Iterator[RagTask]: logger.debug("Searching reports for task with search: %s", search) - provider = job.provider - search_provider = retrieval_providers[provider] - result = search_provider.handler(search) + result = retrieval_provider.handler(search) for document_id in result.document_ids: task = RagTask.objects.create( diff --git a/radis/search/site.py b/radis/search/site.py index 6da1baf1..3d9cdf38 100644 --- a/radis/search/site.py +++ b/radis/search/site.py @@ -41,6 +41,19 @@ class SearchFilters: class Search(NamedTuple): + """ + A class representing a search. + + If both offset and limit are set to 0, then the search provider + should return the most accurate total count it can calculate. + + Attributes: + - query (str): The query to search. + - offset (int): The offset of the search results. + - limit (int): The limit of the search results. + - filters (SearchFilters): The filters to apply to the search. + """ + query: str offset: int = 0 limit: int = 10 @@ -51,23 +64,26 @@ class Search(NamedTuple): class SearchProvider(NamedTuple): + """ + A class representing a search provider. + + Attributes: + - name (str): The name of the search provider. + - handler (SearchHandler): The function that handles the search. + - max_results (int): The maximum number of results that can be returned. + Must be smaller than offset + limit when searching. + - info_template (str): The template to be rendered as info. + """ + name: str handler: SearchHandler + max_results: int info_template: str search_providers: dict[str, SearchProvider] = {} -def register_search_provider( - name: str, - handler: SearchHandler, - info_template: str, -) -> None: - """Register a search handler. - - The name can be selected by the user in the search form. The searcher is called - when the user submits the form and returns the results. The template name is - the partial to be rendered as info below the search form. - """ - search_providers[name] = SearchProvider(name, handler, info_template) +def register_search_provider(search_provider: SearchProvider) -> None: + """Register a search provider.""" + search_providers[search_provider.name] = search_provider diff --git a/radis/search/views.py b/radis/search/views.py index 194392e0..c7acca8f 100644 --- a/radis/search/views.py +++ b/radis/search/views.py @@ -60,7 +60,10 @@ def get(self, request: AuthenticatedHttpRequest, *args, **kwargs): if total_count is not None: context["total_count"] = total_count - paginator = Paginator(range(total_count), page_size) + # We don't allow to paginate through all results, but the provider tells + # us how many results it can return + max_size = min(total_count, search_provider.max_results) + paginator = Paginator(range(max_size), page_size) context["paginator"] = paginator context["page_obj"] = paginator.get_page(page_number) diff --git a/radis/vespa/apps.py b/radis/vespa/apps.py index e6d9ebd2..297e0a4d 100644 --- a/radis/vespa/apps.py +++ b/radis/vespa/apps.py @@ -11,15 +11,16 @@ def ready(self): def register_app(): - from radis.rag.site import register_retrieval_provider + from radis.rag.site import RetrievalProvider, register_retrieval_provider from radis.reports.models import Report from radis.reports.site import ( ReportEventType, register_document_fetcher, register_report_handler, ) - from radis.search.site import register_search_provider + from radis.search.site import SearchProvider, register_search_provider from radis.vespa.providers import retrieve_bm25 + from radis.vespa.vespa_app import MAX_RETRIEVAL_HITS, MAX_SEARCH_HITS from .providers import search_bm25, search_hybrid, search_semantic from .utils.document_utils import ( @@ -49,21 +50,37 @@ def fetch_vespa_document(report: Report) -> dict[str, Any]: register_document_fetcher("vespa", fetch_vespa_document) register_search_provider( - name="Vespa Hybrid Ranking", - handler=search_hybrid, - info_template="vespa/_hybrid_info.html", + SearchProvider( + name="Vespa Hybrid Ranking", + handler=search_hybrid, + max_results=MAX_SEARCH_HITS, + info_template="vespa/_hybrid_info.html", + ) ) register_search_provider( - name="Vespa BM25 Ranking", - handler=search_bm25, - info_template="vespa/_bm25_info.html", + SearchProvider( + name="Vespa BM25 Ranking", + handler=search_bm25, + max_results=MAX_SEARCH_HITS, + info_template="vespa/_bm25_info.html", + ) ) register_search_provider( - name="Vespa Semantic Ranking", - handler=search_semantic, - info_template="vespa/_bm25_info.html", + SearchProvider( + name="Vespa Semantic Ranking", + handler=search_semantic, + max_results=MAX_SEARCH_HITS, + info_template="vespa/_bm25_info.html", + ) ) - register_retrieval_provider(name="Keyword Search", handler=retrieve_bm25) + register_retrieval_provider( + RetrievalProvider( + name="Vespa BM25", + handler=retrieve_bm25, + max_results=MAX_RETRIEVAL_HITS, + info_template="vespa/_bm25_info.html", + ) + ) diff --git a/radis/vespa/vespa_app.py b/radis/vespa/vespa_app.py index dff67ee0..ed96aeaf 100644 --- a/radis/vespa/vespa_app.py +++ b/radis/vespa/vespa_app.py @@ -27,6 +27,18 @@ SEARCH_QUERY_PROFILE = "SearchProfile" RETRIEVAL_QUERY_PROFILE = "RetrievalProfile" +# We set max hits to the same value as max offset as our search and retrieval +# provider (as most other full text search databases) only allow to set +# a maximum results (offset + limit). That way we make sure that the maximum +# results can really be reached regardless of the actual number of offset and +# limit. +MAX_SEARCH_HITS = 1000 +MAX_SEARCH_OFFSET = 1000 +SEARCH_TIMEOUT = 1 +MAX_RETRIEVAL_HITS = 10000 +MAX_RETRIEVAL_OFFSET = 10000 +RETRIEVAL_TIMEOUT = 10 + def _create_report_schema(): return Schema( @@ -170,18 +182,37 @@ def __init__(self, app_folder: PathLike) -> None: self.services_file = Path(app_folder) / "services.xml" self.services_doc = ET.parse(self.services_file) - def _add_query_profile(self, name: str, max_hits: int, max_offset: int, timeout: int): + def apply(self): + # We overwrite the generated default query profile + self._add_query_profile( + SEARCH_QUERY_PROFILE, + MAX_SEARCH_HITS, + MAX_SEARCH_OFFSET, + SEARCH_TIMEOUT, + ) + self._add_query_profile( + RETRIEVAL_QUERY_PROFILE, + MAX_RETRIEVAL_HITS, + MAX_RETRIEVAL_OFFSET, + RETRIEVAL_TIMEOUT, + ) + self._add_bolding_config() + self._add_dynamic_snippet_config() + self._write() + + def _add_query_profile(self, profile_name: str, max_hits: int, max_offset: int, timeout: int): query_profile_el = ET.fromstring( f""" - + {max_hits} {max_offset} + {timeout} """ ) tree = ET.ElementTree(query_profile_el) ET.indent(tree, space="\t", level=0) - with open(self.query_profiles_folder / f"{name}.xml", "wb") as f: + with open(self.query_profiles_folder / f"{profile_name}.xml", "wb") as f: tree.write(f, encoding="UTF-8") # https://docs.vespa.ai/en/reference/schema-reference.html#bolding @@ -221,14 +252,6 @@ def _write(self): ET.indent(self.services_doc, " ") self.services_doc.write(self.services_file, encoding="UTF-8", xml_declaration=True) - def apply(self): - # We overwrite the generated default query profile - self._add_query_profile(SEARCH_QUERY_PROFILE, 100, 900, 1) - self._add_query_profile(RETRIEVAL_QUERY_PROFILE, 1000, 9000, 10) - self._add_bolding_config() - self._add_dynamic_snippet_config() - self._write() - class VespaApp: _vespa_host = settings.VESPA_HOST