Skip to content

Commit

Permalink
Allow to set query language in Vespa and add german example reports
Browse files Browse the repository at this point in the history
  • Loading branch information
medihack committed Feb 17, 2024
1 parent eecc8c8 commit d37a5d1
Show file tree
Hide file tree
Showing 10 changed files with 328 additions and 210 deletions.
2 changes: 1 addition & 1 deletion compose/docker-compose.dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ services:
./manage.py collectstatic --no-input &&
wait-for-it -s vespa.local:19071 -t 60 &&
./manage.py setup_vespa --generate --deploy &&
./manage.py populate_db &&
./manage.py populate_db --report-language de &&
./manage.py runserver 0.0.0.0:8000
"
profiles:
Expand Down
3 changes: 2 additions & 1 deletion example.env
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ SSL_IP_ADDRESSES=127.0.0.1
USER_TIME_ZONE=Europe/Berlin
FORCE_DEBUG_TOOLBAR=false
BACKUP_DIR=/mnt/backups
OPENAI_API_KEY=xxx
VESPA_QUERY_LANGUAGE=de
OPENAI_API_KEY=

# Docker swarm mode does not respect the Docker Proxy client configuration
# (see https://docs.docker.com/network/proxy/#configure-the-docker-client).
Expand Down
48 changes: 9 additions & 39 deletions notebooks/openai.ipynb → notebooks/generate_reports.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,47 +21,17 @@
"\n",
"api_key = config[\"OPENAI_API_KEY\"]\n",
"\n",
"generator = ReportGenerator(api_key)\n",
"generator = ReportGenerator(api_key, language=\"en\")\n",
"\n",
"reports = generator.generate_reports(10)\n",
"reports = []\n",
"for _ in range(100):\n",
" report = generator.generate_report()\n",
" reports.append(report)\n",
" generator.reset_context()\n",
"\n",
"reports"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Clean report texts\n",
"\n",
"import re\n",
"\n",
"items = reports.copy()\n",
"\n",
"lines_to_strip = [\n",
" \"Radiology Report\",\n",
" \"Patient Information:\",\n",
" \"Referring Physician:\",\n",
" \"Patient:\",\n",
" \"Patient ID:\" \"Patient Name:\",\n",
" \"Name:\",\n",
" \"Date of Birth:\",\n",
" \"Gender:\",\n",
" \"Age:\",\n",
" \"Note:\",\n",
" \"Signed:\",\n",
" \"Date:\",\n",
"]\n",
"for line_start in lines_to_strip:\n",
" items = [re.sub(rf\"{line_start}.*\\n\", \"\", item) for item in items]\n",
"\n",
"items = [item.strip() for item in items]\n",
"\n",
"items"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -72,9 +42,9 @@
"\n",
"import json\n",
"\n",
"json_data = json.dumps(items, indent=4)\n",
"json_data = json.dumps(reports, indent=4)\n",
"\n",
"with open(\"../samples/reports.json\", \"w\") as outfile:\n",
"with open(\"../samples/reports_en.json\", \"w\") as outfile:\n",
" outfile.write(json_data)"
]
}
Expand All @@ -95,7 +65,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.1"
"version": "3.12.1"
},
"orig_nbformat": 4
},
Expand Down
25 changes: 19 additions & 6 deletions radis/core/management/commands/populate_db.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from os import environ
from pathlib import Path
from typing import Literal

from django.conf import settings
from django.contrib.auth.models import Group, Permission
Expand Down Expand Up @@ -28,21 +29,28 @@
fake = Faker()


def feed_report(body: str):
report = ReportFactory.create(body=body)
def feed_report(body: str, language: Literal["en", "de"]):
report = ReportFactory.create(language=language, body=body)
groups = fake.random_elements(elements=list(Group.objects.all()), unique=True)
report.groups.set(groups)
for handler in report_event_handlers:
handler("created", report)


def feed_reports():
samples_path = Path(settings.BASE_DIR / "samples" / "reports.json")
def feed_reports(language: Literal["en", "de"]):
if language == "en":
sample_file = "reports_en.json"
elif language == "de":
sample_file = "reports_de.json"
else:
raise ValueError(f"Language {language} is not supported.")

samples_path = Path(settings.BASE_DIR / "samples" / sample_file)
with open(samples_path, "r") as f:
reports = json.load(f)

for report in reports:
feed_report(report)
feed_report(report, language)


def create_admin() -> User:
Expand Down Expand Up @@ -123,6 +131,11 @@ def add_arguments(self, parser: CommandParser) -> None:
action="store_true",
help="Skip populating the database with example reports.",
)
parser.add_argument(
"--report-language",
default="en",
help="Which report language to use (en or de).",
)

def handle(self, *args, **options):
if User.objects.count() > 0:
Expand All @@ -137,4 +150,4 @@ def handle(self, *args, **options):
print("Reports already populated. Skipping.")
else:
print("Populating database with example reports.")
feed_reports()
feed_reports(options["report_language"])
100 changes: 57 additions & 43 deletions radis/core/utils/report_generator.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,35 @@
from random import randint
from time import sleep
from typing import Any, Callable
from typing import Any, Callable, Literal

import openai
import tiktoken
from openai import OpenAI
from openai.types.chat import ChatCompletionMessageParam

ReportGeneratedCallback = Callable[[str, int], None]
INITIAL_SYSTEM_PROMPT = {
"de": "Du bist ein radiologischer Facharzt.",
"en": "You are a senior radiologist.",
}

INITIAL_INSTRUCTION = {
"de": "Schreibe einen radiologischen Befund als Beispiel.",
"en": "Write an example radiology report.",
}

FOLLOWUP_INSTRUCTION = {
"de": "Schreibe einen weiteren radiologischen Befund als Beispiel.",
"en": "Write another example radiology report.",
}


def num_tokens_from_messages(
messages: list[ChatCompletionMessageParam], model="gpt-3.5-turbo-0613", silent=False
):
"""Return the number of tokens used by a list of messages.
# from https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
def num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613", silent=False):
"""Return the number of tokens used by a list of messages."""
From https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken
"""
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
Expand Down Expand Up @@ -52,84 +70,80 @@ def num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613", silent=False)
for message in messages:
num_tokens += tokens_per_message
for key, value in message.items():
assert isinstance(value, str)
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens


ReportGeneratedCallback = Callable[[str, int], None]


class ReportGenerator:
def __init__(
self,
api_key: str,
model="gpt-3.5-turbo",
max_tokens=4096,
model: str = "gpt-3.5-turbo",
max_tokens: int = 4096,
language: Literal["en", "de"] = "en",
callback: ReportGeneratedCallback | None = None,
) -> None:
self.api_key = api_key
self.model = model
self.max_tokens = max_tokens
self.callback = callback
self._client = OpenAI(api_key=api_key)
self._model = model
self._max_tokens = max_tokens
self._language = language
self._callback = callback

self.messages: list[ChatCompletionMessageParam] = [
{"role": "system", "content": "You are an senior radiologist."}
{"role": "system", "content": INITIAL_SYSTEM_PROMPT[self._language]}
]

def reset_context(self, full_reset=False):
if len(self.messages) < 3 or full_reset:
self.messages = [self.messages[0]]
else:
# Retain first question and last answer.
self.messages = [self.messages[0], self.messages[1], self.messages[-1]]

def generate_report(self, freshly=False) -> str:
if freshly:
self.reset_context(full_reset=True)
self.messages = [self.messages[0]]

def generate_report(self) -> str:
if len(self.messages) == 1:
self.messages.append({"role": "user", "content": "Write an example radiology report."})
self.messages.append({"role": "user", "content": INITIAL_INSTRUCTION[self._language]})
else:
self.messages.append(
{"role": "user", "content": "Write another example radiology report."}
)

token_count = num_tokens_from_messages(self.messages, self.model, silent=True)
if token_count > self.max_tokens:
self.reset_context()
self.messages.append({"role": "user", "content": FOLLOWUP_INSTRUCTION[self._language]})

token_count = num_tokens_from_messages(self.messages, self._model, silent=True)
if token_count > self._max_tokens:
# Retain system prompt, initial instruction, last answer and last instruction
assert len(self.messages) >= 4
self.messages = [
self.messages[0],
self.messages[1],
self.messages[-2],
self.messages[-1],
]

response: Any = None
retries = 0
while not response:
try:
response = openai.chat.completions.create(
response = self._client.chat.completions.create(
messages=self.messages,
model=self.model,
model=self._model,
# model=self.model, messages=self.messages, api_key=self.api_key
)
# For available errors see https://github.com/openai/openai-python#handling-errors
except openai.APIStatusError as err:
retries += 1
if retries == 3:
print("Error! Service unavailable even after 3 retries.")
print(f"Error! Service unavailable even after 3 retries: {err}")
raise err

# maybe use rate limiter like https://github.com/tomasbasham/ratelimit
sleep(randint(3, 10))

answer = response.choices[0].message.content

if self.callback:
token_count = num_tokens_from_messages(self.messages, self.model, silent=True)
self.callback(answer, token_count)
if self._callback:
token_count = num_tokens_from_messages(self.messages, self._model, silent=True)
self._callback(answer, token_count)

self.messages.append({"role": "assistant", "content": answer})
return answer

def generate_reports(self, num: int, freshly=False) -> list[str]:
reports = []
for i in range(num):
report = self.generate_report(freshly=freshly)
reports.append(report)

return reports
5 changes: 5 additions & 0 deletions radis/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,11 @@
VESPA_CONFIG_PORT = env.int("VESPA_CONFIG_PORT", default=19071) # type: ignore
VESPA_DATA_PORT = env.int("VESPA_DATA_PORT", default=8080) # type: ignore

# The language of the VESPA Query. If set to "auto" (the default) it will let Vespa
# try to autodetect it (not a good idea because of the mostly small query strings).
# It should be set to the same language as the report were indexed (the language
# field of the report model). Examples: en, de, es, fr, it
VESPA_QUERY_LANGUAGE = env.str("VESPA_QUERY_LANGUAGE", default="auto") # type: ignore

# A timezone that is used for users of the web interface.
USER_TIME_ZONE = env.str("USER_TIME_ZONE", default="Europe/Berlin") # type: ignore
Expand Down
49 changes: 31 additions & 18 deletions radis/vespa/utils/search_methods.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
from django.conf import settings

from radis.search.models import SearchResult

from ..vespa_app import vespa_app
from .document_utils import document_from_vespa_response


def search_bm25(query: str, offset: int, page_size: int) -> SearchResult:
params = {
"yql": "select * from sources * where userQuery()",
"query": query,
"type": "web",
"hits": page_size,
"offset": offset,
"ranking": "bm25",
}

if settings.VESPA_QUERY_LANGUAGE != "auto":
params["language"] = settings.VESPA_QUERY_LANGUAGE

client = vespa_app.get_client()
response = client.query(
yql="select * from report where userQuery()",
query=query,
type="web",
hits=page_size,
offset=offset,
ranking="bm25",
)
response = client.query(**params)

return SearchResult(
total_count=response.json["root"]["fields"]["totalCount"],
Expand All @@ -24,17 +31,23 @@ def search_bm25(query: str, offset: int, page_size: int) -> SearchResult:

# https://pyvespa.readthedocs.io/en/latest/getting-started-pyvespa.html#Hybrid-search-with-the-OR-query-operator
def search_hybrid(query: str, offset: int, page_size: int) -> SearchResult:
client = vespa_app.get_client()
response = client.query(
yql="select * from sources * where userQuery() or \
params = {
"yql": "select * from sources * where userQuery() or \
({targetHits:1000}nearestNeighbor(embedding,q))",
query=query,
type="web",
hits=page_size,
offset=offset,
ranking="fusion",
body={"input.query(q)": f"embed({query})"},
)
"query": query,
"type": "web",
"hits": page_size,
"offset": offset,
"ranking": "fusion",
"body": {"input.query(q)": f"embed({query})"},
}

if settings.VESPA_QUERY_LANGUAGE != "auto":
params["language"] = settings.VESPA_QUERY_LANGUAGE

client = vespa_app.get_client()
response = client.query(**params)

return SearchResult(
total_count=response.json["root"]["fields"]["totalCount"],
coverage=response.json["root"]["coverage"]["coverage"],
Expand Down
Loading

0 comments on commit d37a5d1

Please sign in to comment.