Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

整理: エンジンが見つからない例外を追加 #1309

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 15 additions & 0 deletions test/e2e/test_missing_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""エンジン取得失敗のテスト"""

from fastapi.testclient import TestClient
from syrupy.assertion import SnapshotAssertion


def test_missing_engine_422(
client: TestClient, snapshot_json: SnapshotAssertion
) -> None:
"""存在しないエンジンを指定するとエラーを返す。"""
response = client.post(
"/audio_query", params={"text": "あ", "speaker": 1, "core_version": "4.0.4"}
)
assert response.status_code == 422
assert snapshot_json == response.json()
8 changes: 4 additions & 4 deletions test/tts_pipeline/test_tts_engines.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
""" `TTSEngineManager` クラスのテスト"""

import pytest
from fastapi import HTTPException

from voicevox_engine.dev.tts_engine.mock import MockTTSEngine
from voicevox_engine.tts_pipeline.tts_engine import TTSEngineManager
from voicevox_engine.tts_pipeline.tts_engine import EngineNotFound, TTSEngineManager


def test_tts_engines_register_engine() -> None:
Expand Down Expand Up @@ -88,9 +87,10 @@ def test_tts_engines_get_engine_missing() -> None:
tts_engine2 = MockTTSEngine()
tts_engines.register_engine(tts_engine1, "0.0.1")
tts_engines.register_engine(tts_engine2, "0.0.2")

# Expects
true_message = "バージョン 0.0.3 のエンジンが見つかりません"
# Test
with pytest.raises(HTTPException) as _:
with pytest.raises(EngineNotFound, match=true_message):
tts_engines.get_engine("0.0.3")


Expand Down
2 changes: 2 additions & 0 deletions voicevox_engine/app/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from voicevox_engine import __version__
from voicevox_engine.app.dependencies import deprecated_mutable_api
from voicevox_engine.app.global_exceptions import register_global_exception_handlers
from voicevox_engine.app.middlewares import configure_middlewares
from voicevox_engine.app.openapi_schema import configure_openapi_schema
from voicevox_engine.app.routers.engine_info import generate_engine_info_router
Expand Down Expand Up @@ -52,6 +53,7 @@ def generate_app(
version=__version__,
)
app = configure_middlewares(app, cors_policy_mode, allow_origin)
app = register_global_exception_handlers(app)

if disable_mutable_api:
deprecated_mutable_api.enable = False
Expand Down
19 changes: 19 additions & 0 deletions voicevox_engine/app/global_exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""グローバルな例外ハンドラの定義と登録"""

from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse

from voicevox_engine.tts_pipeline.tts_engine import EngineNotFound


def register_global_exception_handlers(app: FastAPI) -> FastAPI:
"""グローバルな例外ハンドラを app へ登録する。"""

# エンジンは複数 router 内で呼ばれるためグローバルなハンドラが相応しい
@app.exception_handler(EngineNotFound)
async def enf_exception_handler(
request: Request, e: EngineNotFound
) -> JSONResponse:
return JSONResponse(status_code=422, content={"message": f"{str(e)}"})

return app
8 changes: 7 additions & 1 deletion voicevox_engine/tts_pipeline/tts_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,12 @@ def frame_synthsize_wave(
return wave


class EngineNotFound(Exception):
"""エンジンが見つからないエラー"""

pass


class TTSEngineManager:
"""TTS エンジンの集まりを一括管理するマネージャー"""

Expand All @@ -709,7 +715,7 @@ def get_engine(self, version: str | None = None) -> TTSEngine:
elif version in self._engines:
return self._engines[version]

raise HTTPException(status_code=422, detail="不明なバージョンです")
raise EngineNotFound(f"バージョン {version} のエンジンが見つかりません")
tarepan marked this conversation as resolved.
Show resolved Hide resolved

def has_engine(self, version: str) -> bool:
"""指定バージョンのエンジンが登録されているか否かを返す。"""
Expand Down