From 710d1db6cc3ebc59a39756c0f9ee511773514092 Mon Sep 17 00:00:00 2001 From: Martin Bernstorff Date: Tue, 9 Apr 2024 10:56:42 +0200 Subject: [PATCH] feat: add wikilinks to question output (#135) --- memorymarker/__main__.py | 10 +++++ memorymarker/question_generator/main.py | 14 ++++++- .../question_generator/qa_responses.py | 6 ++- .../steps/question_wikilinker.py | 42 +++++++++++++++++++ 4 files changed, 69 insertions(+), 3 deletions(-) create mode 100644 memorymarker/question_generator/steps/question_wikilinker.py diff --git a/memorymarker/__main__.py b/memorymarker/__main__.py index 587eeac..5148bfc 100644 --- a/memorymarker/__main__.py +++ b/memorymarker/__main__.py @@ -20,6 +20,7 @@ AnthropicCompleter, ) from memorymarker.question_generator.completers.openai_completer import ( + OpenAICompleter, OpenAIModelCompleter, ) from memorymarker.question_generator.flows.question_flow import QuestionFlow @@ -27,6 +28,9 @@ from memorymarker.question_generator.qa_responses import QAResponses from memorymarker.question_generator.steps.qa_extractor import QuestionExtractionStep from memorymarker.question_generator.steps.qa_generation import QuestionGenerationStep +from memorymarker.question_generator.steps.question_wikilinker import ( + QuestionWikilinkerStep, +) from memorymarker.question_generator.steps.reasoning import ReasoningStep app = typer.Typer(no_args_is_help=True) @@ -150,6 +154,12 @@ def typer_cli( response_model=QAResponses, # type: ignore ) ), + QuestionWikilinkerStep( + completer=OpenAICompleter( + api_key=os.getenv("OPENAI_API_KEY", "No OPENAI_API"), + model="gpt-4-turbo-preview", + ) + ), ), )(chunked_highlights[0:max_n]) ) diff --git a/memorymarker/question_generator/main.py b/memorymarker/question_generator/main.py index efb1975..7b15bfb 100644 --- a/memorymarker/question_generator/main.py +++ b/memorymarker/question_generator/main.py @@ -13,6 +13,7 @@ AnthropicCompleter, ) from memorymarker.question_generator.completers.openai_completer import ( + OpenAICompleter, OpenAIModelCompleter, ) from memorymarker.question_generator.example_repo_airtable import ( @@ -25,6 +26,9 @@ from memorymarker.question_generator.qa_responses import QAResponses from memorymarker.question_generator.steps.qa_extractor import QuestionExtractionStep from memorymarker.question_generator.steps.qa_generation import QuestionGenerationStep +from memorymarker.question_generator.steps.question_wikilinker import ( + QuestionWikilinkerStep, +) from memorymarker.question_generator.steps.reasoning import ReasoningStep if TYPE_CHECKING: @@ -100,7 +104,7 @@ async def main(): # "stack is a data structure that contains a collection of elements where you can add and delete elements from just one end ", # "A semaphore manages an internal counter", # } - document_titles = {"Singly Linked List", "Jeg har set mit køns smerte"} + document_titles = {"Singly Linked List"} input_highlights = _select_highlights_from_omnivore() selected_highlights = input_highlights.filter( lambda _: any(title in _.source_document.title for title in document_titles) @@ -126,7 +130,7 @@ async def main(): grouped_highlights, [ QuestionFlow( - _name="chunked_reasoning", + _name="chunked_reasoning_with_wikilinks", steps=( ReasoningStep(completer=base_completer), QuestionGenerationStep( @@ -139,6 +143,12 @@ async def main(): response_model=QAResponses, # type: ignore ) ), + QuestionWikilinkerStep( + completer=OpenAICompleter( + api_key=os.getenv("OPENAI_API_KEY", "No OPENAI_API"), + model="gpt-4-turbo-preview", + ) + ), ), ) ], diff --git a/memorymarker/question_generator/qa_responses.py b/memorymarker/question_generator/qa_responses.py index 038dd7d..f1d9426 100644 --- a/memorymarker/question_generator/qa_responses.py +++ b/memorymarker/question_generator/qa_responses.py @@ -8,7 +8,7 @@ from memorymarker.question_generator.reasoned_highlight import Highlights -@dataclass(frozen=True) +@dataclass class QAPrompt: hydrated_highlight: "Highlights | None" question: str @@ -34,5 +34,9 @@ def to_qaprompt(self, reasoned_highlight: "Highlights") -> QAPrompt: ) +class QuestionResponseModel(BaseModel): + question: str + + class QAResponses(pydantic.BaseModel): items: Sequence[QAPromptResponseModel] diff --git a/memorymarker/question_generator/steps/question_wikilinker.py b/memorymarker/question_generator/steps/question_wikilinker.py new file mode 100644 index 0000000..d53bb95 --- /dev/null +++ b/memorymarker/question_generator/steps/question_wikilinker.py @@ -0,0 +1,42 @@ +import asyncio +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from memorymarker.question_generator.steps.step import FlowStep + +if TYPE_CHECKING: + from memorymarker.question_generator.completers.completer import Completer + from memorymarker.question_generator.qa_responses import QAPrompt + from memorymarker.question_generator.reasoned_highlight import Highlights + + +@dataclass(frozen=True) +class QuestionWikilinkerStep(FlowStep): + completer: "Completer" + prompt = """In the following, identify the important, domain-specific terms. Then, capitalise them, and surround them with wikilinks. There can be more than one important term. Identify terms as you would in a wikipedia article. + +E.g.: +When working with version control, why is the git amend command misleading? + +Turns into: +When working with [[Version control]], why is the [[Git amend]] command misleading? +Here is the question: +{question} +""" + + def identity(self) -> str: + return f"{self.__class__.__name__}_{self.completer.identity()}" + + async def _wikilink_prompt(self, question: "QAPrompt") -> "QAPrompt": + prompt = self.prompt.format(question=question.question) + response = await self.completer(prompt) # type: ignore + question.question = response + return question + + async def __call__(self, highlight: "Highlights") -> "Highlights": + prompts = highlight.question_answer_pairs + wikilinked_prompts = await asyncio.gather( + *[self._wikilink_prompt(prompt) for prompt in prompts] + ) + highlight.question_answer_pairs = wikilinked_prompts + return highlight