Skip to content

Commit

Permalink
APIのstyle_idの型をStyleId型に (#966)
Browse files Browse the repository at this point in the history
* run.pyのとこStyleIdに

* API引数をStyleId化

* 間違えて追加してしまっていた

* 漏れ

* StyleIdの場所変更

* pysen

* 自動import箇所が意図とあってなさそうだった

* ワークアラウンドなことをコメント、FIXMEをコメント
  • Loading branch information
Hiroshiba authored Jan 3, 2024
1 parent 0881648 commit 938f8ae
Show file tree
Hide file tree
Showing 12 changed files with 63 additions and 51 deletions.
64 changes: 31 additions & 33 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from voicevox_engine.engine_manifest import EngineManifestLoader
from voicevox_engine.engine_manifest.EngineManifest import EngineManifest
from voicevox_engine.library_manager import LibraryManager
from voicevox_engine.metas.Metas import StyleId
from voicevox_engine.metas.MetasStore import MetasStore, construct_lookup
from voicevox_engine.model import (
AccentPhrase,
Expand All @@ -47,7 +48,6 @@
ParseKanaError,
Speaker,
SpeakerInfo,
StyleId,
StyleIdNotFoundError,
SupportedDevicesInfo,
UserDictWord,
Expand Down Expand Up @@ -93,17 +93,17 @@


def get_style_id_from_deprecated(
style_id: int | None, speaker_id: int | None
style_id: StyleId | None, speaker_id: StyleId | None
) -> StyleId:
"""
style_idとspeaker_id両方ともNoneかNoneでないかをチェックし、
どちらか片方しかNoneが存在しなければstyle_idを返す
"""
if speaker_id is not None and style_id is None:
warnings.warn("speakerは非推奨です。style_idを利用してください。", stacklevel=1)
return StyleId(speaker_id)
return speaker_id
elif style_id is not None and speaker_id is None:
return StyleId(style_id)
return style_id
raise HTTPException(
status_code=400, detail="speakerとstyle_idが両方とも存在しないか、両方とも存在しています。"
)
Expand Down Expand Up @@ -282,8 +282,8 @@ def get_core(core_version: Optional[str]) -> CoreAdapter:
)
def audio_query(
text: str,
style_id: int | None = Query(default=None), # noqa: B008
speaker: int | None = Query(default=None, deprecated=True), # noqa: B008
style_id: StyleId | None = Query(default=None), # noqa: B008
speaker: StyleId | None = Query(default=None, deprecated=True), # noqa: B008
core_version: str | None = None,
) -> AudioQuery:
"""
Expand Down Expand Up @@ -333,9 +333,7 @@ def audio_query_from_preset(
else:
raise HTTPException(status_code=422, detail="該当するプリセットIDが見つかりません")

accent_phrases = engine.create_accent_phrases(
text, StyleId(selected_preset.style_id)
)
accent_phrases = engine.create_accent_phrases(text, selected_preset.style_id)
return AudioQuery(
accent_phrases=accent_phrases,
speedScale=selected_preset.speedScale,
Expand Down Expand Up @@ -363,8 +361,8 @@ def audio_query_from_preset(
)
def accent_phrases(
text: str,
style_id: int | None = Query(default=None), # noqa: B008
speaker: int | None = Query(default=None, deprecated=True), # noqa: B008
style_id: StyleId | None = Query(default=None), # noqa: B008
speaker: StyleId | None = Query(default=None, deprecated=True), # noqa: B008
is_kana: bool = False,
core_version: str | None = None,
) -> list[AccentPhrase]:
Expand Down Expand Up @@ -401,8 +399,8 @@ def accent_phrases(
)
def mora_data(
accent_phrases: list[AccentPhrase],
style_id: int | None = Query(default=None), # noqa: B008
speaker: int | None = Query(default=None, deprecated=True), # noqa: B008
style_id: StyleId | None = Query(default=None), # noqa: B008
speaker: StyleId | None = Query(default=None, deprecated=True), # noqa: B008
core_version: str | None = None,
) -> list[AccentPhrase]:
style_id = get_style_id_from_deprecated(style_id=style_id, speaker_id=speaker)
Expand All @@ -417,8 +415,8 @@ def mora_data(
)
def mora_length(
accent_phrases: list[AccentPhrase],
style_id: int | None = Query(default=None), # noqa: B008
speaker: int | None = Query(default=None, deprecated=True), # noqa: B008
style_id: StyleId | None = Query(default=None), # noqa: B008
speaker: StyleId | None = Query(default=None, deprecated=True), # noqa: B008
core_version: str | None = None,
) -> list[AccentPhrase]:
style_id = get_style_id_from_deprecated(style_id=style_id, speaker_id=speaker)
Expand All @@ -433,8 +431,8 @@ def mora_length(
)
def mora_pitch(
accent_phrases: list[AccentPhrase],
style_id: int | None = Query(default=None), # noqa: B008
speaker: int | None = Query(default=None, deprecated=True), # noqa: B008
style_id: StyleId | None = Query(default=None), # noqa: B008
speaker: StyleId | None = Query(default=None, deprecated=True), # noqa: B008
core_version: str | None = None,
) -> list[AccentPhrase]:
style_id = get_style_id_from_deprecated(style_id=style_id, speaker_id=speaker)
Expand All @@ -456,8 +454,8 @@ def mora_pitch(
)
def synthesis(
query: AudioQuery,
style_id: int | None = Query(default=None), # noqa: B008
speaker: int | None = Query(default=None, deprecated=True), # noqa: B008
style_id: StyleId | None = Query(default=None), # noqa: B008
speaker: StyleId | None = Query(default=None, deprecated=True), # noqa: B008
enable_interrogative_upspeak: bool = Query( # noqa: B008
default=True,
description="疑問系のテキストが与えられたら語尾を自動調整する",
Expand Down Expand Up @@ -497,8 +495,8 @@ def synthesis(
def cancellable_synthesis(
query: AudioQuery,
request: Request,
style_id: int | None = Query(default=None), # noqa: B008
speaker: int | None = Query(default=None, deprecated=True), # noqa: B008
style_id: StyleId | None = Query(default=None), # noqa: B008
speaker: StyleId | None = Query(default=None, deprecated=True), # noqa: B008
core_version: str | None = None,
) -> FileResponse:
style_id = get_style_id_from_deprecated(style_id=style_id, speaker_id=speaker)
Expand Down Expand Up @@ -536,8 +534,8 @@ def cancellable_synthesis(
)
def multi_synthesis(
queries: list[AudioQuery],
style_id: int | None = Query(default=None), # noqa: B008
speaker: int | None = Query(default=None, deprecated=True), # noqa: B008
style_id: StyleId | None = Query(default=None), # noqa: B008
speaker: StyleId | None = Query(default=None, deprecated=True), # noqa: B008
core_version: str | None = None,
) -> FileResponse:
style_id = get_style_id_from_deprecated(style_id=style_id, speaker_id=speaker)
Expand Down Expand Up @@ -576,7 +574,7 @@ def multi_synthesis(
summary="指定した話者に対してエンジン内の話者がモーフィングが可能か判定する",
)
def morphable_targets(
base_speakers: list[int],
base_speakers: list[int], # FIXME: StyleId型にする
core_version: str | None = None,
) -> list[dict[str, MorphableTargetInfo]]:
"""
Expand Down Expand Up @@ -617,7 +615,7 @@ def morphable_targets(
)
def _synthesis_morphing(
query: AudioQuery,
base_speaker: int,
base_speaker: int, # FIXME: StyleId型にする
target_speaker: int,
morph_rate: float = Query(..., ge=0.0, le=1.0), # noqa: B008
core_version: str | None = None,
Expand Down Expand Up @@ -1001,7 +999,7 @@ def uninstall_library(library_uuid: str) -> Response:

@app.post("/initialize_style_id", status_code=204, tags=["その他"])
def initialize_style_id(
style_id: int,
style_id: StyleId,
skip_reinit: bool = Query( # noqa: B008
False, description="既に初期化済みのスタイルの再初期化をスキップするかどうか"
),
Expand All @@ -1012,23 +1010,23 @@ def initialize_style_id(
実行しなくても他のAPIは使用できますが、初回実行時に時間がかかることがあります。
"""
core = get_core(core_version)
core.initialize_style_id_synthesis(StyleId(style_id), skip_reinit=skip_reinit)
core.initialize_style_id_synthesis(style_id, skip_reinit=skip_reinit)
return Response(status_code=204)

@app.get("/is_initialized_style_id", response_model=bool, tags=["その他"])
def is_initialized_style_id(
style_id: int,
style_id: StyleId,
core_version: str | None = None,
) -> bool:
"""
指定されたstyle_idのスタイルが初期化されているかどうかを返します。
"""
core = get_core(core_version)
return core.is_initialized_style_id_synthesis(StyleId(style_id))
return core.is_initialized_style_id_synthesis(style_id)

@app.post("/initialize_speaker", status_code=204, tags=["その他"], deprecated=True)
def initialize_speaker(
speaker: int,
speaker: StyleId,
skip_reinit: bool = Query( # noqa: B008
False, description="既に初期化済みの話者の再初期化をスキップするかどうか"
),
Expand All @@ -1044,14 +1042,14 @@ def initialize_speaker(
stacklevel=1,
)
return initialize_style_id(
StyleId(speaker), skip_reinit=skip_reinit, core_version=core_version
speaker, skip_reinit=skip_reinit, core_version=core_version
)

@app.get(
"/is_initialized_speaker", response_model=bool, tags=["その他"], deprecated=True
)
def is_initialized_speaker(
speaker: int,
speaker: StyleId,
core_version: str | None = None,
) -> bool:
"""
Expand All @@ -1062,7 +1060,7 @@ def is_initialized_speaker(
"使用しているAPI(/is_initialize_speaker)は非推奨です。/is_initialized_style_idを利用してください。",
stacklevel=1,
)
return is_initialized_style_id(StyleId(speaker), core_version=core_version)
return is_initialized_style_id(speaker, core_version=core_version)

@app.get("/user_dict", response_model=dict[str, UserDictWord], tags=["ユーザー辞書"])
def get_user_dict_words() -> dict[str, UserDictWord]:
Expand Down
3 changes: 2 additions & 1 deletion test/test_mock_tts_engine.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from unittest import TestCase

from voicevox_engine.dev.tts_engine import MockTTSEngine
from voicevox_engine.model import AccentPhrase, AudioQuery, Mora, StyleId
from voicevox_engine.metas.Metas import StyleId
from voicevox_engine.model import AccentPhrase, AudioQuery, Mora
from voicevox_engine.tts_pipeline.kana_converter import create_kana


Expand Down
3 changes: 2 additions & 1 deletion test/test_tts_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

import numpy

from voicevox_engine.model import AccentPhrase, AudioQuery, Mora, StyleId
from voicevox_engine.metas.Metas import StyleId
from voicevox_engine.model import AccentPhrase, AudioQuery, Mora
from voicevox_engine.tts_pipeline import TTSEngine
from voicevox_engine.tts_pipeline.acoustic_feature_extractor import Phoneme
from voicevox_engine.tts_pipeline.tts_engine import (
Expand Down
3 changes: 2 additions & 1 deletion test/test_tts_engine_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from unittest import TestCase

from voicevox_engine.dev.core.mock import MockCoreWrapper
from voicevox_engine.model import AccentPhrase, Mora, StyleId
from voicevox_engine.metas.Metas import StyleId
from voicevox_engine.model import AccentPhrase, Mora
from voicevox_engine.tts_pipeline import TTSEngine
from voicevox_engine.tts_pipeline.tts_engine import (
apply_interrogative_upspeak, # FIXME: この関数を使うテストをTTSEngine用のテストに移動する
Expand Down
3 changes: 2 additions & 1 deletion voicevox_engine/cancellable_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from fastapi import HTTPException, Request

from .core_initializer import initialize_cores
from .model import AudioQuery, StyleId
from .metas.Metas import StyleId
from .model import AudioQuery
from .tts_pipeline import make_tts_engines_from_cores
from .utility import get_latest_core_version

Expand Down
2 changes: 1 addition & 1 deletion voicevox_engine/core_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from numpy import ndarray

from .core_wrapper import CoreWrapper, OldCoreError
from .model import StyleId
from .metas.Metas import StyleId


class CoreAdapter:
Expand Down
3 changes: 2 additions & 1 deletion voicevox_engine/dev/tts_engine/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from pyopenjtalk import tts
from soxr import resample

from ...model import AudioQuery, StyleId
from ...metas.Metas import StyleId
from ...model import AudioQuery
from ...tts_pipeline import TTSEngine
from ...tts_pipeline.tts_engine import to_flatten_moras
from ..core.mock import MockCoreWrapper
Expand Down
10 changes: 7 additions & 3 deletions voicevox_engine/metas/Metas.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
from enum import Enum
from typing import List, Optional
from typing import List, NewType, Optional

from pydantic import BaseModel, Field

# NOTE: 循環importを防ぐためにとりあえずここに書いている
# FIXME: 他のmodelに依存せず、全modelから参照できる場所に配置する
StyleId = NewType("StyleId", int)


class SpeakerStyle(BaseModel):
"""
話者のスタイル情報
"""

name: str = Field(title="スタイル名")
id: int = Field(title="スタイルID")
id: StyleId = Field(title="スタイルID")


class SpeakerSupportPermittedSynthesisMorphing(str, Enum):
Expand Down Expand Up @@ -67,7 +71,7 @@ class StyleInfo(BaseModel):
スタイルの追加情報
"""

id: int = Field(title="スタイルID")
id: StyleId = Field(title="スタイルID")
icon: str = Field(title="当該スタイルのアイコンをbase64エンコードしたもの")
portrait: Optional[str] = Field(title="当該スタイルのportrait.pngをbase64エンコードしたもの")
voice_samples: List[str] = Field(title="voice_sampleのwavファイルをbase64エンコードしたもの")
Expand Down
5 changes: 1 addition & 4 deletions voicevox_engine/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from enum import Enum
from re import findall, fullmatch
from typing import Any, Dict, List, NewType, Optional
from typing import Any, Dict, List, Optional

from pydantic import BaseModel, Field, StrictStr, validator

Expand Down Expand Up @@ -45,9 +45,6 @@ def __hash__(self):
return hash(tuple(sorted(items)))


StyleId = NewType("StyleId", int)


class AudioQuery(BaseModel):
"""
音声合成用のクエリ
Expand Down
11 changes: 8 additions & 3 deletions voicevox_engine/morphing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,14 @@
from soxr import resample

from .core_adapter import CoreAdapter
from .metas.Metas import Speaker, SpeakerStyle, SpeakerSupportPermittedSynthesisMorphing
from .metas.Metas import (
Speaker,
SpeakerStyle,
SpeakerSupportPermittedSynthesisMorphing,
StyleId,
)
from .metas.MetasStore import construct_lookup
from .model import AudioQuery, MorphableTargetInfo, StyleId, StyleIdNotFoundError
from .model import AudioQuery, MorphableTargetInfo, StyleIdNotFoundError
from .tts_pipeline import TTSEngine


Expand Down Expand Up @@ -52,7 +57,7 @@ def create_morphing_parameter(
def get_morphable_targets(
speakers: List[Speaker],
base_speakers: List[int],
) -> List[Dict[int, MorphableTargetInfo]]:
) -> List[Dict[StyleId, MorphableTargetInfo]]:
"""
speakers: 全話者の情報
base_speakers: モーフィング可能か判定したいベースの話者リスト(スタイルID)
Expand Down
4 changes: 3 additions & 1 deletion voicevox_engine/preset/Preset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from pydantic import BaseModel, Field

from voicevox_engine.metas.Metas import StyleId


class Preset(BaseModel):
"""
Expand All @@ -9,7 +11,7 @@ class Preset(BaseModel):
id: int = Field(title="プリセットID")
name: str = Field(title="プリセット名")
speaker_uuid: str = Field(title="話者のUUID")
style_id: int = Field(title="スタイルID")
style_id: StyleId = Field(title="スタイルID")
speedScale: float = Field(title="全体の話速")
pitchScale: float = Field(title="全体の音高")
intonationScale: float = Field(title="全体の抑揚")
Expand Down
3 changes: 2 additions & 1 deletion voicevox_engine/tts_pipeline/tts_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

from ..core_adapter import CoreAdapter
from ..core_wrapper import CoreWrapper
from ..model import AccentPhrase, AudioQuery, Mora, StyleId
from ..metas.Metas import StyleId
from ..model import AccentPhrase, AudioQuery, Mora
from .acoustic_feature_extractor import Phoneme
from .mora_list import openjtalk_mora2text
from .text_analyzer import text_to_accent_phrases
Expand Down

0 comments on commit 938f8ae

Please sign in to comment.