diff --git a/memium/destination/ankiconnect/anki_converter.py b/memium/destination/ankiconnect/anki_converter.py index 7305c2d5..130db186 100644 --- a/memium/destination/ankiconnect/anki_converter.py +++ b/memium/destination/ankiconnect/anki_converter.py @@ -1,13 +1,10 @@ -from collections.abc import Sequence - -from functionalpy._sequence import Seq - -from ...source.prompts.prompt import BasePrompt -from ...source.prompts.prompt_cloze import ClozePrompt -from ...source.prompts.prompt_qa import QAPrompt +from ...source.prompts.prompt import BasePrompt, DestinationPrompt +from ...source.prompts.prompt_cloze import ClozePrompt, ClozeWithoutDoc +from ...source.prompts.prompt_qa import QAPrompt, QAWithoutDoc from .anki_prompt import AnkiPrompt from .anki_prompt_cloze import AnkiCloze from .anki_prompt_qa import AnkiQA +from .ankiconnect_gateway import NoteInfo class AnkiPromptConverter: @@ -18,7 +15,7 @@ def __init__( self.deck_prefix = deck_prefix self.card_css = card_css - def _prompt_to_card(self, prompt: BasePrompt) -> AnkiPrompt: + def prompt_to_card(self, prompt: BasePrompt) -> AnkiPrompt: deck_in_tags = [ tag for tag in prompt.tags if tag.startswith(self.deck_prefix) ] @@ -45,9 +42,25 @@ def _prompt_to_card(self, prompt: BasePrompt) -> AnkiPrompt: "BasePrompt is the base class for all prompts, use a subclass" ) - def prompts_to_cards( - self, prompts: Sequence[BasePrompt] - ) -> Sequence[AnkiPrompt]: - """Takes an iterable of prompts and turns them into AnkiCards""" + def note_info_to_prompt(self, note_info: NoteInfo) -> DestinationPrompt: + if "Question" in note_info.fields and "Answer" in note_info.fields: + return DestinationPrompt( + QAWithoutDoc( + question=note_info.fields["Question"].value, + answer=note_info.fields["Answer"].value, + add_tags=note_info.tags, + ), + destination_id=str(note_info.noteId), + ) + + if "Text" in note_info.fields: + return DestinationPrompt( + ClozeWithoutDoc( + text=note_info.fields["Text"].value, add_tags=note_info.tags + ), + destination_id=str(note_info.noteId), + ) - return Seq(prompts).map(self._prompt_to_card).to_list() + raise ValueError( + f"NoteInfo {note_info} has neither Question nor Text field" + ) diff --git a/memium/destination/ankiconnect/anki_prompt_cloze.py b/memium/destination/ankiconnect/anki_prompt_cloze.py index 68980e32..98ddb4f3 100644 --- a/memium/destination/ankiconnect/anki_prompt_cloze.py +++ b/memium/destination/ankiconnect/anki_prompt_cloze.py @@ -3,7 +3,7 @@ import genanki -from ...utils.hash_cleaned_str import clean_str, int_hash_str +from ...utils.hash_cleaned_str import int_hash_str, remove_punctuation from .anki_prompt import AnkiPrompt @@ -16,7 +16,7 @@ class AnkiCloze(AnkiPrompt): @property def uuid(self) -> int: - return int_hash_str(clean_str(self.text)) + return int_hash_str(remove_punctuation(self.text)) @property def genanki_model(self) -> genanki.Model: diff --git a/memium/destination/ankiconnect/anki_prompt_qa.py b/memium/destination/ankiconnect/anki_prompt_qa.py index 8a940f0a..58f2af15 100644 --- a/memium/destination/ankiconnect/anki_prompt_qa.py +++ b/memium/destination/ankiconnect/anki_prompt_qa.py @@ -3,7 +3,7 @@ import genanki -from ...utils.hash_cleaned_str import clean_str, int_hash_str +from ...utils.hash_cleaned_str import int_hash_str, remove_punctuation from .anki_prompt import AnkiPrompt @@ -17,7 +17,7 @@ class AnkiQA(AnkiPrompt): @property def uuid(self) -> int: - return int_hash_str(clean_str(f"{self.question}{self.answer}")) + return int_hash_str(remove_punctuation(f"{self.question}{self.answer}")) @property def genanki_model(self) -> genanki.Model: diff --git a/memium/destination/ankiconnect/test_anki_converter.py b/memium/destination/ankiconnect/test_anki_converter.py index 53f82335..ce4acff7 100644 --- a/memium/destination/ankiconnect/test_anki_converter.py +++ b/memium/destination/ankiconnect/test_anki_converter.py @@ -44,7 +44,7 @@ def test_anki_prompt_converter( """Tests the AnkiPromptConverter class""" card = AnkiPromptConverter( base_deck="FakeDeck", card_css="FakeCSS" - ).prompts_to_cards([input_prompt])[0] + ).prompt_to_card(input_prompt) assert card.uuid == expected_card.uuid for attr in expected_card.__dict__: diff --git a/memium/destination/destination_ankiconnect.py b/memium/destination/destination_ankiconnect.py index 9f7c72bf..c0d53c0c 100644 --- a/memium/destination/destination_ankiconnect.py +++ b/memium/destination/destination_ankiconnect.py @@ -5,12 +5,10 @@ from functionalpy._sequence import Seq from ..source.prompts.prompt import DestinationPrompt -from ..source.prompts.prompt_cloze import ClozeWithoutDoc -from ..source.prompts.prompt_qa import QAWithoutDoc from ..utils.hash_cleaned_str import hash_cleaned_str from .ankiconnect.anki_converter import AnkiPromptConverter from .ankiconnect.anki_prompt import AnkiPrompt -from .ankiconnect.ankiconnect_gateway import AnkiConnectGateway, NoteInfo +from .ankiconnect.ankiconnect_gateway import AnkiConnectGateway from .destination import ( DeletePrompts, PromptDestination, @@ -28,33 +26,10 @@ def __init__( self.gateway = gateway self.prompt_converter = prompt_converter - def _note_info_to_prompt(self, note_info: NoteInfo) -> DestinationPrompt: - if "Question" in note_info.fields and "Answer" in note_info.fields: - return DestinationPrompt( - QAWithoutDoc( - question=note_info.fields["Question"].value, - answer=note_info.fields["Answer"].value, - add_tags=note_info.tags, - ), - destination_id=str(note_info.noteId), - ) - - if "Text" in note_info.fields: - return DestinationPrompt( - ClozeWithoutDoc( - text=note_info.fields["Text"].value, add_tags=note_info.tags - ), - destination_id=str(note_info.noteId), - ) - - raise ValueError( - f"NoteInfo {note_info} has neither Question nor Text field" - ) - def get_all_prompts(self) -> Sequence[DestinationPrompt]: return ( Seq(self.gateway.get_all_note_infos()) - .map(self._note_info_to_prompt) + .map(self.prompt_converter.note_info_to_prompt) .to_list() ) @@ -85,7 +60,9 @@ def _create_package(self, cards: Sequence[AnkiPrompt]) -> genanki.Package: return genanki.Package(deck_or_decks=decks) def _push_prompts(self, command: PushPrompts) -> None: - cards = self.prompt_converter.prompts_to_cards(command.prompts) + cards = [ + self.prompt_converter.prompt_to_card(e) for e in command.prompts + ] models = [card.genanki_model for card in cards] unique_models: dict[int, genanki.Model] = { diff --git a/memium/source/prompts/prompt.py b/memium/source/prompts/prompt.py index 502f4834..60082907 100644 --- a/memium/source/prompts/prompt.py +++ b/memium/source/prompts/prompt.py @@ -1,11 +1,12 @@ from collections.abc import Sequence -from typing import Protocol +from typing import Protocol, runtime_checkable +@runtime_checkable class BasePrompt(Protocol): @property def scheduling_uid(self) -> int: - """UID used when scheduling the prompt. If this UID changes, the scheduling of the prompt is reset.""" + """UID used when scheduling the prompt. If this UID changes, the old prompt is deleted and a new prompt is created.""" ... @property diff --git a/memium/source/prompts/prompt_cloze.py b/memium/source/prompts/prompt_cloze.py index 565ad25d..41831a34 100644 --- a/memium/source/prompts/prompt_cloze.py +++ b/memium/source/prompts/prompt_cloze.py @@ -1,7 +1,7 @@ from collections.abc import Sequence from dataclasses import dataclass -from ...utils.hash_cleaned_str import clean_str, hash_cleaned_str, int_hash_str +from ...utils.hash_cleaned_str import hash_cleaned_str, int_hash_str from ..document import Document from .prompt import BasePrompt @@ -16,7 +16,7 @@ def scheduling_uid(self) -> int: @property def update_uid(self) -> int: - return int_hash_str(f"{clean_str(self.text)}{self.tags}") + return int_hash_str(f"{(self.text)}{self.tags}") @property def tags(self) -> Sequence[str]: diff --git a/memium/source/prompts/prompt_qa.py b/memium/source/prompts/prompt_qa.py index 3b059cf5..1288a99b 100644 --- a/memium/source/prompts/prompt_qa.py +++ b/memium/source/prompts/prompt_qa.py @@ -1,7 +1,7 @@ from collections.abc import Sequence from dataclasses import dataclass -from ...utils.hash_cleaned_str import hash_cleaned_str +from ...utils.hash_cleaned_str import hash_cleaned_str, int_hash_str from ..document import Document from .prompt import BasePrompt @@ -17,7 +17,7 @@ def scheduling_uid(self) -> int: @property def update_uid(self) -> int: - return hash_cleaned_str(f"{self.question}_{self.answer}_{self.tags}") + return int_hash_str(f"{self.question}_{self.answer}_{self.tags}") @property def tags(self) -> Sequence[str]: @@ -38,9 +38,6 @@ class QAFromDoc(QAPrompt): parent_doc: Document line_nr: int - def __repr__(self) -> str: - return f"{self.parent_doc.source_path}:{self.line_nr}: \n\tQ. {self.question}\n\tA. {self.answer}" - @property def tags(self) -> Sequence[str]: return self.parent_doc.tags diff --git a/memium/test_main.py b/memium/test_main.py index 94f7e944..84470475 100644 --- a/memium/test_main.py +++ b/memium/test_main.py @@ -13,13 +13,22 @@ reason="Tests require a running AnkiConnect server", ) def test_main( - tmp_path: Path, caplog: pytest.LogCaptureFixture, - base_deck: str = "Tests::Integration Test deck", + base_deck: str = "Tests::Main Integration Test", ): caplog.set_level(logging.INFO) - output_path = tmp_path / "test.md" - with output_path.open("w") as f: + + # Clear and delete output path + test_input_path = Path("/output") + + if test_input_path.exists(): + for entity in test_input_path.iterdir(): + if not entity.is_dir(): + entity.unlink() + else: + test_input_path.mkdir(parents=True) + + with (test_input_path / "test.md").open("w") as f: f.write( """# Test note Q. Test question? @@ -29,9 +38,15 @@ def test_main( main( base_deck=base_deck, - input_dir=tmp_path, - max_deletions_per_run=1, - dry_run=True, + input_dir=test_input_path, + max_deletions_per_run=2, + dry_run=False, + ) + + # Test idempotency + main( + base_deck=base_deck, + input_dir=test_input_path, + max_deletions_per_run=0, # 0 deletions allowed to test idempotency + dry_run=False, ) - assert "Pushing prompt" in caplog.text - assert "Test question?" in caplog.text diff --git a/memium/utils/hash_cleaned_str.py b/memium/utils/hash_cleaned_str.py index 00bdf14d..3fe6eeb2 100644 --- a/memium/utils/hash_cleaned_str.py +++ b/memium/utils/hash_cleaned_str.py @@ -1,14 +1,35 @@ import hashlib +from collections.abc import Callable +from bs4 import BeautifulSoup -def clean_str(input_str: str) -> str: + +def remove_spaces(text: str) -> str: + """Remove spaces from a string.""" + return text.replace(" ", "") + + +def remove_html_tags(text: str) -> str: + clean_text = BeautifulSoup(text, "html.parser").text + return clean_text + + +def remove_punctuation(text: str) -> str: """Clean string before hashing, so changes to spacing, punctuation, newlines etc. do not affect the hash.""" - lowered = input_str.lower() + lowered = text.lower() punctuation = r"""!"#$%&'()*+,-./:;<=>?@[\]^_`|~""" - cleaned = lowered.translate(str.maketrans("", "", punctuation)).replace( - " ", "" - ) + cleaned = lowered.translate(str.maketrans("", "", punctuation)) + + return cleaned + + +def clean_str(input_str: str) -> str: + """Clean string before hashing, so changes to spacing, punctuation, newlines etc. do not affect the hash.""" + cleaned = input_str + + for cleaner in [remove_html_tags, remove_punctuation, remove_spaces]: + cleaned = cleaner(cleaned) return cleaned @@ -27,8 +48,9 @@ def int_hash_str(input_string: str, max_length: int = 10) -> int: return shortened -def hash_cleaned_str(input_str: str) -> int: +def hash_cleaned_str( + input_str: str, cleaner: Callable[[str], str] = clean_str +) -> int: """Hash a string after cleaning it.""" - cleaned = clean_str(input_str) - hashed = int_hash_str(cleaned) + hashed = int_hash_str(cleaner(input_str)) return hashed diff --git a/memium/utils/test_hash_cleaned_str.py b/memium/utils/test_hash_cleaned_str.py index 0aa1d67a..73b6d22c 100644 --- a/memium/utils/test_hash_cleaned_str.py +++ b/memium/utils/test_hash_cleaned_str.py @@ -1,7 +1,9 @@ -from .hash_cleaned_str import hash_cleaned_str +import pytest +from .hash_cleaned_str import clean_str, hash_cleaned_str -def test_hash_cleaned_str(): + +def test_hash_cleaned_str_should_ignore_punctuation(): strings_should_hash_to_identical = ["this", "This"] punctuation = r"!().,:;/" @@ -13,3 +15,18 @@ def test_hash_cleaned_str(): len({hash_cleaned_str(s) for s in strings_should_hash_to_identical}) == 1 ) + + +def test_hash_cleaned_str_should_remove_html_tags(): + strings_should_hash_to_identical = ["

Test

", "Test"] + + assert hash_cleaned_str( + strings_should_hash_to_identical[0] + ) == hash_cleaned_str(strings_should_hash_to_identical[1]) + + +@pytest.mark.parametrize( + ("input_str", "expected"), [("Is <2, but >4", "is2but4")] +) +def test_str_cleaner(input_str: str, expected: str): + assert clean_str(input_str) == clean_str(expected) diff --git a/pyproject.toml b/pyproject.toml index 59e59787..c35607ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ "typer==0.9.0", "tqdm==4.66.1", "wasabi==1.1.2", + "bs4==0.0.1", ] [project.license]