Skip to content

Commit

Permalink
ruff: format
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinBernstorff committed Oct 27, 2023
1 parent 082ea88 commit 804b6b8
Show file tree
Hide file tree
Showing 15 changed files with 88 additions and 36 deletions.
4 changes: 3 additions & 1 deletion application/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
from personal_mnemonic_medium.prompt_extractors.cloze_extractor import (
ClozePromptExtractor,
)
from personal_mnemonic_medium.prompt_extractors.qa_extractor import QAPromptExtractor
from personal_mnemonic_medium.prompt_extractors.qa_extractor import (
QAPromptExtractor,
)
from wasabi import Printer

msg = Printer(timestamp=True)
Expand Down
2 changes: 1 addition & 1 deletion makefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ test: ## Run tests
pytest $(SRC_PATH)

lint: ## Format code
ruff . --fix --extend-select F401
ruff format .
ruff . --fix --extend-select F401

type-check: ## Type-check code
pyright $(SRC_PATH)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ reportMissingTypeStubs = false

[tool.ruff]
# Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default.
line-length = 80
select = [
"A",
"ANN",
Expand Down
4 changes: 3 additions & 1 deletion src/personal_mnemonic_medium/card_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ def run(
) -> list[AnkiCard]:
notes: list[Document] = []
if input_path.is_dir():
notes += list(self.document_factory.get_notes_from_dir(dir_path=input_path))
notes += list(
self.document_factory.get_notes_from_dir(dir_path=input_path)
)

if not input_path.is_dir():
note_from_file = self.document_factory.get_note_from_file(
Expand Down
14 changes: 7 additions & 7 deletions src/personal_mnemonic_medium/exporters/anki/card_types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,13 @@ def get_source_button(self) -> str:

def to_genanki_note(self) -> genanki.Note:
"""Produce a genanki. Note with the specified guid."""
if len(self.html_fields) > len(self.genanki_model.fields): # type: ignore
if len(self.html_fields) > len(self.genanki_model.fields): # type: ignore
raise ValueError(
f"Too many fields for model {self.genanki_model.name}: {self.html_fields}", # type: ignore
f"Too many fields for model {self.genanki_model.name}: {self.html_fields}", # type: ignore
)

if len(self.html_fields) < len(self.genanki_model.fields): # type: ignore
while len(self.html_fields) < len(self.genanki_model.fields): # type: ignore
if len(self.html_fields) < len(self.genanki_model.fields): # type: ignore
while len(self.html_fields) < len(self.genanki_model.fields): # type: ignore
before_extras_field = len(self.html_fields) == 2
if before_extras_field:
self.add_field(self.get_source_button())
Expand Down Expand Up @@ -156,12 +156,12 @@ def determine_media_references(self) -> Iterator[tuple[Path, Path]]:
results = []

def process_match(m) -> str: # noqa # type: ignore
initial_contents = m.group(1) # type: ignore
abspath, newpath = self.make_ref_pair(initial_contents) # type: ignore
initial_contents = m.group(1) # type: ignore
abspath, newpath = self.make_ref_pair(initial_contents) # type: ignore
results.append((abspath, newpath)) # noqa # type: ignore
return r'src="' + newpath + '"'

current_stage = re.sub(regex, process_match, current_stage) # type: ignore
current_stage = re.sub(regex, process_match, current_stage) # type: ignore

yield from results

Expand Down
4 changes: 3 additions & 1 deletion src/personal_mnemonic_medium/exporters/anki/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

from personal_mnemonic_medium.exporters.anki.anki_css import CARD_MODEL_CSS

ANKICONNECT_URL = "http://host.docker.internal:8765" # On host machine, port is 8765
ANKICONNECT_URL = (
"http://host.docker.internal:8765"
) # On host machine, port is 8765

CARD_MATHJAX_CONTENT = textwrap.dedent(
"""\
Expand Down
16 changes: 12 additions & 4 deletions src/personal_mnemonic_medium/exporters/anki/package_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
from personal_mnemonic_medium.exporters.anki.card_types.cloze import AnkiCloze
from personal_mnemonic_medium.exporters.anki.card_types.qa import AnkiQA
from personal_mnemonic_medium.exporters.base import CardExporter
from personal_mnemonic_medium.prompt_extractors.cloze_extractor import ClozePrompt
from personal_mnemonic_medium.prompt_extractors.cloze_extractor import (
ClozePrompt,
)
from personal_mnemonic_medium.prompt_extractors.prompt import Prompt
from personal_mnemonic_medium.prompt_extractors.qa_extractor import QAPrompt
from personal_mnemonic_medium.utils.hasher import simple_hash
Expand All @@ -35,7 +37,9 @@ class DeckBundle:
media: set[str]

def get_package(self) -> genanki.Package:
return genanki.Package(deck_or_decks=self.deck, media_files=list(self.media))
return genanki.Package(
deck_or_decks=self.deck, media_files=list(self.media)
)

def save_deck_to_file(self, output_path: Path) -> Path:
package = self.get_package()
Expand Down Expand Up @@ -85,7 +89,9 @@ def cards_to_deck(
try:
deck.add_note(card.to_genanki_note())
except IndexError as e:
log.debug(f"Could not add card {card} to deck {deck_name}, {e}.")
log.debug(
f"Could not add card {card} to deck {deck_name}, {e}."
)

return deck, media

Expand All @@ -112,7 +118,9 @@ def prompts_to_cards(
source_prompt=prompt,
)
else:
raise NotImplementedError(f"Prompt type {type(prompt)} not supported.")
raise NotImplementedError(
f"Prompt type {type(prompt)} not supported."
)

cards += [card]

Expand Down
21 changes: 16 additions & 5 deletions src/personal_mnemonic_medium/exporters/anki/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ def invoke(action: Any, **params: Any) -> Any:
"""
requestJson = json.dumps(request(action, **params)).encode("utf-8")
response = json.load(
urllib.request.urlopen(urllib.request.Request(ANKICONNECT_URL, requestJson)),
urllib.request.urlopen(
urllib.request.Request(ANKICONNECT_URL, requestJson)
),
)
if len(response) != 2:
raise Exception("response has an unexpected number of fields")
Expand Down Expand Up @@ -102,7 +104,9 @@ def sync_deck(
msg.info("\tNotes removed: ")
msg.info(f"\t\t{removed_note_guids}")

package_path = deck_bundle.save_deck_to_file(save_dir_path / "deck.apkg")
package_path = deck_bundle.save_deck_to_file(
save_dir_path / "deck.apkg"
)
try:
sync_path = str(sync_dir_path / "deck.apkg")
invoke("importPackage", path=sync_path)
Expand All @@ -124,7 +128,9 @@ def sync_deck(
msg.good(f"Deleted {len(guids_to_delete)} notes")

except Exception:
msg.fail(f"Unable to delete cards in {deck_bundle.deck.name}")
msg.fail(
f"Unable to delete cards in {deck_bundle.deck.name}"
)
# Print full stack trace
traceback.print_exc()
except Exception as e:
Expand All @@ -144,7 +150,9 @@ def get_md_note_infos(deck_bundle: DeckBundle) -> set[str]:
return md_note_guids


def get_anki_note_infos(deck_bundle: DeckBundle) -> tuple[dict[str, Any], set[str]]:
def get_anki_note_infos(
deck_bundle: DeckBundle
) -> tuple[dict[str, Any], set[str]]:
anki_card_ids: list[int] = invoke(
"findCards",
query=f'"deck:{deck_bundle.deck.name}"',
Expand All @@ -158,7 +166,10 @@ def get_anki_note_infos(deck_bundle: DeckBundle) -> tuple[dict[str, Any], set[st

# convert the note info into a dictionary of guid to note info
anki_note_info_by_guid = {
n["fields"]["UUID"]["value"].replace("<p>", "").replace("</p>", "").strip(): n
n["fields"]["UUID"]["value"]
.replace("<p>", "")
.replace("</p>", "")
.strip(): n
for n in anki_notes_info
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ def field_to_html(field: Any) -> str:
If math is separated with dollar sign it is converted to brackets.
"""
if CONFIG["dollar"]:
for sep, (op, cl) in [("$$", (r"\\[", r"\\]")), ("$", (r"\\(", r"\\)"))]:
for sep, (op, cl) in [
("$$", (r"\\[", r"\\]")),
("$", (r"\\(", r"\\)")),
]:
escaped_sep = sep.replace(r"$", r"\$")
# ignore escaped dollar signs when splitting the field
field = re.split(rf"(?<!\\){escaped_sep}", field)
Expand All @@ -35,7 +38,9 @@ def field_to_html(field: Any) -> str:
token_instances = re.findall(pattern, field)

for instance in token_instances:
field = field.replace(instance, replacement + instance[1:-1] + replacement) # type: ignore
field = field.replace(
instance, replacement + instance[1:-1] + replacement
) # type: ignore

# Make sure every \n converts into a newline
field = field.replace("\n", " \n")
Expand Down
4 changes: 3 additions & 1 deletion src/personal_mnemonic_medium/note_factories/note.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ def __init__(

import_time_formatted = datetime.datetime.now().strftime("%Y-%m-%d")

self.tags = self.get_tags(self.content, import_time=import_time_formatted)
self.tags = self.get_tags(
self.content, import_time=import_time_formatted
)

@staticmethod
def replace_alias_wiki_links(text: str) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def _replace_cloze_id_with_unique(

for cloze in selected_clozes:
output_hash = (
int(hashlib.sha256(cloze.encode("utf-8")).hexdigest(), 16) % 10**3
int(hashlib.sha256(cloze.encode("utf-8")).hexdigest(), 16)
% 10**3
)

new_cloze = f"{{{{c{output_hash}::{cloze[1:-1]}}}}}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@


class QAPrompt(Prompt):
def __init__(self, question: str, answer: str, *args: Any, **kwargs: Any) -> None:
def __init__(
self, question: str, answer: str, *args: Any, **kwargs: Any
) -> None:
super().__init__(*args, **kwargs)
self.question = question
self.answer = answer


class QAPromptExtractor(PromptExtractor):
def __init__(self, question_prefix: str = "Q.", answer_prefix: str = "A.") -> None:
def __init__(
self, question_prefix: str = "Q.", answer_prefix: str = "A."
) -> None:
self.question_prefix = question_prefix
self.answer_prefix = answer_prefix

Expand Down
4 changes: 3 additions & 1 deletion src/personal_mnemonic_medium/utils/hasher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

def simple_hash(text: str) -> int:
"""MD5 of text, mod 2^63. Probably not a great hash function."""
comp_hash = int(hashlib.sha256(text.encode("utf-8")).hexdigest(), 16) % 10**10
comp_hash = (
int(hashlib.sha256(text.encode("utf-8")).hexdigest(), 16) % 10**10
)

return comp_hash
16 changes: 10 additions & 6 deletions tests/exporters/anki/test_card_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ def test_get_subtags():

def test_qa_uuid_generation():
file_path = (
Path(__file__).parent.parent.parent / "test_md_files" / "test_card_guid.md"
Path(__file__).parent.parent.parent
/ "test_md_files"
/ "test_card_guid.md"
)
cards = TestCardPipeline(prompt_extractors=[QAPromptExtractor()]).run(
input_path=file_path,
Expand All @@ -106,9 +108,13 @@ def test_qa_uuid_generation():

def test_cloze_uuid_generation():
file_path = (
Path(__file__).parent.parent.parent / "test_md_files" / "test_card_guid.md"
Path(__file__).parent.parent.parent
/ "test_md_files"
/ "test_card_guid.md"
)
cloze_cards = TestCardPipeline(prompt_extractors=[ClozePromptExtractor()]).run(
cloze_cards = TestCardPipeline(
prompt_extractors=[ClozePromptExtractor()]
).run(
input_path=file_path,
)

Expand All @@ -121,9 +127,7 @@ def test_get_bear_id():
factory = MarkdownNoteFactory()
note_str = r"Q. A card with a GUID.\nA. And here is its answer.\n\nQS. How about a card like this?\nA. Yes, an answer too.\n\nQ. How about multiline questions?\n* Like this\n* Or this?\nA. What is the hash?\n\nAnd some {cloze} deletions? For sure! Multipe {even}.\n\n<!-- {BearID:7696CDCD-803A-40BC-88D8-855DDBEC56CA-31546-000054DF17EAE2C1} -->"

expected_id = (
r"<!-- {BearID:7696CDCD-803A-40BC-88D8-855DDBEC56CA-31546-000054DF17EAE2C1} -->"
)
expected_id = r"<!-- {BearID:7696CDCD-803A-40BC-88D8-855DDBEC56CA-31546-000054DF17EAE2C1} -->"

extracted_id = factory.get_note_id(note_str)

Expand Down
14 changes: 11 additions & 3 deletions tests/prompt_extractors/test_qa_prompt_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import pytest
from personal_mnemonic_medium.note_factories.note import Document
from personal_mnemonic_medium.prompt_extractors.qa_extractor import QAPromptExtractor
from personal_mnemonic_medium.prompt_extractors.qa_extractor import (
QAPromptExtractor,
)


@pytest.fixture()
Expand Down Expand Up @@ -35,13 +37,19 @@ def test_has_qa_matches(qa_extractor: QAPromptExtractor):
"QA. Testing something else, even with QA in it!",
"\\Q. Testing newlines as well!",
]
matches = [string for string in example_strings if qa_extractor._has_qa(string)]
matches = [
string for string in example_strings if qa_extractor._has_qa(string)
]

assert len(matches) == 3


def test_has_qa_does_not_match(qa_extractor: QAPromptExtractor):
example_strings = ["\nQ.E.D.", "> A question like this, or", "::Q. A comment!::"]
example_strings = [
"\nQ.E.D.",
"> A question like this, or",
"::Q. A comment!::",
]

matches = 0

Expand Down

0 comments on commit 804b6b8

Please sign in to comment.