diff --git a/memium/source/extractors/extractor_cloze.py b/memium/source/extractors/extractor_cloze.py index d5b2845..664314b 100644 --- a/memium/source/extractors/extractor_cloze.py +++ b/memium/source/extractors/extractor_cloze.py @@ -14,24 +14,21 @@ class ClozePromptExtractor(BasePromptExtractor): @staticmethod def _get_blocks(string: str) -> list[str]: """Break string into a list by 2+ newlines in a row.""" - return re.split(r"(\n\n)+", string) + # Exclude entire code blocks + string_sans_code_blocks = re.sub( + r"```.*?```", "", string, flags=re.DOTALL + ) + return re.split(r"(\n\n)+", string_sans_code_blocks) @staticmethod def _has_cloze(string: str) -> bool: - if ( - len(re.findall(r"{.*}", string)) > 0 - and "BearID" not in string # Exclude BearID - and "$$" not in string # Exclude math - and r"```" not in string # Exclude code - and "Q." not in string # Exclude Q&A - and "A." not in string # Exclude Q&A - ): + if len(re.findall(r"{.*}", string)) > 0: return True return False @staticmethod - def _is_code_block(string: str) -> bool: - if string.startswith("```"): + def _is_math_block(string: str) -> bool: + if string.startswith("$$"): return True return False @@ -75,8 +72,12 @@ def extract_prompts(self, document: Document) -> Sequence[ClozePrompt]: blocks = self._get_blocks(document.content) for block_string in blocks: - if self._is_code_block(block_string) or self._is_html_comment( - block_string + if any( + exclusion_criterion(block_string) + for exclusion_criterion in ( + self._is_html_comment, + self._is_math_block, + ) ): continue if self._has_cloze(block_string): diff --git a/memium/source/extractors/test_prompt_extractor_cloze.py b/memium/source/extractors/test_prompt_extractor_cloze.py index d3e06d3..d44a363 100644 --- a/memium/source/extractors/test_prompt_extractor_cloze.py +++ b/memium/source/extractors/test_prompt_extractor_cloze.py @@ -40,6 +40,14 @@ def test_cloze_prompt_extractor(tmpdir: Path): ), ("""""", True), ("""{Cloze}""", False), + ( + """```html +Some content + +Content in another {block} + ```""", + True, + ), ], ) def test_ignore_block_types(content: str, skipped: bool):