Skip to content

Commit

Permalink
feat implement QA and Cloze promptextractors (#298)
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinBernstorff authored Dec 10, 2023
2 parents a5574b6 + 825dc3f commit 781d3f4
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def _get_prompts_from_document(
prompts: list[BasePrompt] = []
for extractor in self._prompt_extractors:
extractor_prompts = list(
extractor.extract_prompts(document.content)
extractor.extract_prompts(document)
)
prompts += extractor_prompts

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
from typing import Protocol

from ...prompts.base_prompt import BasePrompt
from ..document_ingesters.document import Document


# TODO: https://github.com/MartinBernstorff/personal-mnemonic-medium/issues/294 Implement QA and Cloze promptextractors
# TODO: https://github.com/MartinBernstorff/personal-mnemonic-medium/issues/297 Implement ClozeExtractor
class BasePromptExtractor(Protocol):
def extract_prompts(self, content: str) -> Sequence[BasePrompt]:
def extract_prompts(
self, document: Document
) -> Sequence[BasePrompt]:
...
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import logging
import re
from collections.abc import Sequence

from ...prompts.qa_prompt import QAPrompt, QAPromptFromDoc
from ..document_ingesters.document import Document
from .base_prompt_extractor import BasePromptExtractor

log = logging.getLogger(__name__)


class QAPromptExtractor(BasePromptExtractor):
def __init__(
self, question_prefix: str, answer_prefix: str
) -> None:
self.question_prefix = question_prefix
self.answer_prefix = answer_prefix

def _get_first_question(self, content: str) -> str:
question = re.findall(
self.question_prefix + r"{0,1}\.(?:(?!A\.).)*",
content,
flags=re.DOTALL,
)[0]

return question[len(self.question_prefix) + 1 :].rstrip()

def _get_first_answer(self, content: str) -> str:
# Have to use positive lookahead to match code-blocks
# To ensure the last answer is matched as well, we add 2 newlines to string.
string_padded = f"{content.rstrip()}\n\n"

answer = re.findall(
r"\n" + self.answer_prefix + r"[ \n]+\n*.+",
string_padded,
re.DOTALL,
)[0]

return answer[len(self.answer_prefix) + 2 :].rstrip()

@staticmethod
def _string_to_blocks_by_newlines(string: str) -> list[str]:
"""Break string into a list by 2+ newlines in a row."""
return re.split(r"(\n\n)+", string)

def _has_qa(self, string: str) -> bool:
"""Check whether a string contains a qa prompt"""
return (
len(
re.findall(
r"^(?![:>]).*"
+ self.question_prefix
+ r"{0,1}\. ",
string,
flags=re.DOTALL,
)
)
!= 0
)

def extract_prompts(
self, document: Document
) -> Sequence[QAPrompt]:
prompts: list[QAPrompt] = []

blocks = self._string_to_blocks_by_newlines(document.content)

block_starting_line_nr = 1

for block_string in blocks:
if self._has_qa(block_string):
question = self._get_first_question(block_string)
try:
answer = self._get_first_answer(block_string)
except IndexError:
logging.warn(
f"Could not find answer in {document.title} for {question}"
)
continue

prompts.append(
QAPromptFromDoc(
question=question,
answer=answer,
parent_doc=document,
line_nr=block_starting_line_nr,
)
)

block_lines = len(
re.findall(r"\n", block_string, flags=re.DOTALL)
)
block_starting_line_nr += block_lines

return prompts
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from pathlib import Path

from ..document_ingesters.document import Document
from .qa_prompt_extractor import QAPromptExtractor


def test_qa_prompt_extractor(tmpdir: Path):
doc = Document(
content="""
Q. What is the meaning of life?
A. 42
#anki/tag/test_tag
""",
source_path=tmpdir / "test.md",
)

extractor = QAPromptExtractor(
question_prefix="Q.", answer_prefix="A."
).extract_prompts(doc)

assert len(extractor) == 1
assert extractor[0].question == "What is the meaning of life?"
assert extractor[0].answer == "42"
assert extractor[0].tags == ["#anki/tag/test_tag"]
12 changes: 9 additions & 3 deletions personal_mnemonic_medium/v2/domain/prompts/qa_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
int_hash_str,
)

from ..prompt_source.document_ingesters.document import Document
from .base_prompt import BasePrompt
from .doc_mixin import DocMixin


@dataclass(frozen=True)
Expand All @@ -33,5 +33,11 @@ def tags(self) -> Sequence[str]:
return self.add_tags


class QAPromptFromDoc(QAPrompt, DocMixin):
...
@dataclass(frozen=True)
class QAPromptFromDoc(QAPrompt):
parent_doc: Document
line_nr: int

@property
def tags(self) -> Sequence[str]:
return self.parent_doc.tags

0 comments on commit 781d3f4

Please sign in to comment.