Skip to content

Commit

Permalink
refactor tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
tatsumoto-ren committed Mar 25, 2024
1 parent f9c4d5e commit 4ec4bed
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 51 deletions.
36 changes: 22 additions & 14 deletions helpers/profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

import dataclasses
import enum
from typing import Iterable
import typing
from collections.abc import Iterable, Container


@enum.unique
Expand Down Expand Up @@ -35,7 +36,12 @@ def all_names(cls) -> Iterable[str]:

@classmethod
def all_comma_separated_names(cls) -> str:
return ','.join(cls.all_names())
return ",".join(cls.all_names())


class AnkiNoteProtocol(typing.Protocol):
def __contains__(self, key: str) -> bool:
...


@dataclasses.dataclass(frozen=True)
Expand All @@ -52,7 +58,7 @@ class Profile:
_subclasses_map = {} # "furigana" (str) -> ProfileFurigana

def __init_subclass__(cls, **kwargs):
mode = kwargs.pop('mode') # suppresses ide warning
mode = kwargs.pop("mode") # suppresses ide warning
super().__init_subclass__(**kwargs)
cls._subclasses_map[mode] = cls
cls.mode = mode
Expand All @@ -62,11 +68,7 @@ def __new__(cls, mode: str, *args, **kwargs):
return object.__new__(subclass)

def enabled_callers(self) -> list[TaskCaller]:
return [
TaskCaller[name]
for name in self.triggered_by.split(',')
if name
]
return [TaskCaller[name] for name in self.triggered_by.split(",") if name]

def should_answer_to(self, caller: TaskCaller) -> bool:
"""
Expand All @@ -75,6 +77,12 @@ def should_answer_to(self, caller: TaskCaller) -> bool:
"""
return caller in self.enabled_callers()

def applies_to_note(self, note: AnkiNoteProtocol) -> bool:
"""
Field names must not be empty or None. The note must have fields with these names.
"""
return (self.source and self.destination) and (self.source in note and self.destination in note)

@classmethod
def class_by_mode(cls, mode: str):
return cls._subclasses_map[mode]
Expand All @@ -96,7 +104,7 @@ def get_default(cls, mode: str):
return cls.class_by_mode(mode).new()

@classmethod
def clone(cls, profile: 'Profile'):
def clone(cls, profile: "Profile"):
return cls(**dataclasses.asdict(profile))


Expand Down Expand Up @@ -133,15 +141,15 @@ def new(cls):
)


def test():
def main():
import json

with open('../config.json') as f:
with open("../config.json") as f:
config = json.load(f)

for p in config.get('profiles'):
for p in config.get("profiles"):
print(Profile(**p))


if __name__ == '__main__':
test()
if __name__ == "__main__":
main()
81 changes: 44 additions & 37 deletions tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,26 @@
# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html

import functools
from typing import Optional, Callable, Any
import io
from typing import Optional

import anki.collection
from anki import hooks
from anki.decks import DeckId
from anki.models import NotetypeDict
from anki.notes import Note
from anki.utils import strip_html_media
from aqt import mw
from aqt.utils import tooltip

from .audio import format_audio_tags, AnkiAudioSourceManager
from .audio import format_audio_tags, AnkiAudioSourceManager, FileSaveResults
from .config_view import config_view as cfg
from .helpers import *
from .helpers.profiles import Profile, ProfileFurigana, PitchOutputFormat, ProfilePitch, ProfileAudio, TaskCaller
from .reading import format_pronunciations, get_pronunciations, generate_furigana


def note_type_matches(note_type: dict[str, Any], profile: Profile) -> bool:
def note_type_matches(note_type: NotetypeDict, profile: Profile) -> bool:
return profile.note_type.lower() in note_type["name"].lower()


Expand Down Expand Up @@ -51,7 +54,11 @@ def _generate_text(self, src_text: str) -> str:
raise NotImplementedError()

def run(self, src_text: str, dest_text: str) -> str:
return out if (out := self._generate_text(src_text)) and (out != src_text or not dest_text) else dest_text
if src_text:
out_text = self._generate_text(src_text)
if out_text and (not dest_text or out_text != src_text):
return out_text
return dest_text


class AddFurigana(DoTask, task_type=ProfileFurigana):
Expand Down Expand Up @@ -117,12 +124,12 @@ def html_to_media_line(txt: str) -> str:

class DoTasks:
def __init__(
self,
note: Note,
*,
caller: TaskCaller,
src_field: Optional[str] = None,
overwrite: bool = False,
self,
note: Note,
*,
caller: TaskCaller,
src_field: Optional[str] = None,
overwrite: bool = False,
):
self._note = note
self._caller = caller
Expand All @@ -134,12 +141,17 @@ def run(self, changed: bool = False) -> bool:

with aud_src_mgr.request_new_session() as aud_mgr:
for task in self._tasks:
if task.should_answer_to(self._caller):
if task.should_answer_to(self._caller) and task.applies_to_note(self._note):
changed = self._do_task(task, aud_mgr=aud_mgr) or changed
return changed

def _do_task(self, task: Profile, aud_mgr: AnkiAudioSourceManager) -> bool:
changed = False

if self._field_contains_garbage(task.destination):
self._note[task.destination] = "" # immediately clear garbage
changed = True

if self._can_fill_destination(task) and (src_text := self._src_text(task)):
self._note[task.destination] = DoTask(
task,
Expand All @@ -152,36 +164,31 @@ def _do_task(self, task: Profile, aud_mgr: AnkiAudioSourceManager) -> bool:
changed = True
return changed

def _can_fill_destination(self, task: Profile) -> bool:
"""
The add-on can fill the destination field if it's empty
or if the user wants to fill it with new data and erase the old data.
"""
return self._is_overwrite_permitted(task) or not html_to_media_line(self._note[task.destination])

def _is_overwrite_permitted(self, task: Profile) -> bool:
"""
Has the user allowed the add-on to erase existing content (if any) in the destination field?
"""
return self._overwrite or task.overwrite_destination

def _src_text(self, task: Profile) -> str:
"""
Return source text with sound and image tags removed.
"""
return mw.col.media.strip(self._note[task.source]).strip()

def _field_contains_garbage(self, field_name: str) -> bool:
# Yomichan added `No pitch accent data` to the field when creating the note.
if "No pitch accent data".lower() in self._note[field_name].lower():
return True
return False

def _can_fill_destination(self, task: Profile) -> bool:
# Field names are empty or None
if not task.source or not task.destination:
return False

# The note doesn't have fields with these names
if task.source not in self._note or task.destination not in self._note:
return False

if self._field_contains_garbage(task.destination):
return True

# Must overwrite any existing data.
if self._overwrite is True or task.overwrite_destination is True:
return True

# Field is empty.
if not html_to_media_line(self._note[task.destination]):
return True

return False
"""
Yomichan added `No pitch accent data` to the field when creating the note.
Rikaitan doesn't have this problem.
"""
return "No pitch accent data".lower() in self._note[field_name].lower()


def on_focus_lost(changed: bool, note: Note, field_idx: int) -> bool:
Expand Down

0 comments on commit 4ec4bed

Please sign in to comment.