Skip to content

Commit

Permalink
Modular search with separate vespa app
Browse files Browse the repository at this point in the history
  • Loading branch information
medihack committed Feb 7, 2024
1 parent 6a2a444 commit 083e969
Show file tree
Hide file tree
Showing 25 changed files with 313 additions and 233 deletions.
18 changes: 9 additions & 9 deletions notebooks/radis_api.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"output_type": "stream",
"text": [
"Status Code: 201\n",
"{'id': 104, 'document_id': 'gepacs_3dfidii5858-6633i4-ii398841', 'pacs_aet': 'gepacs', 'pacs_name': 'GE PACS', 'patient_id': '1234578', 'patient_birth_date': '1976-05-23', 'patient_sex': 'M', 'study_instance_uid': '34343-34343-34343', 'accession_number': '345348389', 'study_description': 'CT of the Thorax', 'study_datetime': '2000-08-10T00:00:00+02:00', 'series_instance_uid': '34343-676556-3343', 'modalities_in_study': ['CT', 'PET'], 'sop_instance_uid': '35858-384834-3843', 'references': ['http://gepacs.com/34343-34343-34343'], 'body': 'This is the report', 'groups': [2]}\n"
"{'id': 101, 'document_id': 'gepacs_3dfidii5858-6633i4-ii398841', 'pacs_aet': 'gepacs', 'pacs_name': 'GE PACS', 'patient_id': '1234578', 'patient_birth_date': '1976-05-23', 'patient_sex': 'M', 'study_instance_uid': '34343-34343-34343', 'accession_number': '345348389', 'study_description': 'CT of the Thorax', 'study_datetime': '2000-08-10T00:00:00+02:00', 'series_instance_uid': '34343-676556-3343', 'modalities_in_study': ['CT', 'PET'], 'sop_instance_uid': '35858-384834-3843', 'references': ['http://gepacs.com/34343-34343-34343'], 'body': 'This is the report', 'groups': [2]}\n"
]
}
],
Expand Down Expand Up @@ -53,13 +53,13 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'id': 102,\n",
"{'id': 101,\n",
" 'document_id': 'gepacs_3dfidii5858-6633i4-ii398841',\n",
" 'pacs_aet': 'gepacs',\n",
" 'pacs_name': 'GE PACS',\n",
Expand All @@ -78,7 +78,7 @@
" 'groups': [2]}"
]
},
"execution_count": 10,
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -114,13 +114,13 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'id': 102,\n",
"{'id': 101,\n",
" 'document_id': 'gepacs_3dfidii5858-6633i4-ii398841',\n",
" 'pacs_aet': 'gepacs',\n",
" 'pacs_name': 'GE PACS',\n",
Expand Down Expand Up @@ -151,7 +151,7 @@
" 'study_datetime': 965858400}}}"
]
},
"execution_count": 11,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -166,7 +166,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 4,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -201,7 +201,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
"version": "3.12.1"
},
"orig_nbformat": 4
},
Expand Down
28 changes: 8 additions & 20 deletions radis/core/management/commands/populate_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@

from django.conf import settings
from django.contrib.auth.models import Group, Permission
from django.core.management import call_command
from django.core.management.base import BaseCommand, CommandParser
from django.core.management.base import BaseCommand
from faker import Faker

from radis.accounts.factories import AdminUserFactory, GroupFactory, UserFactory
from radis.accounts.models import User
from radis.reports.factories import ReportFactory
from radis.search.models import ReportDocument
from radis.search.vespa_app import vespa_app
from radis.reports.models import Report
from radis.reports.site import report_event_handlers
from radis.token_authentication.factories import TokenFactory
from radis.token_authentication.models import FRACTION_LENGTH
from radis.token_authentication.utils.crypto import hash_token
Expand All @@ -34,7 +33,8 @@ def feed_report(body: str):
report = ReportFactory.create(body=body)
groups = fake.random_elements(elements=list(Group.objects.all()), unique=True)
report.groups.set(groups)
ReportDocument(report).create()
for handler in report_event_handlers:
handler("created", report)


def feed_reports():
Expand Down Expand Up @@ -107,28 +107,16 @@ def create_groups(users: list[User]) -> list[Group]:
class Command(BaseCommand):
help = "Populates the database with example data."

def add_arguments(self, parser: CommandParser) -> None:
parser.add_argument("--reset", action="store_true")

def handle(self, *args, **options):
if options["reset"]:
# Can only be done when dev server is not running and needs django_extensions installed
call_command("reset_db", "--noinput")
call_command("migrate")
vespa_app.get_client().delete_all_docs("radis_content", "report")

if User.objects.count() > 0:
print("Development database already populated. Skipping.")
else:
print("Populating development database with test data.")
users = create_users()
create_groups(users)

results = vespa_app.get_client().query(
{"yql": "select * from sources * where true", "hits": 1}
)
if results.number_documents_retrieved > 0:
print("Vespa already populated. Skipping.")
if Report.objects.first():
print("Reports already populated. Skipping.")
else:
print("Populating Vespa with example reports.")
print("Populating database with example reports.")
feed_reports()
2 changes: 1 addition & 1 deletion radis/reports/api/viewsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def retrieve(self, request: Request, *args: Any, **kwargs: Any) -> Response:
instance: Report = self.get_object()

extra = {}
for fetcher in document_fetchers:
for fetcher in document_fetchers.values():
document = fetcher.fetch(instance)
if document:
extra[fetcher.source] = document
Expand Down
4 changes: 2 additions & 2 deletions radis/reports/site.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class DocumentFetcher(NamedTuple):
fetch: FetchDocument


document_fetchers: list[DocumentFetcher] = []
document_fetchers: dict[str, DocumentFetcher] = {}


def register_document_fetcher(source: str, fetch: FetchDocument) -> None:
Expand All @@ -38,7 +38,7 @@ def register_document_fetcher(source: str, fetch: FetchDocument) -> None:
database and returns a document in the form of a dictionary from another
database (like Vespa).
"""
document_fetchers.append(DocumentFetcher(source, fetch))
document_fetchers[source] = DocumentFetcher(source, fetch)


class ReportPanelButton(NamedTuple):
Expand Down
6 changes: 0 additions & 6 deletions radis/search/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,12 @@ def ready(self):

def register_app():
from radis.core.site import register_main_menu_item
from radis.reports.site import register_document_fetcher, register_report_handler

from .models import fetch_document, handle_report

register_main_menu_item(
url_name="search",
label="Search",
)

register_report_handler(handle_report)
register_document_fetcher("vespa", fetch_document)


def init_db(**kwargs):
create_app_settings()
Expand Down
149 changes: 8 additions & 141 deletions radis/search/models.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,9 @@
import logging
from dataclasses import dataclass
from datetime import date, datetime, time
from typing import Any, Literal

from rest_framework.status import HTTP_200_OK
from vespa.io import VespaQueryResponse
from datetime import date, datetime
from typing import Literal, NamedTuple

from radis.core.models import AppSettings
from radis.reports.models import Report
from radis.reports.site import ReportEventType

from .utils.search_utils import extract_document_id
from .vespa_app import REPORT_SCHEMA_NAME, vespa_app

logger = logging.getLogger(__name__)

Expand All @@ -21,73 +13,7 @@ class Meta:
verbose_name_plural = "Search app settings"


class ReportDocument:
def __init__(self, report: Report) -> None:
self.report = report

def _dictify_for_vespa(self) -> dict[str, Any]:
"""Dictify the report for Vespa.
Must be in the same format as schema in vespa_app.py
"""
# Vespa can't store dates and datetimes natively, so we store them as a number.
patient_birth_date = int(
datetime.combine(self.report.patient_birth_date, time()).timestamp()
)
study_datetime = int(self.report.study_datetime.timestamp())

return {
"groups": [group.id for group in self.report.groups.all()],
"pacs_aet": self.report.pacs_aet,
"pacs_name": self.report.pacs_name,
"patient_birth_date": patient_birth_date,
"patient_sex": self.report.patient_sex,
"study_description": self.report.study_description,
"study_datetime": study_datetime,
"modalities_in_study": self.report.modalities_in_study,
"references": self.report.references,
"body": self.report.body.strip(),
}

def fetch(self) -> dict[str, Any]:
response = vespa_app.get_client().get_data(REPORT_SCHEMA_NAME, self.report.document_id)

if response.get_status_code() != HTTP_200_OK:
message = response.get_json()
raise Exception(f"Error while fetching report from Vespa: {message}")

return response.get_json()

def create(self) -> None:
fields = self._dictify_for_vespa()
response = vespa_app.get_client().feed_data_point(
REPORT_SCHEMA_NAME, self.report.document_id, fields
)

if response.get_status_code() != HTTP_200_OK:
message = response.get_json()
raise Exception(f"Error while feeding report to Vespa: {message}")

def update(self) -> None:
fields = self._dictify_for_vespa()
response = vespa_app.get_client().update_data(
REPORT_SCHEMA_NAME, self.report.document_id, fields
)

if response.get_status_code() != HTTP_200_OK:
message = response.get_json()
raise Exception(f"Error while updating report on Vespa: {message}")

def delete(self) -> None:
response = vespa_app.get_client().delete_data(REPORT_SCHEMA_NAME, self.report.document_id)

if response.get_status_code() != HTTP_200_OK:
message = response.get_json()
raise Exception(f"Error while deleting report on Vespa: {message}")


@dataclass(kw_only=True)
class ReportSummary:
class ReportDocument(NamedTuple):
relevance: float | None
document_id: str
pacs_name: str
Expand All @@ -99,71 +25,12 @@ class ReportSummary:
references: list[str]
body: str

@staticmethod
def from_vespa_response(record: dict) -> "ReportSummary":
patient_birth_date = date.fromtimestamp(record["fields"]["patient_birth_date"])
study_datetime = datetime.fromtimestamp(record["fields"]["study_datetime"])

return ReportSummary(
relevance=record["relevance"],
document_id=extract_document_id(record["id"]),
pacs_name=record["fields"]["pacs_name"],
patient_birth_date=patient_birth_date,
patient_sex=record["fields"]["patient_sex"],
study_description=record["fields"].get("study_description", ""),
study_datetime=study_datetime,
modalities_in_study=record["fields"].get("modalities_in_study", []),
references=record["fields"].get("references", []),
body=record["fields"]["body"],
)

@property
def report_full(self) -> Report:
def full_report(self) -> Report:
return Report.objects.get(document_id=self.document_id)


@dataclass
class ReportQuery:
total_count: int
coverage: float
documents: int
reports: list[ReportSummary]

@staticmethod
def from_vespa_response(response: VespaQueryResponse):
json = response.json
return ReportQuery(
total_count=json["root"]["fields"]["totalCount"],
coverage=json["root"]["coverage"]["coverage"],
documents=json["root"]["coverage"]["documents"],
reports=[ReportSummary.from_vespa_response(hit) for hit in response.hits],
)

@staticmethod
def query_reports(query: str, offset: int = 0, page_size: int = 100) -> "ReportQuery":
client = vespa_app.get_client()
response = client.query(
{
"yql": "select * from report where userQuery()",
"query": query,
"type": "web",
"hits": page_size,
"offset": offset,
}
)
return ReportQuery.from_vespa_response(response)


def handle_report(event_type: ReportEventType, report: Report):
# Sync reports with Vespa
if event_type == "created":
ReportDocument(report).create()
elif event_type == "updated":
ReportDocument(report).update()
elif event_type == "deleted":
ReportDocument(report).delete()


def fetch_document(report: Report) -> dict[str, Any]:
doc = ReportDocument(report).fetch()
return doc
class SearchResult(NamedTuple):
total_count: int | None
coverage: float | None
documents: list[ReportDocument]
1 change: 1 addition & 0 deletions radis/search/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@

class SearchParamsSerializer(serializers.Serializer):
query = serializers.CharField(default="")
algorithm = serializers.CharField(default="")
page = serializers.IntegerField(min_value=1, default=1)
per_page = serializers.IntegerField(min_value=1, max_value=100, default=25)
35 changes: 35 additions & 0 deletions radis/search/site.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Callable, NamedTuple

from .models import SearchResult


class Search(NamedTuple):
query: str
offset: int = 0
page_size: int = 10


Searcher = Callable[[Search], SearchResult]


class SearchHandler(NamedTuple):
name: str
searcher: Searcher
template_name: str


search_handlers: dict[str, SearchHandler] = {}


def register_search_handler(
name: str,
searcher: Searcher,
template_name: str,
) -> None:
"""Register a search handler.
The name can be selected by the user in the search form. The searcher is called
when the user submits the form and returns the results. The template name is
the partial to be rendered as help below the search form.
"""
search_handlers[name] = SearchHandler(name, searcher, template_name)
Loading

0 comments on commit 083e969

Please sign in to comment.