Skip to content

Commit

Permalink
feat implement ClozeExtractor (#301)
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinBernstorff authored Dec 10, 2023
2 parents 781d3f4 + a9afe36 commit 741545a
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import hashlib
import logging
import re
from collections.abc import Sequence

from ...prompts.cloze_prompt import ClozePrompt, ClozePromptFromDoc
from ..document_ingesters.document import Document
from .base_prompt_extractor import BasePromptExtractor

log = logging.getLogger(__name__)


class ClozePromptExtractor(BasePromptExtractor):
@staticmethod
def _break_string_by_two_or_more_newlines(
string: str
) -> list[str]:
"""Break string into a list by 2+ newlines in a row."""
return re.split(r"(\n\n)+", string)

@staticmethod
def _has_cloze(string: str) -> bool:
if (
len(re.findall(r"{.*}", string)) > 0
and "BearID" not in string # Exclude BearID
and "$$" not in string # Exclude math
and r"```" not in string # Exclude code
and "Q." not in string # Exclude Q&A
and "A." not in string # Exclude Q&A
):
return True
return False

@staticmethod
def _replace_cloze_id_with_unique(
string: str, selected_cloze: str | None = None
) -> str:
"""Each cloze deletion in a note is numbered sequentially.
This function ensures that the numbering is based on the content of the cloze deletion, essentially ensuring that if you modify the contents of a cloze, only the scheduling of that specific cloze is changed.
Args:
string (str): The string to replace the cloze id with a unique id.
selected_cloze (str, optional): If you only want to replace a specific cloze, pass it here. Defaults to None.
"""
if selected_cloze is not None:
selected_clozes = [selected_cloze]
else:
selected_clozes = re.findall(
r"{(?!BearID).[^}]*}", string
)

for cloze in selected_clozes:
output_hash = (
int(
hashlib.sha256(cloze.encode("utf-8")).hexdigest(),
16,
)
% 10**3
)

new_cloze = f"{{{{c{output_hash}::{cloze[1:-1]}}}}}"

string = string.replace(cloze, new_cloze)

return string

def extract_prompts(
self, document: Document
) -> Sequence[ClozePrompt]:
prompts: list[ClozePromptFromDoc] = []

blocks = self._break_string_by_two_or_more_newlines(
document.content
)

for block_string in blocks:
if self._has_cloze(block_string):
clozes = re.findall(
r"{(?!BearID).[^}]*}", block_string
)

# TODO: https://github.com/MartinBernstorff/personal-mnemonic-medium/issues/300 refactor: move clozeid replacement to AnkiCloze generator class
for selected_cloze in clozes:
prompt_content = (
self._replace_cloze_id_with_unique(
block_string,
selected_cloze=selected_cloze,
)
)

prompts.append(
ClozePromptFromDoc(
text=prompt_content, source_doc=document
)
)

return prompts
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from pathlib import Path

from ..document_ingesters.document import Document
from .cloze_prompt_extractor import ClozePromptExtractor


def test_cloze_prompt_extractor(tmpdir: Path):
doc = Document(
content=r"""
What is the meaning of life? {42}
This is another block without any cloze prompts.
#anki/tag/test_tag
""",
source_path=tmpdir / "test.md",
)

extractor = ClozePromptExtractor().extract_prompts(doc)

assert len(extractor) == 1
assert (
extractor[0].text
== r"What is the meaning of life? {{c734::42}}"
)
assert extractor[0].tags == ["#anki/tag/test_tag"]
21 changes: 15 additions & 6 deletions personal_mnemonic_medium/v2/domain/prompts/cloze_prompt.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from collections.abc import Sequence
from dataclasses import dataclass

from personal_mnemonic_medium.v2.domain.prompts.doc_mixin import (
DocMixin,
)

from ..int_hash_str import int_hash_str
from ..prompt_source.document_ingesters.document import Document
from .base_prompt import BasePrompt

# TODO: https://github.com/MartinBernstorff/personal-mnemonic-medium/issues/299 refactor: clean up doc type inheritance


@dataclass(frozen=True)
class ClozePrompt(BasePrompt):
Expand All @@ -31,5 +30,15 @@ def tags(self) -> Sequence[str]:
return self.add_tags


class ClozePromptFromDoc(ClozePrompt, DocMixin):
...
@dataclass(frozen=True)
class ClozePromptFromDoc(ClozePrompt):
text: str
source_doc: Document

@property
def uid(self) -> int:
return int_hash_str(self.text)

@property
def tags(self) -> Sequence[str]:
return self.source_doc.tags

0 comments on commit 741545a

Please sign in to comment.