From aa7079c815e892f3ba3c1b5adcd628b9b4dd07f0 Mon Sep 17 00:00:00 2001 From: Martin Bernstorff Date: Sun, 10 Dec 2023 18:49:32 +0000 Subject: [PATCH 1/4] feat: implement QA extractor Fixes #294 --- .../base_prompt_extractor.py | 5 +- .../prompt_extractors/qa_prompt_extractor.py | 95 +++++++++++++++++++ .../test_qa_prompt_extractor.py | 24 +++++ .../v2/domain/prompts/qa_prompt.py | 12 ++- 4 files changed, 132 insertions(+), 4 deletions(-) create mode 100644 personal_mnemonic_medium/v2/domain/prompt_source/prompt_extractors/qa_prompt_extractor.py create mode 100644 personal_mnemonic_medium/v2/domain/prompt_source/prompt_extractors/test_qa_prompt_extractor.py diff --git a/personal_mnemonic_medium/v2/domain/prompt_source/prompt_extractors/base_prompt_extractor.py b/personal_mnemonic_medium/v2/domain/prompt_source/prompt_extractors/base_prompt_extractor.py index 0761eff0..7bd6df3a 100644 --- a/personal_mnemonic_medium/v2/domain/prompt_source/prompt_extractors/base_prompt_extractor.py +++ b/personal_mnemonic_medium/v2/domain/prompt_source/prompt_extractors/base_prompt_extractor.py @@ -2,9 +2,12 @@ from typing import Protocol from ...prompts.base_prompt import BasePrompt +from ..document_ingesters.document import Document # TODO: https://github.com/MartinBernstorff/personal-mnemonic-medium/issues/294 Implement QA and Cloze promptextractors class BasePromptExtractor(Protocol): - def extract_prompts(self, content: str) -> Sequence[BasePrompt]: + def extract_prompts( + self, document: Document + ) -> Sequence[BasePrompt]: ... diff --git a/personal_mnemonic_medium/v2/domain/prompt_source/prompt_extractors/qa_prompt_extractor.py b/personal_mnemonic_medium/v2/domain/prompt_source/prompt_extractors/qa_prompt_extractor.py new file mode 100644 index 00000000..e16f69c6 --- /dev/null +++ b/personal_mnemonic_medium/v2/domain/prompt_source/prompt_extractors/qa_prompt_extractor.py @@ -0,0 +1,95 @@ +import logging +import re +from collections.abc import Sequence + +from ...prompts.qa_prompt import QAPrompt, QAPromptFromDoc +from ..document_ingesters.document import Document +from .base_prompt_extractor import BasePromptExtractor + +log = logging.getLogger(__name__) + + +class QAPromptExtractor(BasePromptExtractor): + def __init__( + self, question_prefix: str, answer_prefix: str + ) -> None: + self.question_prefix = question_prefix + self.answer_prefix = answer_prefix + + def _get_first_question(self, content: str) -> str: + question = re.findall( + self.question_prefix + r"{0,1}\.(?:(?!A\.).)*", + content, + flags=re.DOTALL, + )[0] + + return question[len(self.question_prefix) + 1 :].rstrip() + + def _get_first_answer(self, content: str) -> str: + # Have to use positive lookahead to match code-blocks + # To ensure the last answer is matched as well, we add 2 newlines to string. + string_padded = f"{content.rstrip()}\n\n" + + answer = re.findall( + r"\n" + self.answer_prefix + r"[ \n]+\n*.+", + string_padded, + re.DOTALL, + )[0] + + return answer[len(self.answer_prefix) + 2 :].rstrip() + + @staticmethod + def _string_to_blocks_by_newlines(string: str) -> list[str]: + """Break string into a list by 2+ newlines in a row.""" + return re.split(r"(\n\n)+", string) + + def _has_qa(self, string: str) -> bool: + """Check whether a string contains a qa prompt""" + return ( + len( + re.findall( + r"^(?![:>]).*" + + self.question_prefix + + r"{0,1}\. ", + string, + flags=re.DOTALL, + ) + ) + != 0 + ) + + def extract_prompts( + self, document: Document + ) -> Sequence[QAPrompt]: + prompts: list[QAPrompt] = [] + + blocks = self._string_to_blocks_by_newlines(document.content) + + block_starting_line_nr = 1 + + for block_string in blocks: + if self._has_qa(block_string): + question = self._get_first_question(block_string) + try: + answer = self._get_first_answer(block_string) + except IndexError: + logging.warn( + f"Could not find answer in {document.title} for {question}" + ) + continue + + prompts.append( + QAPromptFromDoc( + question=question, + answer=answer, + parent_doc=document, + line_nr=block_starting_line_nr, + ) + ) + + block_lines = len( + re.findall(r"\n", block_string, flags=re.DOTALL) + ) + block_starting_line_nr += block_lines + + return prompts diff --git a/personal_mnemonic_medium/v2/domain/prompt_source/prompt_extractors/test_qa_prompt_extractor.py b/personal_mnemonic_medium/v2/domain/prompt_source/prompt_extractors/test_qa_prompt_extractor.py new file mode 100644 index 00000000..2cb1a648 --- /dev/null +++ b/personal_mnemonic_medium/v2/domain/prompt_source/prompt_extractors/test_qa_prompt_extractor.py @@ -0,0 +1,24 @@ +from pathlib import Path + +from ..document_ingesters.document import Document +from .qa_prompt_extractor import QAPromptExtractor + + +def test_qa_prompt_extractor(tmpdir: Path): + doc = Document( + content=""" + +Q. What is the meaning of life? +A. 42 + +""", + source_path=tmpdir / "test.md", + ) + + extractor = QAPromptExtractor( + question_prefix="Q.", answer_prefix="A." + ).extract_prompts(doc) + + assert len(extractor) == 1 + assert extractor[0].question == "What is the meaning of life?" + assert extractor[0].answer == "42" diff --git a/personal_mnemonic_medium/v2/domain/prompts/qa_prompt.py b/personal_mnemonic_medium/v2/domain/prompts/qa_prompt.py index 061699ad..6e59c10f 100644 --- a/personal_mnemonic_medium/v2/domain/prompts/qa_prompt.py +++ b/personal_mnemonic_medium/v2/domain/prompts/qa_prompt.py @@ -6,8 +6,8 @@ int_hash_str, ) +from ..prompt_source.document_ingesters.document import Document from .base_prompt import BasePrompt -from .doc_mixin import DocMixin @dataclass(frozen=True) @@ -33,5 +33,11 @@ def tags(self) -> Sequence[str]: return self.add_tags -class QAPromptFromDoc(QAPrompt, DocMixin): - ... +@dataclass(frozen=True) +class QAPromptFromDoc(QAPrompt): + parent_doc: Document + line_nr: int + + @property + def tags(self) -> Sequence[str]: + return self.parent_doc.tags From 037ddc3b8c55e787a1ec1f41e42a0c147321e05c Mon Sep 17 00:00:00 2001 From: Martin Bernstorff Date: Sun, 10 Dec 2023 18:49:53 +0000 Subject: [PATCH 2/4] misc: todo --- .../prompt_source/prompt_extractors/base_prompt_extractor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/personal_mnemonic_medium/v2/domain/prompt_source/prompt_extractors/base_prompt_extractor.py b/personal_mnemonic_medium/v2/domain/prompt_source/prompt_extractors/base_prompt_extractor.py index 7bd6df3a..ba064789 100644 --- a/personal_mnemonic_medium/v2/domain/prompt_source/prompt_extractors/base_prompt_extractor.py +++ b/personal_mnemonic_medium/v2/domain/prompt_source/prompt_extractors/base_prompt_extractor.py @@ -5,7 +5,7 @@ from ..document_ingesters.document import Document -# TODO: https://github.com/MartinBernstorff/personal-mnemonic-medium/issues/294 Implement QA and Cloze promptextractors +# TODO: https://github.com/MartinBernstorff/personal-mnemonic-medium/issues/297 Implement ClozeExtractor class BasePromptExtractor(Protocol): def extract_prompts( self, document: Document From f493ed492c55766d56c85dcbb633b0de3988032e Mon Sep 17 00:00:00 2001 From: Martin Bernstorff Date: Sun, 10 Dec 2023 18:51:11 +0000 Subject: [PATCH 3/4] test: tags --- .../prompt_extractors/test_qa_prompt_extractor.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/personal_mnemonic_medium/v2/domain/prompt_source/prompt_extractors/test_qa_prompt_extractor.py b/personal_mnemonic_medium/v2/domain/prompt_source/prompt_extractors/test_qa_prompt_extractor.py index 2cb1a648..e071572f 100644 --- a/personal_mnemonic_medium/v2/domain/prompt_source/prompt_extractors/test_qa_prompt_extractor.py +++ b/personal_mnemonic_medium/v2/domain/prompt_source/prompt_extractors/test_qa_prompt_extractor.py @@ -11,6 +11,8 @@ def test_qa_prompt_extractor(tmpdir: Path): Q. What is the meaning of life? A. 42 +#anki/tag/test_tag + """, source_path=tmpdir / "test.md", ) @@ -22,3 +24,4 @@ def test_qa_prompt_extractor(tmpdir: Path): assert len(extractor) == 1 assert extractor[0].question == "What is the meaning of life?" assert extractor[0].answer == "42" + assert extractor[0].tags == ["#anki/tag/test_tag"] From 825dc3f998d35b5dd2e075d5cc99729626094d56 Mon Sep 17 00:00:00 2001 From: Martin Bernstorff Date: Sun, 10 Dec 2023 18:51:31 +0000 Subject: [PATCH 4/4] misc. --- .../v2/domain/prompt_source/document_prompt_source.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/personal_mnemonic_medium/v2/domain/prompt_source/document_prompt_source.py b/personal_mnemonic_medium/v2/domain/prompt_source/document_prompt_source.py index d8357476..3cd4b7ab 100644 --- a/personal_mnemonic_medium/v2/domain/prompt_source/document_prompt_source.py +++ b/personal_mnemonic_medium/v2/domain/prompt_source/document_prompt_source.py @@ -27,7 +27,7 @@ def _get_prompts_from_document( prompts: list[BasePrompt] = [] for extractor in self._prompt_extractors: extractor_prompts = list( - extractor.extract_prompts(document.content) + extractor.extract_prompts(document) ) prompts += extractor_prompts