diff --git a/personal_mnemonic_medium/v2/domain/prompt_source/document_prompt_source.py b/personal_mnemonic_medium/v2/domain/prompt_source/document_prompt_source.py index d8357476..3cd4b7ab 100644 --- a/personal_mnemonic_medium/v2/domain/prompt_source/document_prompt_source.py +++ b/personal_mnemonic_medium/v2/domain/prompt_source/document_prompt_source.py @@ -27,7 +27,7 @@ def _get_prompts_from_document( prompts: list[BasePrompt] = [] for extractor in self._prompt_extractors: extractor_prompts = list( - extractor.extract_prompts(document.content) + extractor.extract_prompts(document) ) prompts += extractor_prompts diff --git a/personal_mnemonic_medium/v2/domain/prompt_source/prompt_extractors/base_prompt_extractor.py b/personal_mnemonic_medium/v2/domain/prompt_source/prompt_extractors/base_prompt_extractor.py index 0761eff0..ba064789 100644 --- a/personal_mnemonic_medium/v2/domain/prompt_source/prompt_extractors/base_prompt_extractor.py +++ b/personal_mnemonic_medium/v2/domain/prompt_source/prompt_extractors/base_prompt_extractor.py @@ -2,9 +2,12 @@ from typing import Protocol from ...prompts.base_prompt import BasePrompt +from ..document_ingesters.document import Document -# TODO: https://github.com/MartinBernstorff/personal-mnemonic-medium/issues/294 Implement QA and Cloze promptextractors +# TODO: https://github.com/MartinBernstorff/personal-mnemonic-medium/issues/297 Implement ClozeExtractor class BasePromptExtractor(Protocol): - def extract_prompts(self, content: str) -> Sequence[BasePrompt]: + def extract_prompts( + self, document: Document + ) -> Sequence[BasePrompt]: ... diff --git a/personal_mnemonic_medium/v2/domain/prompt_source/prompt_extractors/qa_prompt_extractor.py b/personal_mnemonic_medium/v2/domain/prompt_source/prompt_extractors/qa_prompt_extractor.py new file mode 100644 index 00000000..e16f69c6 --- /dev/null +++ b/personal_mnemonic_medium/v2/domain/prompt_source/prompt_extractors/qa_prompt_extractor.py @@ -0,0 +1,95 @@ +import logging +import re +from collections.abc import Sequence + +from ...prompts.qa_prompt import QAPrompt, QAPromptFromDoc +from ..document_ingesters.document import Document +from .base_prompt_extractor import BasePromptExtractor + +log = logging.getLogger(__name__) + + +class QAPromptExtractor(BasePromptExtractor): + def __init__( + self, question_prefix: str, answer_prefix: str + ) -> None: + self.question_prefix = question_prefix + self.answer_prefix = answer_prefix + + def _get_first_question(self, content: str) -> str: + question = re.findall( + self.question_prefix + r"{0,1}\.(?:(?!A\.).)*", + content, + flags=re.DOTALL, + )[0] + + return question[len(self.question_prefix) + 1 :].rstrip() + + def _get_first_answer(self, content: str) -> str: + # Have to use positive lookahead to match code-blocks + # To ensure the last answer is matched as well, we add 2 newlines to string. + string_padded = f"{content.rstrip()}\n\n" + + answer = re.findall( + r"\n" + self.answer_prefix + r"[ \n]+\n*.+", + string_padded, + re.DOTALL, + )[0] + + return answer[len(self.answer_prefix) + 2 :].rstrip() + + @staticmethod + def _string_to_blocks_by_newlines(string: str) -> list[str]: + """Break string into a list by 2+ newlines in a row.""" + return re.split(r"(\n\n)+", string) + + def _has_qa(self, string: str) -> bool: + """Check whether a string contains a qa prompt""" + return ( + len( + re.findall( + r"^(?![:>]).*" + + self.question_prefix + + r"{0,1}\. ", + string, + flags=re.DOTALL, + ) + ) + != 0 + ) + + def extract_prompts( + self, document: Document + ) -> Sequence[QAPrompt]: + prompts: list[QAPrompt] = [] + + blocks = self._string_to_blocks_by_newlines(document.content) + + block_starting_line_nr = 1 + + for block_string in blocks: + if self._has_qa(block_string): + question = self._get_first_question(block_string) + try: + answer = self._get_first_answer(block_string) + except IndexError: + logging.warn( + f"Could not find answer in {document.title} for {question}" + ) + continue + + prompts.append( + QAPromptFromDoc( + question=question, + answer=answer, + parent_doc=document, + line_nr=block_starting_line_nr, + ) + ) + + block_lines = len( + re.findall(r"\n", block_string, flags=re.DOTALL) + ) + block_starting_line_nr += block_lines + + return prompts diff --git a/personal_mnemonic_medium/v2/domain/prompt_source/prompt_extractors/test_qa_prompt_extractor.py b/personal_mnemonic_medium/v2/domain/prompt_source/prompt_extractors/test_qa_prompt_extractor.py new file mode 100644 index 00000000..e071572f --- /dev/null +++ b/personal_mnemonic_medium/v2/domain/prompt_source/prompt_extractors/test_qa_prompt_extractor.py @@ -0,0 +1,27 @@ +from pathlib import Path + +from ..document_ingesters.document import Document +from .qa_prompt_extractor import QAPromptExtractor + + +def test_qa_prompt_extractor(tmpdir: Path): + doc = Document( + content=""" + +Q. What is the meaning of life? +A. 42 + +#anki/tag/test_tag + +""", + source_path=tmpdir / "test.md", + ) + + extractor = QAPromptExtractor( + question_prefix="Q.", answer_prefix="A." + ).extract_prompts(doc) + + assert len(extractor) == 1 + assert extractor[0].question == "What is the meaning of life?" + assert extractor[0].answer == "42" + assert extractor[0].tags == ["#anki/tag/test_tag"] diff --git a/personal_mnemonic_medium/v2/domain/prompts/qa_prompt.py b/personal_mnemonic_medium/v2/domain/prompts/qa_prompt.py index 061699ad..6e59c10f 100644 --- a/personal_mnemonic_medium/v2/domain/prompts/qa_prompt.py +++ b/personal_mnemonic_medium/v2/domain/prompts/qa_prompt.py @@ -6,8 +6,8 @@ int_hash_str, ) +from ..prompt_source.document_ingesters.document import Document from .base_prompt import BasePrompt -from .doc_mixin import DocMixin @dataclass(frozen=True) @@ -33,5 +33,11 @@ def tags(self) -> Sequence[str]: return self.add_tags -class QAPromptFromDoc(QAPrompt, DocMixin): - ... +@dataclass(frozen=True) +class QAPromptFromDoc(QAPrompt): + parent_doc: Document + line_nr: int + + @property + def tags(self) -> Sequence[str]: + return self.parent_doc.tags