Skip to content

Commit

Permalink
Allow the providers to specify max_results for search and RAG retrieval
Browse files Browse the repository at this point in the history
  • Loading branch information
medihack committed Mar 18, 2024
1 parent 2678a50 commit 470177b
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 48 deletions.
6 changes: 5 additions & 1 deletion TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## High Priority

- Check if for RAG ranking should be turned off for performance improvements (and using some fixed sort order)
- Some present provider.max_results to the user somehow, especially important if the query results (step 1) is larger
- task control panel
- Limit RAG search result choice (stop when there x positive results)
- Limit RAG search by maximum number of reports that gets processed
Expand Down Expand Up @@ -60,10 +62,12 @@
- When a new question is added to the catalog all existing and also upcoming reports will be evaluated
- Users can filter by those categories in the normal search (make this plug-in able)
- Subscriptions app
- Users can subscribe to Patient IDs, modalities, keywords, questions (see RAG app), categories (see Categories app)
- Users can subscribe to Patient IDs, questions (see RAG app), categories (see Categories app)
- Can also filter by modalities, study description, patient sex, patient age range
- Cave, make sure categories app are evaluated before subscriptions
- Users get notified by Email when new matching reports arrive in the future
- Maybe link to report in RADIS in Email, optionally full report text in Email
- Maybe only allow a maximum number of hits
- Maybe set a maximum number of reports in Email
- Allow to export collections to ADIT to transfer the corresponding studies

Expand Down
17 changes: 15 additions & 2 deletions radis/rag/site.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,25 @@ class RetrievalResult(NamedTuple):


class RetrievalProvider(NamedTuple):
"""
A class representing a retrieval provider.
Attributes:
- name (str): The name of the retrieval provider.
- handler (SearchHandler): The function that handles the retrieval.
- max_results (int): The maximum number of results that can be returned.
Must be smaller than offset + limit when searching.
- info_template (str): The template to be rendered as info.
"""

name: str
handler: RetrievalHandler
max_results: int
info_template: str


retrieval_providers: dict[str, RetrievalProvider] = {}


def register_retrieval_provider(name: str, handler: RetrievalHandler):
retrieval_providers[name] = RetrievalProvider(name, handler)
def register_retrieval_provider(retrieval_provider: RetrievalProvider):
retrieval_providers[retrieval_provider.name] = retrieval_provider
14 changes: 5 additions & 9 deletions radis/rag/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,13 @@ def collect_tasks(self, job: RagJob) -> Iterator[RagTask]:
elif job.patient_sex == "F":
patient_sex = "F"

# TODO: We have to handle the result size somehow. Maybe page through
# the results or alter the Vespa setting that infinite results are
# returned.
# Maybe also set a limit of maximum results that can be collected to
# be processed by a LLM.
provider = job.provider
retrieval_provider = retrieval_providers[provider]

search = Search(
query=job.query,
offset=0,
limit=100,
limit=retrieval_provider.max_results,
filters=SearchFilters(
study_date_from=job.study_date_from,
study_date_till=job.study_date_till,
Expand All @@ -151,9 +149,7 @@ def collect_tasks(self, job: RagJob) -> Iterator[RagTask]:

logger.debug("Searching reports for task with search: %s", search)

provider = job.provider
search_provider = retrieval_providers[provider]
result = search_provider.handler(search)
result = retrieval_provider.handler(search)

for document_id in result.document_ids:
task = RagTask.objects.create(
Expand Down
40 changes: 28 additions & 12 deletions radis/search/site.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,19 @@ class SearchFilters:


class Search(NamedTuple):
"""
A class representing a search.
If both offset and limit are set to 0, then the search provider
should return the most accurate total count it can calculate.
Attributes:
- query (str): The query to search.
- offset (int): The offset of the search results.
- limit (int): The limit of the search results.
- filters (SearchFilters): The filters to apply to the search.
"""

query: str
offset: int = 0
limit: int = 10
Expand All @@ -51,23 +64,26 @@ class Search(NamedTuple):


class SearchProvider(NamedTuple):
"""
A class representing a search provider.
Attributes:
- name (str): The name of the search provider.
- handler (SearchHandler): The function that handles the search.
- max_results (int): The maximum number of results that can be returned.
Must be smaller than offset + limit when searching.
- info_template (str): The template to be rendered as info.
"""

name: str
handler: SearchHandler
max_results: int
info_template: str


search_providers: dict[str, SearchProvider] = {}


def register_search_provider(
name: str,
handler: SearchHandler,
info_template: str,
) -> None:
"""Register a search handler.
The name can be selected by the user in the search form. The searcher is called
when the user submits the form and returns the results. The template name is
the partial to be rendered as info below the search form.
"""
search_providers[name] = SearchProvider(name, handler, info_template)
def register_search_provider(search_provider: SearchProvider) -> None:
"""Register a search provider."""
search_providers[search_provider.name] = search_provider
5 changes: 4 additions & 1 deletion radis/search/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ def get(self, request: AuthenticatedHttpRequest, *args, **kwargs):

if total_count is not None:
context["total_count"] = total_count
paginator = Paginator(range(total_count), page_size)
# We don't allow to paginate through all results, but the provider tells
# us how many results it can return
max_size = min(total_count, search_provider.max_results)
paginator = Paginator(range(max_size), page_size)
context["paginator"] = paginator
context["page_obj"] = paginator.get_page(page_number)

Expand Down
41 changes: 29 additions & 12 deletions radis/vespa/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@ def ready(self):


def register_app():
from radis.rag.site import register_retrieval_provider
from radis.rag.site import RetrievalProvider, register_retrieval_provider
from radis.reports.models import Report
from radis.reports.site import (
ReportEventType,
register_document_fetcher,
register_report_handler,
)
from radis.search.site import register_search_provider
from radis.search.site import SearchProvider, register_search_provider
from radis.vespa.providers import retrieve_bm25
from radis.vespa.vespa_app import MAX_RETRIEVAL_HITS, MAX_SEARCH_HITS

from .providers import search_bm25, search_hybrid, search_semantic
from .utils.document_utils import (
Expand Down Expand Up @@ -49,21 +50,37 @@ def fetch_vespa_document(report: Report) -> dict[str, Any]:
register_document_fetcher("vespa", fetch_vespa_document)

register_search_provider(
name="Vespa Hybrid Ranking",
handler=search_hybrid,
info_template="vespa/_hybrid_info.html",
SearchProvider(
name="Vespa Hybrid Ranking",
handler=search_hybrid,
max_results=MAX_SEARCH_HITS,
info_template="vespa/_hybrid_info.html",
)
)

register_search_provider(
name="Vespa BM25 Ranking",
handler=search_bm25,
info_template="vespa/_bm25_info.html",
SearchProvider(
name="Vespa BM25 Ranking",
handler=search_bm25,
max_results=MAX_SEARCH_HITS,
info_template="vespa/_bm25_info.html",
)
)

register_search_provider(
name="Vespa Semantic Ranking",
handler=search_semantic,
info_template="vespa/_bm25_info.html",
SearchProvider(
name="Vespa Semantic Ranking",
handler=search_semantic,
max_results=MAX_SEARCH_HITS,
info_template="vespa/_bm25_info.html",
)
)

register_retrieval_provider(name="Keyword Search", handler=retrieve_bm25)
register_retrieval_provider(
RetrievalProvider(
name="Vespa BM25",
handler=retrieve_bm25,
max_results=MAX_RETRIEVAL_HITS,
info_template="vespa/_bm25_info.html",
)
)
45 changes: 34 additions & 11 deletions radis/vespa/vespa_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,18 @@
SEARCH_QUERY_PROFILE = "SearchProfile"
RETRIEVAL_QUERY_PROFILE = "RetrievalProfile"

# We set max hits to the same value as max offset as our search and retrieval
# provider (as most other full text search databases) only allow to set
# a maximum results (offset + limit). That way we make sure that the maximum
# results can really be reached regardless of the actual number of offset and
# limit.
MAX_SEARCH_HITS = 1000
MAX_SEARCH_OFFSET = 1000
SEARCH_TIMEOUT = 1
MAX_RETRIEVAL_HITS = 10000
MAX_RETRIEVAL_OFFSET = 10000
RETRIEVAL_TIMEOUT = 10


def _create_report_schema():
return Schema(
Expand Down Expand Up @@ -170,18 +182,37 @@ def __init__(self, app_folder: PathLike) -> None:
self.services_file = Path(app_folder) / "services.xml"
self.services_doc = ET.parse(self.services_file)

def _add_query_profile(self, name: str, max_hits: int, max_offset: int, timeout: int):
def apply(self):
# We overwrite the generated default query profile
self._add_query_profile(
SEARCH_QUERY_PROFILE,
MAX_SEARCH_HITS,
MAX_SEARCH_OFFSET,
SEARCH_TIMEOUT,
)
self._add_query_profile(
RETRIEVAL_QUERY_PROFILE,
MAX_RETRIEVAL_HITS,
MAX_RETRIEVAL_OFFSET,
RETRIEVAL_TIMEOUT,
)
self._add_bolding_config()
self._add_dynamic_snippet_config()
self._write()

def _add_query_profile(self, profile_name: str, max_hits: int, max_offset: int, timeout: int):
query_profile_el = ET.fromstring(
f"""
<query-profile id="{name}">
<query-profile id="{profile_name}">
<field name="maxHits">{max_hits}</field>
<field name="maxOffset">{max_offset}</field>
<field name="timeout">{timeout}</field>
</query-profile>
"""
)
tree = ET.ElementTree(query_profile_el)
ET.indent(tree, space="\t", level=0)
with open(self.query_profiles_folder / f"{name}.xml", "wb") as f:
with open(self.query_profiles_folder / f"{profile_name}.xml", "wb") as f:
tree.write(f, encoding="UTF-8")

# https://docs.vespa.ai/en/reference/schema-reference.html#bolding
Expand Down Expand Up @@ -221,14 +252,6 @@ def _write(self):
ET.indent(self.services_doc, " ")
self.services_doc.write(self.services_file, encoding="UTF-8", xml_declaration=True)

def apply(self):
# We overwrite the generated default query profile
self._add_query_profile(SEARCH_QUERY_PROFILE, 100, 900, 1)
self._add_query_profile(RETRIEVAL_QUERY_PROFILE, 1000, 9000, 10)
self._add_bolding_config()
self._add_dynamic_snippet_config()
self._write()


class VespaApp:
_vespa_host = settings.VESPA_HOST
Expand Down

0 comments on commit 470177b

Please sign in to comment.