diff --git a/notebooks/radis_api.ipynb b/notebooks/radis_api.ipynb index d3b051a0..3145b141 100644 --- a/notebooks/radis_api.ipynb +++ b/notebooks/radis_api.ipynb @@ -10,7 +10,7 @@ "output_type": "stream", "text": [ "Status Code: 201\n", - "{'id': 104, 'document_id': 'gepacs_3dfidii5858-6633i4-ii398841', 'pacs_aet': 'gepacs', 'pacs_name': 'GE PACS', 'patient_id': '1234578', 'patient_birth_date': '1976-05-23', 'patient_sex': 'M', 'study_instance_uid': '34343-34343-34343', 'accession_number': '345348389', 'study_description': 'CT of the Thorax', 'study_datetime': '2000-08-10T00:00:00+02:00', 'series_instance_uid': '34343-676556-3343', 'modalities_in_study': ['CT', 'PET'], 'sop_instance_uid': '35858-384834-3843', 'references': ['http://gepacs.com/34343-34343-34343'], 'body': 'This is the report', 'groups': [2]}\n" + "{'id': 101, 'document_id': 'gepacs_3dfidii5858-6633i4-ii398841', 'pacs_aet': 'gepacs', 'pacs_name': 'GE PACS', 'patient_id': '1234578', 'patient_birth_date': '1976-05-23', 'patient_sex': 'M', 'study_instance_uid': '34343-34343-34343', 'accession_number': '345348389', 'study_description': 'CT of the Thorax', 'study_datetime': '2000-08-10T00:00:00+02:00', 'series_instance_uid': '34343-676556-3343', 'modalities_in_study': ['CT', 'PET'], 'sop_instance_uid': '35858-384834-3843', 'references': ['http://gepacs.com/34343-34343-34343'], 'body': 'This is the report', 'groups': [2]}\n" ] } ], @@ -53,13 +53,13 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "{'id': 102,\n", + "{'id': 101,\n", " 'document_id': 'gepacs_3dfidii5858-6633i4-ii398841',\n", " 'pacs_aet': 'gepacs',\n", " 'pacs_name': 'GE PACS',\n", @@ -78,7 +78,7 @@ " 'groups': [2]}" ] }, - "execution_count": 10, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -114,13 +114,13 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "{'id': 102,\n", + "{'id': 101,\n", " 'document_id': 'gepacs_3dfidii5858-6633i4-ii398841',\n", " 'pacs_aet': 'gepacs',\n", " 'pacs_name': 'GE PACS',\n", @@ -151,7 +151,7 @@ " 'study_datetime': 965858400}}}" ] }, - "execution_count": 11, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -166,7 +166,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -201,7 +201,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.12.1" }, "orig_nbformat": 4 }, diff --git a/radis/core/management/commands/populate_db.py b/radis/core/management/commands/populate_db.py index 10b6b315..8a67116c 100644 --- a/radis/core/management/commands/populate_db.py +++ b/radis/core/management/commands/populate_db.py @@ -4,15 +4,14 @@ from django.conf import settings from django.contrib.auth.models import Group, Permission -from django.core.management import call_command -from django.core.management.base import BaseCommand, CommandParser +from django.core.management.base import BaseCommand from faker import Faker from radis.accounts.factories import AdminUserFactory, GroupFactory, UserFactory from radis.accounts.models import User from radis.reports.factories import ReportFactory -from radis.search.models import ReportDocument -from radis.search.vespa_app import vespa_app +from radis.reports.models import Report +from radis.reports.site import report_event_handlers from radis.token_authentication.factories import TokenFactory from radis.token_authentication.models import FRACTION_LENGTH from radis.token_authentication.utils.crypto import hash_token @@ -34,7 +33,8 @@ def feed_report(body: str): report = ReportFactory.create(body=body) groups = fake.random_elements(elements=list(Group.objects.all()), unique=True) report.groups.set(groups) - ReportDocument(report).create() + for handler in report_event_handlers: + handler("created", report) def feed_reports(): @@ -107,16 +107,7 @@ def create_groups(users: list[User]) -> list[Group]: class Command(BaseCommand): help = "Populates the database with example data." - def add_arguments(self, parser: CommandParser) -> None: - parser.add_argument("--reset", action="store_true") - def handle(self, *args, **options): - if options["reset"]: - # Can only be done when dev server is not running and needs django_extensions installed - call_command("reset_db", "--noinput") - call_command("migrate") - vespa_app.get_client().delete_all_docs("radis_content", "report") - if User.objects.count() > 0: print("Development database already populated. Skipping.") else: @@ -124,11 +115,8 @@ def handle(self, *args, **options): users = create_users() create_groups(users) - results = vespa_app.get_client().query( - {"yql": "select * from sources * where true", "hits": 1} - ) - if results.number_documents_retrieved > 0: - print("Vespa already populated. Skipping.") + if Report.objects.first(): + print("Reports already populated. Skipping.") else: - print("Populating Vespa with example reports.") + print("Populating database with example reports.") feed_reports() diff --git a/radis/reports/api/viewsets.py b/radis/reports/api/viewsets.py index 1a72beda..83ea7d98 100644 --- a/radis/reports/api/viewsets.py +++ b/radis/reports/api/viewsets.py @@ -37,7 +37,7 @@ def retrieve(self, request: Request, *args: Any, **kwargs: Any) -> Response: instance: Report = self.get_object() extra = {} - for fetcher in document_fetchers: + for fetcher in document_fetchers.values(): document = fetcher.fetch(instance) if document: extra[fetcher.source] = document diff --git a/radis/reports/site.py b/radis/reports/site.py index 1cd14723..bcdf14db 100644 --- a/radis/reports/site.py +++ b/radis/reports/site.py @@ -28,7 +28,7 @@ class DocumentFetcher(NamedTuple): fetch: FetchDocument -document_fetchers: list[DocumentFetcher] = [] +document_fetchers: dict[str, DocumentFetcher] = {} def register_document_fetcher(source: str, fetch: FetchDocument) -> None: @@ -38,7 +38,7 @@ def register_document_fetcher(source: str, fetch: FetchDocument) -> None: database and returns a document in the form of a dictionary from another database (like Vespa). """ - document_fetchers.append(DocumentFetcher(source, fetch)) + document_fetchers[source] = DocumentFetcher(source, fetch) class ReportPanelButton(NamedTuple): diff --git a/radis/search/apps.py b/radis/search/apps.py index 5f39c8e3..5cffa968 100644 --- a/radis/search/apps.py +++ b/radis/search/apps.py @@ -14,18 +14,12 @@ def ready(self): def register_app(): from radis.core.site import register_main_menu_item - from radis.reports.site import register_document_fetcher, register_report_handler - - from .models import fetch_document, handle_report register_main_menu_item( url_name="search", label="Search", ) - register_report_handler(handle_report) - register_document_fetcher("vespa", fetch_document) - def init_db(**kwargs): create_app_settings() diff --git a/radis/search/models.py b/radis/search/models.py index a349ef12..2059cde8 100644 --- a/radis/search/models.py +++ b/radis/search/models.py @@ -1,17 +1,9 @@ import logging -from dataclasses import dataclass -from datetime import date, datetime, time -from typing import Any, Literal - -from rest_framework.status import HTTP_200_OK -from vespa.io import VespaQueryResponse +from datetime import date, datetime +from typing import Literal, NamedTuple from radis.core.models import AppSettings from radis.reports.models import Report -from radis.reports.site import ReportEventType - -from .utils.search_utils import extract_document_id -from .vespa_app import REPORT_SCHEMA_NAME, vespa_app logger = logging.getLogger(__name__) @@ -21,73 +13,7 @@ class Meta: verbose_name_plural = "Search app settings" -class ReportDocument: - def __init__(self, report: Report) -> None: - self.report = report - - def _dictify_for_vespa(self) -> dict[str, Any]: - """Dictify the report for Vespa. - - Must be in the same format as schema in vespa_app.py - """ - # Vespa can't store dates and datetimes natively, so we store them as a number. - patient_birth_date = int( - datetime.combine(self.report.patient_birth_date, time()).timestamp() - ) - study_datetime = int(self.report.study_datetime.timestamp()) - - return { - "groups": [group.id for group in self.report.groups.all()], - "pacs_aet": self.report.pacs_aet, - "pacs_name": self.report.pacs_name, - "patient_birth_date": patient_birth_date, - "patient_sex": self.report.patient_sex, - "study_description": self.report.study_description, - "study_datetime": study_datetime, - "modalities_in_study": self.report.modalities_in_study, - "references": self.report.references, - "body": self.report.body.strip(), - } - - def fetch(self) -> dict[str, Any]: - response = vespa_app.get_client().get_data(REPORT_SCHEMA_NAME, self.report.document_id) - - if response.get_status_code() != HTTP_200_OK: - message = response.get_json() - raise Exception(f"Error while fetching report from Vespa: {message}") - - return response.get_json() - - def create(self) -> None: - fields = self._dictify_for_vespa() - response = vespa_app.get_client().feed_data_point( - REPORT_SCHEMA_NAME, self.report.document_id, fields - ) - - if response.get_status_code() != HTTP_200_OK: - message = response.get_json() - raise Exception(f"Error while feeding report to Vespa: {message}") - - def update(self) -> None: - fields = self._dictify_for_vespa() - response = vespa_app.get_client().update_data( - REPORT_SCHEMA_NAME, self.report.document_id, fields - ) - - if response.get_status_code() != HTTP_200_OK: - message = response.get_json() - raise Exception(f"Error while updating report on Vespa: {message}") - - def delete(self) -> None: - response = vespa_app.get_client().delete_data(REPORT_SCHEMA_NAME, self.report.document_id) - - if response.get_status_code() != HTTP_200_OK: - message = response.get_json() - raise Exception(f"Error while deleting report on Vespa: {message}") - - -@dataclass(kw_only=True) -class ReportSummary: +class ReportDocument(NamedTuple): relevance: float | None document_id: str pacs_name: str @@ -99,71 +25,12 @@ class ReportSummary: references: list[str] body: str - @staticmethod - def from_vespa_response(record: dict) -> "ReportSummary": - patient_birth_date = date.fromtimestamp(record["fields"]["patient_birth_date"]) - study_datetime = datetime.fromtimestamp(record["fields"]["study_datetime"]) - - return ReportSummary( - relevance=record["relevance"], - document_id=extract_document_id(record["id"]), - pacs_name=record["fields"]["pacs_name"], - patient_birth_date=patient_birth_date, - patient_sex=record["fields"]["patient_sex"], - study_description=record["fields"].get("study_description", ""), - study_datetime=study_datetime, - modalities_in_study=record["fields"].get("modalities_in_study", []), - references=record["fields"].get("references", []), - body=record["fields"]["body"], - ) - @property - def report_full(self) -> Report: + def full_report(self) -> Report: return Report.objects.get(document_id=self.document_id) -@dataclass -class ReportQuery: - total_count: int - coverage: float - documents: int - reports: list[ReportSummary] - - @staticmethod - def from_vespa_response(response: VespaQueryResponse): - json = response.json - return ReportQuery( - total_count=json["root"]["fields"]["totalCount"], - coverage=json["root"]["coverage"]["coverage"], - documents=json["root"]["coverage"]["documents"], - reports=[ReportSummary.from_vespa_response(hit) for hit in response.hits], - ) - - @staticmethod - def query_reports(query: str, offset: int = 0, page_size: int = 100) -> "ReportQuery": - client = vespa_app.get_client() - response = client.query( - { - "yql": "select * from report where userQuery()", - "query": query, - "type": "web", - "hits": page_size, - "offset": offset, - } - ) - return ReportQuery.from_vespa_response(response) - - -def handle_report(event_type: ReportEventType, report: Report): - # Sync reports with Vespa - if event_type == "created": - ReportDocument(report).create() - elif event_type == "updated": - ReportDocument(report).update() - elif event_type == "deleted": - ReportDocument(report).delete() - - -def fetch_document(report: Report) -> dict[str, Any]: - doc = ReportDocument(report).fetch() - return doc +class SearchResult(NamedTuple): + total_count: int | None + coverage: float | None + documents: list[ReportDocument] diff --git a/radis/search/serializers.py b/radis/search/serializers.py index aed69c3f..06b2da35 100644 --- a/radis/search/serializers.py +++ b/radis/search/serializers.py @@ -3,5 +3,6 @@ class SearchParamsSerializer(serializers.Serializer): query = serializers.CharField(default="") + algorithm = serializers.CharField(default="") page = serializers.IntegerField(min_value=1, default=1) per_page = serializers.IntegerField(min_value=1, max_value=100, default=25) diff --git a/radis/search/site.py b/radis/search/site.py new file mode 100644 index 00000000..484741d8 --- /dev/null +++ b/radis/search/site.py @@ -0,0 +1,35 @@ +from typing import Callable, NamedTuple + +from .models import SearchResult + + +class Search(NamedTuple): + query: str + offset: int = 0 + page_size: int = 10 + + +Searcher = Callable[[Search], SearchResult] + + +class SearchHandler(NamedTuple): + name: str + searcher: Searcher + template_name: str + + +search_handlers: dict[str, SearchHandler] = {} + + +def register_search_handler( + name: str, + searcher: Searcher, + template_name: str, +) -> None: + """Register a search handler. + + The name can be selected by the user in the search form. The searcher is called + when the user submits the form and returns the results. The template name is + the partial to be rendered as help below the search form. + """ + search_handlers[name] = SearchHandler(name, searcher, template_name) diff --git a/radis/search/templates/search/_result_report.html b/radis/search/templates/search/_result_document.html similarity index 73% rename from radis/search/templates/search/_result_report.html rename to radis/search/templates/search/_result_document.html index 310ca937..c1463f7d 100644 --- a/radis/search/templates/search/_result_report.html +++ b/radis/search/templates/search/_result_document.html @@ -2,8 +2,8 @@
{% include "search/_result_header.html" with counter=forloop.counter %} -
{{ report.body|safe }}
- {% include "reports/_report_buttons_panel.html" with report=report.report_full %} +
{{ document.body|safe }}
+ {% include "reports/_report_buttons_panel.html" with report=document.full_report %}
diff --git a/radis/search/templates/search/_result_header.html b/radis/search/templates/search/_result_header.html index ce1beba0..411a173b 100644 --- a/radis/search/templates/search/_result_header.html +++ b/radis/search/templates/search/_result_header.html @@ -2,18 +2,18 @@
- Age: {% calc_age report.patient_birth_date report.study_datetime %} - Sex: {{ report.patient_sex }} + Age: {% calc_age document.patient_birth_date document.study_datetime %} + Sex: {{ document.patient_sex }}
- Modalities: {{ report.modalities_in_study|join:", " }} - Study Description: {{ report.study_description }} + Modalities: {{ document.modalities_in_study|join:", " }} + Study Description: {{ document.study_description }}
Result: #{{ counter|add:offset }} - Relevance: {{ report.relevance|floatformat:3 }} + Relevance: {{ document.relevance|floatformat:3 }}
diff --git a/radis/search/templates/search/_search_form.html b/radis/search/templates/search/_search_form.html new file mode 100644 index 00000000..13677714 --- /dev/null +++ b/radis/search/templates/search/_search_form.html @@ -0,0 +1,33 @@ +{% load bootstrap_icon from core_extras %} +
+
+ Search + + + + +
+
diff --git a/radis/search/templates/search/_search_results.html b/radis/search/templates/search/_search_results.html new file mode 100644 index 00000000..0a3e3d30 --- /dev/null +++ b/radis/search/templates/search/_search_results.html @@ -0,0 +1,11 @@ +
+
+ {% if total_count %}Ranked {{ total_count }} report{{ total_count|pluralize }}.{% endif %} +
+ {% for document in documents %} + {% include "search/_result_document.html" %} + {% empty %} + + {% endfor %} + {% include "core/_pagination.html" %} +
diff --git a/radis/search/templates/search/search.html b/radis/search/templates/search/search.html index 8acf8c74..4c7e0d1d 100644 --- a/radis/search/templates/search/search.html +++ b/radis/search/templates/search/search.html @@ -1,37 +1,22 @@ {% extends "search/search_layout.html" %} {% load static from static %} -{% load crispy from crispy_forms_tags %} {% load bootstrap_icon from core_extras %} {% block heading %} -

Search reports

+

Search Reports

{% endblock heading %} {% block content %} -
-
-
- - -
-
- {% if query and total_count %} - Ranked {{ total_count }} report{{ total_count|pluralize }}. - {% endif %} -
-
+ {% include "search/_search_form.html" %} + {% if not selected_algorithm %} + + {% else %} {% if query %} - {% for report in reports %} - {% include "search/_result_report.html" %} - {% empty %} - - {% endfor %} + {% include "search/_search_results.html" %} {% else %} - +
{% include help_template_name %}
{% endif %} -
- {% include "core/_pagination.html" %} + {% endif %} {% endblock content %} diff --git a/radis/search/utils/search_utils.py b/radis/search/utils/search_utils.py deleted file mode 100644 index 5d903652..00000000 --- a/radis/search/utils/search_utils.py +++ /dev/null @@ -1,10 +0,0 @@ -from typing import Protocol, runtime_checkable - - -@runtime_checkable -class AnnotatedReport(Protocol): - total: int - - -def extract_document_id(id: str) -> str: - return id.split(":")[-1] diff --git a/radis/search/views.py b/radis/search/views.py index 7c2c08ed..db4ff334 100644 --- a/radis/search/views.py +++ b/radis/search/views.py @@ -1,3 +1,5 @@ +from typing import Any + from django.contrib.auth.mixins import LoginRequiredMixin from django.core.exceptions import BadRequest from django.core.paginator import Paginator @@ -6,8 +8,8 @@ from radis.core.types import AuthenticatedRequest -from .models import ReportQuery from .serializers import SearchParamsSerializer +from .site import Search, SearchHandler, search_handlers class SearchView(LoginRequiredMixin, View): @@ -18,22 +20,41 @@ def get(self, request: AuthenticatedRequest, *args, **kwargs): raise BadRequest("Invalid GET parameters.") query: str = serializer.validated_data["query"] + algorithm: str = serializer.validated_data["algorithm"] page_number: int = serializer.validated_data["page"] page_size: int = serializer.validated_data["per_page"] offset = (page_number - 1) * page_size - context = {} - if query: - result = ReportQuery.query_reports(query, offset, page_size) + search_handler: SearchHandler | None = None + + available_algorithms = sorted(list(search_handlers.keys())) + context: dict[str, Any] = {"available_algorithms": available_algorithms} + + if available_algorithms and algorithm: + search_handler = search_handlers.get(algorithm) + if search_handler: + context["selected_algorithm"] = algorithm + context["help_template_name"] = search_handler.template_name + + if available_algorithms and not search_handler: + algorithm = available_algorithms[0] + search_handler = search_handlers[available_algorithms[0]] + context["selected_algorithm"] = algorithm + context["help_template_name"] = search_handler.template_name + + if query and search_handler: + search = Search(query=query, offset=offset, page_size=page_size) + result = search_handler.searcher(search) total_count = result.total_count - paginator = Paginator(range(total_count), page_size) - page = paginator.get_page(page_number) + + if total_count is not None: + context["total_count"] = total_count + paginator = Paginator(range(total_count), page_size) + context["paginator"] = paginator + context["page_obj"] = paginator.get_page(page_number) context["query"] = query context["offset"] = offset - context["paginator"] = paginator - context["page_obj"] = page - context["total_count"] = total_count - context["reports"] = result.reports + context["documents"] = result.documents return render(request, "search/search.html", context) diff --git a/radis/settings/base.py b/radis/settings/base.py index 613372a1..a7e54573 100644 --- a/radis/settings/base.py +++ b/radis/settings/base.py @@ -68,6 +68,7 @@ "radis.search.apps.SearchConfig", "radis.collections.apps.CollectionsConfig", "radis.notes.apps.NotesConfig", + "radis.vespa.apps.VespaConfig", "channels", ] diff --git a/radis/search/management/__init__.py b/radis/vespa/__init__.py similarity index 100% rename from radis/search/management/__init__.py rename to radis/vespa/__init__.py diff --git a/radis/vespa/apps.py b/radis/vespa/apps.py new file mode 100644 index 00000000..aa8217db --- /dev/null +++ b/radis/vespa/apps.py @@ -0,0 +1,51 @@ +from typing import Any + +from django.apps import AppConfig + + +class VespaConfig(AppConfig): + name = "radis.vespa" + + def ready(self): + register_app() + + +def register_app(): + from radis.reports.models import Report + from radis.reports.site import ( + ReportEventType, + register_document_fetcher, + register_report_handler, + ) + from radis.search.models import SearchResult + from radis.search.site import Search, register_search_handler + from radis.vespa.utils.vespa_utils import ( + create_document, + delete_document, + fetch_document, + search_bm25, + update_document, + ) + + from .utils.vespa_utils import dictify_report_for_vespa + + def handle_report(event_type: ReportEventType, report: Report): + # Sync reports with Vespa + if event_type == "created": + create_document(report.document_id, dictify_report_for_vespa(report)) + elif event_type == "updated": + update_document(report.document_id, dictify_report_for_vespa(report)) + elif event_type == "deleted": + delete_document(report.document_id) + + register_report_handler(handle_report) + + def fetch_vespa_document(report: Report) -> dict[str, Any]: + return fetch_document(report.document_id) + + register_document_fetcher("vespa", fetch_vespa_document) + + def search_vespa_bm25(search: Search) -> SearchResult: + return search_bm25(search.query, search.offset, search.page_size) + + register_search_handler("Vespa BM25", search_vespa_bm25, "vespa/_bm25_help.html") diff --git a/radis/search/management/commands/__init__.py b/radis/vespa/management/__init__.py similarity index 100% rename from radis/search/management/commands/__init__.py rename to radis/vespa/management/__init__.py diff --git a/radis/vespa/management/commands/__init__.py b/radis/vespa/management/commands/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/radis/search/management/commands/setup_vespa.py b/radis/vespa/management/commands/setup_vespa.py similarity index 100% rename from radis/search/management/commands/setup_vespa.py rename to radis/vespa/management/commands/setup_vespa.py diff --git a/radis/vespa/templates/vespa/_bm25_help.html b/radis/vespa/templates/vespa/_bm25_help.html new file mode 100644 index 00000000..7087f7cd --- /dev/null +++ b/radis/vespa/templates/vespa/_bm25_help.html @@ -0,0 +1 @@ +TODO: The help of Vespa diff --git a/radis/vespa/utils/__init__.py b/radis/vespa/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/radis/vespa/utils/vespa_utils.py b/radis/vespa/utils/vespa_utils.py new file mode 100644 index 00000000..ef0fea29 --- /dev/null +++ b/radis/vespa/utils/vespa_utils.py @@ -0,0 +1,102 @@ +from datetime import date, datetime, time +from typing import Any + +from radis.reports.models import Report +from radis.search.models import ReportDocument, SearchResult + +from ..vespa_app import REPORT_SCHEMA_NAME, vespa_app + + +def dictify_report_for_vespa(report: Report) -> dict[str, Any]: + """Dictify the report for Vespa. + + Must be in the same format as schema in vespa_app.py + """ + # Vespa can't store dates and datetimes natively, so we store them as a number. + patient_birth_date = int(datetime.combine(report.patient_birth_date, time()).timestamp()) + study_datetime = int(report.study_datetime.timestamp()) + + return { + "groups": [group.id for group in report.groups.all()], + "pacs_aet": report.pacs_aet, + "pacs_name": report.pacs_name, + "patient_birth_date": patient_birth_date, + "patient_sex": report.patient_sex, + "study_description": report.study_description, + "study_datetime": study_datetime, + "modalities_in_study": report.modalities_in_study, + "references": report.references, + "body": report.body.strip(), + } + + +def fetch_document(document_id: str) -> dict[str, Any]: + response = vespa_app.get_client().get_data(REPORT_SCHEMA_NAME, document_id) + + if response.get_status_code() != 200: + message = response.get_json() + raise Exception(f"Error while fetching document from Vespa: {message}") + + return response.get_json() + + +def create_document(document_id: str, fields: dict[str, Any]) -> None: + response = vespa_app.get_client().feed_data_point(REPORT_SCHEMA_NAME, document_id, fields) + + if response.get_status_code() != 200: + message = response.get_json() + raise Exception(f"Error while feeding document to Vespa: {message}") + + +def update_document(document_id: str, fields: dict[str, Any]) -> None: + response = vespa_app.get_client().update_data(REPORT_SCHEMA_NAME, document_id, fields) + + if response.get_status_code() != 200: + message = response.get_json() + raise Exception(f"Error while updating document on Vespa: {message}") + + +def delete_document(document_id: str) -> None: + response = vespa_app.get_client().delete_data(REPORT_SCHEMA_NAME, document_id) + + if response.get_status_code() != 200: + message = response.get_json() + raise Exception(f"Error while deleting document on Vespa: {message}") + + +def document_from_vespa_response(record: dict[str, Any]) -> ReportDocument: + document_id = record["id"].split(":")[-1] + patient_birth_date = date.fromtimestamp(record["fields"]["patient_birth_date"]) + study_datetime = datetime.fromtimestamp(record["fields"]["study_datetime"]) + + return ReportDocument( + relevance=record["relevance"], + document_id=document_id, + pacs_name=record["fields"]["pacs_name"], + patient_birth_date=patient_birth_date, + patient_sex=record["fields"]["patient_sex"], + study_description=record["fields"].get("study_description", ""), + study_datetime=study_datetime, + modalities_in_study=record["fields"].get("modalities_in_study", []), + references=record["fields"].get("references", []), + body=record["fields"]["body"], + ) + + +def search_bm25(query: str, offset: int, page_size: int) -> SearchResult: + client = vespa_app.get_client() + response = client.query( + { + "yql": "select * from report where userQuery()", + "query": query, + "type": "web", + "hits": page_size, + "offset": offset, + } + ) + + return SearchResult( + total_count=response.json["root"]["fields"]["totalCount"], + coverage=response.json["root"]["coverage"]["coverage"], + documents=[document_from_vespa_response(hit) for hit in response.hits], + ) diff --git a/radis/search/vespa_app.py b/radis/vespa/vespa_app.py similarity index 100% rename from radis/search/vespa_app.py rename to radis/vespa/vespa_app.py