diff --git a/helpers/profiles.py b/helpers/profiles.py index e13b710..def608d 100644 --- a/helpers/profiles.py +++ b/helpers/profiles.py @@ -3,7 +3,8 @@ import dataclasses import enum -from typing import Iterable +import typing +from collections.abc import Iterable, Container @enum.unique @@ -35,7 +36,12 @@ def all_names(cls) -> Iterable[str]: @classmethod def all_comma_separated_names(cls) -> str: - return ','.join(cls.all_names()) + return ",".join(cls.all_names()) + + +class AnkiNoteProtocol(typing.Protocol): + def __contains__(self, key: str) -> bool: + ... @dataclasses.dataclass(frozen=True) @@ -52,7 +58,7 @@ class Profile: _subclasses_map = {} # "furigana" (str) -> ProfileFurigana def __init_subclass__(cls, **kwargs): - mode = kwargs.pop('mode') # suppresses ide warning + mode = kwargs.pop("mode") # suppresses ide warning super().__init_subclass__(**kwargs) cls._subclasses_map[mode] = cls cls.mode = mode @@ -62,11 +68,7 @@ def __new__(cls, mode: str, *args, **kwargs): return object.__new__(subclass) def enabled_callers(self) -> list[TaskCaller]: - return [ - TaskCaller[name] - for name in self.triggered_by.split(',') - if name - ] + return [TaskCaller[name] for name in self.triggered_by.split(",") if name] def should_answer_to(self, caller: TaskCaller) -> bool: """ @@ -75,6 +77,12 @@ def should_answer_to(self, caller: TaskCaller) -> bool: """ return caller in self.enabled_callers() + def applies_to_note(self, note: AnkiNoteProtocol) -> bool: + """ + Field names must not be empty or None. The note must have fields with these names. + """ + return (self.source and self.destination) and (self.source in note and self.destination in note) + @classmethod def class_by_mode(cls, mode: str): return cls._subclasses_map[mode] @@ -96,7 +104,7 @@ def get_default(cls, mode: str): return cls.class_by_mode(mode).new() @classmethod - def clone(cls, profile: 'Profile'): + def clone(cls, profile: "Profile"): return cls(**dataclasses.asdict(profile)) @@ -133,15 +141,15 @@ def new(cls): ) -def test(): +def main(): import json - with open('../config.json') as f: + with open("../config.json") as f: config = json.load(f) - for p in config.get('profiles'): + for p in config.get("profiles"): print(Profile(**p)) -if __name__ == '__main__': - test() +if __name__ == "__main__": + main() diff --git a/tasks.py b/tasks.py index 066130c..7150aed 100644 --- a/tasks.py +++ b/tasks.py @@ -2,23 +2,26 @@ # License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html import functools -from typing import Optional, Callable, Any +import io +from typing import Optional import anki.collection from anki import hooks from anki.decks import DeckId +from anki.models import NotetypeDict from anki.notes import Note from anki.utils import strip_html_media from aqt import mw +from aqt.utils import tooltip -from .audio import format_audio_tags, AnkiAudioSourceManager +from .audio import format_audio_tags, AnkiAudioSourceManager, FileSaveResults from .config_view import config_view as cfg from .helpers import * from .helpers.profiles import Profile, ProfileFurigana, PitchOutputFormat, ProfilePitch, ProfileAudio, TaskCaller from .reading import format_pronunciations, get_pronunciations, generate_furigana -def note_type_matches(note_type: dict[str, Any], profile: Profile) -> bool: +def note_type_matches(note_type: NotetypeDict, profile: Profile) -> bool: return profile.note_type.lower() in note_type["name"].lower() @@ -51,7 +54,11 @@ def _generate_text(self, src_text: str) -> str: raise NotImplementedError() def run(self, src_text: str, dest_text: str) -> str: - return out if (out := self._generate_text(src_text)) and (out != src_text or not dest_text) else dest_text + if src_text: + out_text = self._generate_text(src_text) + if out_text and (not dest_text or out_text != src_text): + return out_text + return dest_text class AddFurigana(DoTask, task_type=ProfileFurigana): @@ -117,12 +124,12 @@ def html_to_media_line(txt: str) -> str: class DoTasks: def __init__( - self, - note: Note, - *, - caller: TaskCaller, - src_field: Optional[str] = None, - overwrite: bool = False, + self, + note: Note, + *, + caller: TaskCaller, + src_field: Optional[str] = None, + overwrite: bool = False, ): self._note = note self._caller = caller @@ -134,12 +141,17 @@ def run(self, changed: bool = False) -> bool: with aud_src_mgr.request_new_session() as aud_mgr: for task in self._tasks: - if task.should_answer_to(self._caller): + if task.should_answer_to(self._caller) and task.applies_to_note(self._note): changed = self._do_task(task, aud_mgr=aud_mgr) or changed return changed def _do_task(self, task: Profile, aud_mgr: AnkiAudioSourceManager) -> bool: changed = False + + if self._field_contains_garbage(task.destination): + self._note[task.destination] = "" # immediately clear garbage + changed = True + if self._can_fill_destination(task) and (src_text := self._src_text(task)): self._note[task.destination] = DoTask( task, @@ -152,36 +164,31 @@ def _do_task(self, task: Profile, aud_mgr: AnkiAudioSourceManager) -> bool: changed = True return changed + def _can_fill_destination(self, task: Profile) -> bool: + """ + The add-on can fill the destination field if it's empty + or if the user wants to fill it with new data and erase the old data. + """ + return self._is_overwrite_permitted(task) or not html_to_media_line(self._note[task.destination]) + + def _is_overwrite_permitted(self, task: Profile) -> bool: + """ + Has the user allowed the add-on to erase existing content (if any) in the destination field? + """ + return self._overwrite or task.overwrite_destination + def _src_text(self, task: Profile) -> str: + """ + Return source text with sound and image tags removed. + """ return mw.col.media.strip(self._note[task.source]).strip() def _field_contains_garbage(self, field_name: str) -> bool: - # Yomichan added `No pitch accent data` to the field when creating the note. - if "No pitch accent data".lower() in self._note[field_name].lower(): - return True - return False - - def _can_fill_destination(self, task: Profile) -> bool: - # Field names are empty or None - if not task.source or not task.destination: - return False - - # The note doesn't have fields with these names - if task.source not in self._note or task.destination not in self._note: - return False - - if self._field_contains_garbage(task.destination): - return True - - # Must overwrite any existing data. - if self._overwrite is True or task.overwrite_destination is True: - return True - - # Field is empty. - if not html_to_media_line(self._note[task.destination]): - return True - - return False + """ + Yomichan added `No pitch accent data` to the field when creating the note. + Rikaitan doesn't have this problem. + """ + return "No pitch accent data".lower() in self._note[field_name].lower() def on_focus_lost(changed: bool, note: Note, field_idx: int) -> bool: