-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat implement QA and Cloze promptextractors (#298)
- Loading branch information
Showing
5 changed files
with
137 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
95 changes: 95 additions & 0 deletions
95
personal_mnemonic_medium/v2/domain/prompt_source/prompt_extractors/qa_prompt_extractor.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import logging | ||
import re | ||
from collections.abc import Sequence | ||
|
||
from ...prompts.qa_prompt import QAPrompt, QAPromptFromDoc | ||
from ..document_ingesters.document import Document | ||
from .base_prompt_extractor import BasePromptExtractor | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
class QAPromptExtractor(BasePromptExtractor): | ||
def __init__( | ||
self, question_prefix: str, answer_prefix: str | ||
) -> None: | ||
self.question_prefix = question_prefix | ||
self.answer_prefix = answer_prefix | ||
|
||
def _get_first_question(self, content: str) -> str: | ||
question = re.findall( | ||
self.question_prefix + r"{0,1}\.(?:(?!A\.).)*", | ||
content, | ||
flags=re.DOTALL, | ||
)[0] | ||
|
||
return question[len(self.question_prefix) + 1 :].rstrip() | ||
|
||
def _get_first_answer(self, content: str) -> str: | ||
# Have to use positive lookahead to match code-blocks | ||
# To ensure the last answer is matched as well, we add 2 newlines to string. | ||
string_padded = f"{content.rstrip()}\n\n" | ||
|
||
answer = re.findall( | ||
r"\n" + self.answer_prefix + r"[ \n]+\n*.+", | ||
string_padded, | ||
re.DOTALL, | ||
)[0] | ||
|
||
return answer[len(self.answer_prefix) + 2 :].rstrip() | ||
|
||
@staticmethod | ||
def _string_to_blocks_by_newlines(string: str) -> list[str]: | ||
"""Break string into a list by 2+ newlines in a row.""" | ||
return re.split(r"(\n\n)+", string) | ||
|
||
def _has_qa(self, string: str) -> bool: | ||
"""Check whether a string contains a qa prompt""" | ||
return ( | ||
len( | ||
re.findall( | ||
r"^(?![:>]).*" | ||
+ self.question_prefix | ||
+ r"{0,1}\. ", | ||
string, | ||
flags=re.DOTALL, | ||
) | ||
) | ||
!= 0 | ||
) | ||
|
||
def extract_prompts( | ||
self, document: Document | ||
) -> Sequence[QAPrompt]: | ||
prompts: list[QAPrompt] = [] | ||
|
||
blocks = self._string_to_blocks_by_newlines(document.content) | ||
|
||
block_starting_line_nr = 1 | ||
|
||
for block_string in blocks: | ||
if self._has_qa(block_string): | ||
question = self._get_first_question(block_string) | ||
try: | ||
answer = self._get_first_answer(block_string) | ||
except IndexError: | ||
logging.warn( | ||
f"Could not find answer in {document.title} for {question}" | ||
) | ||
continue | ||
|
||
prompts.append( | ||
QAPromptFromDoc( | ||
question=question, | ||
answer=answer, | ||
parent_doc=document, | ||
line_nr=block_starting_line_nr, | ||
) | ||
) | ||
|
||
block_lines = len( | ||
re.findall(r"\n", block_string, flags=re.DOTALL) | ||
) | ||
block_starting_line_nr += block_lines | ||
|
||
return prompts |
27 changes: 27 additions & 0 deletions
27
...nal_mnemonic_medium/v2/domain/prompt_source/prompt_extractors/test_qa_prompt_extractor.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from pathlib import Path | ||
|
||
from ..document_ingesters.document import Document | ||
from .qa_prompt_extractor import QAPromptExtractor | ||
|
||
|
||
def test_qa_prompt_extractor(tmpdir: Path): | ||
doc = Document( | ||
content=""" | ||
Q. What is the meaning of life? | ||
A. 42 | ||
#anki/tag/test_tag | ||
""", | ||
source_path=tmpdir / "test.md", | ||
) | ||
|
||
extractor = QAPromptExtractor( | ||
question_prefix="Q.", answer_prefix="A." | ||
).extract_prompts(doc) | ||
|
||
assert len(extractor) == 1 | ||
assert extractor[0].question == "What is the meaning of life?" | ||
assert extractor[0].answer == "42" | ||
assert extractor[0].tags == ["#anki/tag/test_tag"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters