Skip to content

Commit

Permalink
fix: AnkiConnect sync does not appear idempotent (#505)
Browse files Browse the repository at this point in the history
fix: AnkiConnect sync does not appear idempotent

Fixes #503

refactor: pure-ish functions for ankiconverter

fix: apply the html cleaner
  • Loading branch information
MartinBernstorff authored Jan 3, 2024
2 parents 6fee03d + 03319e8 commit afe06ba
Show file tree
Hide file tree
Showing 12 changed files with 117 additions and 74 deletions.
39 changes: 26 additions & 13 deletions memium/destination/ankiconnect/anki_converter.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from collections.abc import Sequence

from functionalpy._sequence import Seq

from ...source.prompts.prompt import BasePrompt
from ...source.prompts.prompt_cloze import ClozePrompt
from ...source.prompts.prompt_qa import QAPrompt
from ...source.prompts.prompt import BasePrompt, DestinationPrompt
from ...source.prompts.prompt_cloze import ClozePrompt, ClozeWithoutDoc
from ...source.prompts.prompt_qa import QAPrompt, QAWithoutDoc
from .anki_prompt import AnkiPrompt
from .anki_prompt_cloze import AnkiCloze
from .anki_prompt_qa import AnkiQA
from .ankiconnect_gateway import NoteInfo


class AnkiPromptConverter:
Expand All @@ -18,7 +15,7 @@ def __init__(
self.deck_prefix = deck_prefix
self.card_css = card_css

def _prompt_to_card(self, prompt: BasePrompt) -> AnkiPrompt:
def prompt_to_card(self, prompt: BasePrompt) -> AnkiPrompt:
deck_in_tags = [
tag for tag in prompt.tags if tag.startswith(self.deck_prefix)
]
Expand All @@ -45,9 +42,25 @@ def _prompt_to_card(self, prompt: BasePrompt) -> AnkiPrompt:
"BasePrompt is the base class for all prompts, use a subclass"
)

def prompts_to_cards(
self, prompts: Sequence[BasePrompt]
) -> Sequence[AnkiPrompt]:
"""Takes an iterable of prompts and turns them into AnkiCards"""
def note_info_to_prompt(self, note_info: NoteInfo) -> DestinationPrompt:
if "Question" in note_info.fields and "Answer" in note_info.fields:
return DestinationPrompt(
QAWithoutDoc(
question=note_info.fields["Question"].value,
answer=note_info.fields["Answer"].value,
add_tags=note_info.tags,
),
destination_id=str(note_info.noteId),
)

if "Text" in note_info.fields:
return DestinationPrompt(
ClozeWithoutDoc(
text=note_info.fields["Text"].value, add_tags=note_info.tags
),
destination_id=str(note_info.noteId),
)

return Seq(prompts).map(self._prompt_to_card).to_list()
raise ValueError(
f"NoteInfo {note_info} has neither Question nor Text field"
)
4 changes: 2 additions & 2 deletions memium/destination/ankiconnect/anki_prompt_cloze.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import genanki

from ...utils.hash_cleaned_str import clean_str, int_hash_str
from ...utils.hash_cleaned_str import int_hash_str, remove_punctuation
from .anki_prompt import AnkiPrompt


Expand All @@ -16,7 +16,7 @@ class AnkiCloze(AnkiPrompt):

@property
def uuid(self) -> int:
return int_hash_str(clean_str(self.text))
return int_hash_str(remove_punctuation(self.text))

@property
def genanki_model(self) -> genanki.Model:
Expand Down
4 changes: 2 additions & 2 deletions memium/destination/ankiconnect/anki_prompt_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import genanki

from ...utils.hash_cleaned_str import clean_str, int_hash_str
from ...utils.hash_cleaned_str import int_hash_str, remove_punctuation
from .anki_prompt import AnkiPrompt


Expand All @@ -17,7 +17,7 @@ class AnkiQA(AnkiPrompt):

@property
def uuid(self) -> int:
return int_hash_str(clean_str(f"{self.question}{self.answer}"))
return int_hash_str(remove_punctuation(f"{self.question}{self.answer}"))

@property
def genanki_model(self) -> genanki.Model:
Expand Down
2 changes: 1 addition & 1 deletion memium/destination/ankiconnect/test_anki_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_anki_prompt_converter(
"""Tests the AnkiPromptConverter class"""
card = AnkiPromptConverter(
base_deck="FakeDeck", card_css="FakeCSS"
).prompts_to_cards([input_prompt])[0]
).prompt_to_card(input_prompt)

assert card.uuid == expected_card.uuid
for attr in expected_card.__dict__:
Expand Down
33 changes: 5 additions & 28 deletions memium/destination/destination_ankiconnect.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@
from functionalpy._sequence import Seq

from ..source.prompts.prompt import DestinationPrompt
from ..source.prompts.prompt_cloze import ClozeWithoutDoc
from ..source.prompts.prompt_qa import QAWithoutDoc
from ..utils.hash_cleaned_str import hash_cleaned_str
from .ankiconnect.anki_converter import AnkiPromptConverter
from .ankiconnect.anki_prompt import AnkiPrompt
from .ankiconnect.ankiconnect_gateway import AnkiConnectGateway, NoteInfo
from .ankiconnect.ankiconnect_gateway import AnkiConnectGateway
from .destination import (
DeletePrompts,
PromptDestination,
Expand All @@ -28,33 +26,10 @@ def __init__(
self.gateway = gateway
self.prompt_converter = prompt_converter

def _note_info_to_prompt(self, note_info: NoteInfo) -> DestinationPrompt:
if "Question" in note_info.fields and "Answer" in note_info.fields:
return DestinationPrompt(
QAWithoutDoc(
question=note_info.fields["Question"].value,
answer=note_info.fields["Answer"].value,
add_tags=note_info.tags,
),
destination_id=str(note_info.noteId),
)

if "Text" in note_info.fields:
return DestinationPrompt(
ClozeWithoutDoc(
text=note_info.fields["Text"].value, add_tags=note_info.tags
),
destination_id=str(note_info.noteId),
)

raise ValueError(
f"NoteInfo {note_info} has neither Question nor Text field"
)

def get_all_prompts(self) -> Sequence[DestinationPrompt]:
return (
Seq(self.gateway.get_all_note_infos())
.map(self._note_info_to_prompt)
.map(self.prompt_converter.note_info_to_prompt)
.to_list()
)

Expand Down Expand Up @@ -85,7 +60,9 @@ def _create_package(self, cards: Sequence[AnkiPrompt]) -> genanki.Package:
return genanki.Package(deck_or_decks=decks)

def _push_prompts(self, command: PushPrompts) -> None:
cards = self.prompt_converter.prompts_to_cards(command.prompts)
cards = [
self.prompt_converter.prompt_to_card(e) for e in command.prompts
]

models = [card.genanki_model for card in cards]
unique_models: dict[int, genanki.Model] = {
Expand Down
5 changes: 3 additions & 2 deletions memium/source/prompts/prompt.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from collections.abc import Sequence
from typing import Protocol
from typing import Protocol, runtime_checkable


@runtime_checkable
class BasePrompt(Protocol):
@property
def scheduling_uid(self) -> int:
"""UID used when scheduling the prompt. If this UID changes, the scheduling of the prompt is reset."""
"""UID used when scheduling the prompt. If this UID changes, the old prompt is deleted and a new prompt is created."""
...

@property
Expand Down
4 changes: 2 additions & 2 deletions memium/source/prompts/prompt_cloze.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections.abc import Sequence
from dataclasses import dataclass

from ...utils.hash_cleaned_str import clean_str, hash_cleaned_str, int_hash_str
from ...utils.hash_cleaned_str import hash_cleaned_str, int_hash_str
from ..document import Document
from .prompt import BasePrompt

Expand All @@ -16,7 +16,7 @@ def scheduling_uid(self) -> int:

@property
def update_uid(self) -> int:
return int_hash_str(f"{clean_str(self.text)}{self.tags}")
return int_hash_str(f"{(self.text)}{self.tags}")

@property
def tags(self) -> Sequence[str]:
Expand Down
7 changes: 2 additions & 5 deletions memium/source/prompts/prompt_qa.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections.abc import Sequence
from dataclasses import dataclass

from ...utils.hash_cleaned_str import hash_cleaned_str
from ...utils.hash_cleaned_str import hash_cleaned_str, int_hash_str
from ..document import Document
from .prompt import BasePrompt

Expand All @@ -17,7 +17,7 @@ def scheduling_uid(self) -> int:

@property
def update_uid(self) -> int:
return hash_cleaned_str(f"{self.question}_{self.answer}_{self.tags}")
return int_hash_str(f"{self.question}_{self.answer}_{self.tags}")

@property
def tags(self) -> Sequence[str]:
Expand All @@ -38,9 +38,6 @@ class QAFromDoc(QAPrompt):
parent_doc: Document
line_nr: int

def __repr__(self) -> str:
return f"{self.parent_doc.source_path}:{self.line_nr}: \n\tQ. {self.question}\n\tA. {self.answer}"

@property
def tags(self) -> Sequence[str]:
return self.parent_doc.tags
33 changes: 24 additions & 9 deletions memium/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,22 @@
reason="Tests require a running AnkiConnect server",
)
def test_main(
tmp_path: Path,
caplog: pytest.LogCaptureFixture,
base_deck: str = "Tests::Integration Test deck",
base_deck: str = "Tests::Main Integration Test",
):
caplog.set_level(logging.INFO)
output_path = tmp_path / "test.md"
with output_path.open("w") as f:

# Clear and delete output path
test_input_path = Path("/output")

if test_input_path.exists():
for entity in test_input_path.iterdir():
if not entity.is_dir():
entity.unlink()
else:
test_input_path.mkdir(parents=True)

with (test_input_path / "test.md").open("w") as f:
f.write(
"""# Test note
Q. Test question?
Expand All @@ -29,9 +38,15 @@ def test_main(

main(
base_deck=base_deck,
input_dir=tmp_path,
max_deletions_per_run=1,
dry_run=True,
input_dir=test_input_path,
max_deletions_per_run=2,
dry_run=False,
)

# Test idempotency
main(
base_deck=base_deck,
input_dir=test_input_path,
max_deletions_per_run=0, # 0 deletions allowed to test idempotency
dry_run=False,
)
assert "Pushing prompt" in caplog.text
assert "Test question?" in caplog.text
38 changes: 30 additions & 8 deletions memium/utils/hash_cleaned_str.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,35 @@
import hashlib
from collections.abc import Callable

from bs4 import BeautifulSoup

def clean_str(input_str: str) -> str:

def remove_spaces(text: str) -> str:
"""Remove spaces from a string."""
return text.replace(" ", "")


def remove_html_tags(text: str) -> str:
clean_text = BeautifulSoup(text, "html.parser").text
return clean_text


def remove_punctuation(text: str) -> str:
"""Clean string before hashing, so changes to spacing, punctuation, newlines etc. do not affect the hash."""
lowered = input_str.lower()
lowered = text.lower()

punctuation = r"""!"#$%&'()*+,-./:;<=>?@[\]^_`|~"""
cleaned = lowered.translate(str.maketrans("", "", punctuation)).replace(
" ", ""
)
cleaned = lowered.translate(str.maketrans("", "", punctuation))

return cleaned


def clean_str(input_str: str) -> str:
"""Clean string before hashing, so changes to spacing, punctuation, newlines etc. do not affect the hash."""
cleaned = input_str

for cleaner in [remove_html_tags, remove_punctuation, remove_spaces]:
cleaned = cleaner(cleaned)

return cleaned

Expand All @@ -27,8 +48,9 @@ def int_hash_str(input_string: str, max_length: int = 10) -> int:
return shortened


def hash_cleaned_str(input_str: str) -> int:
def hash_cleaned_str(
input_str: str, cleaner: Callable[[str], str] = clean_str
) -> int:
"""Hash a string after cleaning it."""
cleaned = clean_str(input_str)
hashed = int_hash_str(cleaned)
hashed = int_hash_str(cleaner(input_str))
return hashed
21 changes: 19 additions & 2 deletions memium/utils/test_hash_cleaned_str.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from .hash_cleaned_str import hash_cleaned_str
import pytest

from .hash_cleaned_str import clean_str, hash_cleaned_str

def test_hash_cleaned_str():

def test_hash_cleaned_str_should_ignore_punctuation():
strings_should_hash_to_identical = ["this", "This"]

punctuation = r"!().,:;/"
Expand All @@ -13,3 +15,18 @@ def test_hash_cleaned_str():
len({hash_cleaned_str(s) for s in strings_should_hash_to_identical})
== 1
)


def test_hash_cleaned_str_should_remove_html_tags():
strings_should_hash_to_identical = ["<p>Test</p>", "Test"]

assert hash_cleaned_str(
strings_should_hash_to_identical[0]
) == hash_cleaned_str(strings_should_hash_to_identical[1])


@pytest.mark.parametrize(
("input_str", "expected"), [("Is <2, but >4", "is2but4")]
)
def test_str_cleaner(input_str: str, expected: str):
assert clean_str(input_str) == clean_str(expected)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies = [
"typer==0.9.0",
"tqdm==4.66.1",
"wasabi==1.1.2",
"bs4==0.0.1",
]

[project.license]
Expand Down

0 comments on commit afe06ba

Please sign in to comment.