Skip to content

Commit

Permalink
整理: 型安全な辞書保存と読み出しを追加 (#1334)
Browse files Browse the repository at this point in the history
* refactor: 型安全な辞書保存と読み出しを追加

* fix: pydantic v2 化

* refactor: BaseModel から dataclass へ変更

* refactor: 保存用単語をリネーム

* fix: 変数をリネームして追従

* refactor: 保存用単語の dump 箇所を変更

* fix: 変数をリネームして追従

* 微調整

---------

Co-authored-by: Hiroshiba Kazuyuki <[email protected]>
  • Loading branch information
tarepan and Hiroshiba authored Jun 19, 2024
1 parent cd559a2 commit 00fb774
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 32 deletions.
50 changes: 18 additions & 32 deletions voicevox_engine/user_dict/user_dict_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@
from uuid import UUID, uuid4

import pyopenjtalk
from pydantic import TypeAdapter

from ..utility.path_utility import get_save_dir, resource_root
from .model import UserDictWord, WordTypes
from .model import UserDictWord
from .user_dict_word import (
SaveFormatUserDictWord,
UserDictInputError,
WordProperty,
cost2priority,
convert_from_save_format,
convert_to_save_format,
create_word,
part_of_speech_data,
priority2cost,
Expand Down Expand Up @@ -55,6 +58,9 @@ def func(*args: Any, **kw: Any) -> Any:
mutex_openjtalk_dict = threading.Lock()


_save_format_dict_adapter = TypeAdapter(dict[str, SaveFormatUserDictWord])


class UserDictionary:
"""ユーザー辞書"""

Expand Down Expand Up @@ -82,21 +88,12 @@ def __init__(
@mutex_wrapper(mutex_user_dict)
def _write_to_json(self, user_dict: dict[str, UserDictWord]) -> None:
"""ユーザー辞書データをファイルへ書き込む。"""
user_dict_path = self._user_dict_path

converted_user_dict = {}
save_format_user_dict: dict[str, SaveFormatUserDictWord] = {}
for word_uuid, word in user_dict.items():
word_dict = word.model_dump()
word_dict["cost"] = priority2cost(
word_dict["context_id"], word_dict["priority"]
)
del word_dict["priority"]
converted_user_dict[word_uuid] = word_dict
# 予めjsonに変換できることを確かめる
user_dict_json = json.dumps(converted_user_dict, ensure_ascii=False)

# ユーザー辞書ファイルへの書き込み
user_dict_path.write_text(user_dict_json, encoding="utf-8")
save_format_word = convert_to_save_format(word)
save_format_user_dict[word_uuid] = save_format_word
user_dict_json = _save_format_dict_adapter.dump_json(save_format_user_dict)
self._user_dict_path.write_bytes(user_dict_json)

@mutex_wrapper(mutex_openjtalk_dict)
def update_dict(self) -> None:
Expand Down Expand Up @@ -180,26 +177,15 @@ def update_dict(self) -> None:
@mutex_wrapper(mutex_user_dict)
def read_dict(self) -> dict[str, UserDictWord]:
"""ユーザー辞書を読み出す。"""
user_dict_path = self._user_dict_path

# 指定ユーザー辞書が存在しない場合、空辞書を返す
if not user_dict_path.is_file():
if not self._user_dict_path.is_file():
return {}

with user_dict_path.open(encoding="utf-8") as f:
with self._user_dict_path.open(encoding="utf-8") as f:
save_format_dict = _save_format_dict_adapter.validate_python(json.load(f))
result: dict[str, UserDictWord] = {}
for word_uuid, word in json.load(f).items():
# cost2priorityで変換を行う際にcontext_idが必要となるが、
# 0.12以前の辞書は、context_idがハードコーディングされていたためにユーザー辞書内に保管されていない
# ハードコーディングされていたcontext_idは固有名詞を意味するものなので、固有名詞のcontext_idを補完する
if word.get("context_id") is None:
word["context_id"] = part_of_speech_data[
WordTypes.PROPER_NOUN
].context_id
word["priority"] = cost2priority(word["context_id"], word["cost"])
del word["cost"]
result[str(UUID(word_uuid))] = UserDictWord(**word)

for word_uuid, word in save_format_dict.items():
result[str(UUID(word_uuid))] = convert_from_save_format(word)
return result

def import_user_dict(
Expand Down
70 changes: 70 additions & 0 deletions voicevox_engine/user_dict/user_dict_word.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,73 @@ def priority2cost(context_id: int, priority: int) -> int:
assert USER_DICT_MIN_PRIORITY <= priority <= USER_DICT_MAX_PRIORITY
cost_candidates = _search_cost_candidates(context_id)
return cost_candidates[USER_DICT_MAX_PRIORITY - priority]


@dataclass
class SaveFormatUserDictWord:
"""保存用の単語の型"""

surface: str
cost: int # `UserDictWord.priority` と対応
part_of_speech: str
part_of_speech_detail_1: str
part_of_speech_detail_2: str
part_of_speech_detail_3: str
inflectional_type: str
inflectional_form: str
stem: str
yomi: str
pronunciation: str
accent_type: int
accent_associative_rule: str
context_id: int | None = None # v0.12 以前の辞書でのみ `None`
mora_count: int | None = None


def convert_to_save_format(word: UserDictWord) -> SaveFormatUserDictWord:
"""単語を保存用に変換する。"""
cost = priority2cost(word.context_id, word.priority)
return SaveFormatUserDictWord(
surface=word.surface,
cost=cost,
context_id=word.context_id,
part_of_speech=word.part_of_speech,
part_of_speech_detail_1=word.part_of_speech_detail_1,
part_of_speech_detail_2=word.part_of_speech_detail_2,
part_of_speech_detail_3=word.part_of_speech_detail_3,
inflectional_type=word.inflectional_type,
inflectional_form=word.inflectional_form,
stem=word.stem,
yomi=word.yomi,
pronunciation=word.pronunciation,
accent_type=word.accent_type,
mora_count=word.mora_count,
accent_associative_rule=word.accent_associative_rule,
)


def convert_from_save_format(word: SaveFormatUserDictWord) -> UserDictWord:
"""単語を保存用から変換する。"""
context_id_p_noun = part_of_speech_data[WordTypes.PROPER_NOUN].context_id
# cost2priorityで変換を行う際にcontext_idが必要となるが、
# 0.12以前の辞書は、context_idがハードコーディングされていたためにユーザー辞書内に保管されていない
# ハードコーディングされていたcontext_idは固有名詞を意味するものなので、固有名詞のcontext_idを補完する
context_id = context_id_p_noun if word.context_id is None else word.context_id
priority = cost2priority(context_id, word.cost)
return UserDictWord(
surface=word.surface,
priority=priority,
context_id=context_id,
part_of_speech=word.part_of_speech,
part_of_speech_detail_1=word.part_of_speech_detail_1,
part_of_speech_detail_2=word.part_of_speech_detail_2,
part_of_speech_detail_3=word.part_of_speech_detail_3,
inflectional_type=word.inflectional_type,
inflectional_form=word.inflectional_form,
stem=word.stem,
yomi=word.yomi,
pronunciation=word.pronunciation,
accent_type=word.accent_type,
mora_count=word.mora_count,
accent_associative_rule=word.accent_associative_rule,
)

0 comments on commit 00fb774

Please sign in to comment.