Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

整理: エンジンが見つからない例外を追加 #1309

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 15 additions & 0 deletions test/e2e/test_missing_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""エンジン取得失敗のテスト"""

from fastapi.testclient import TestClient
from syrupy.assertion import SnapshotAssertion


def test_missing_tts_engine_422(
client: TestClient, snapshot_json: SnapshotAssertion
) -> None:
"""存在しないコアを指定するとエラーを返す。"""
response = client.post(
"/audio_query", params={"text": "あ", "speaker": 1, "core_version": "4.0.4"}
)
assert response.status_code == 422
assert snapshot_json == response.json()
6 changes: 2 additions & 4 deletions test/unit/tts_pipeline/test_tts_engines.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
""" `TTSEngineManager` クラスのテスト"""

import pytest
from fastapi import HTTPException

from voicevox_engine.dev.tts_engine.mock import MockTTSEngine
from voicevox_engine.tts_pipeline.tts_engine import TTSEngineManager
from voicevox_engine.tts_pipeline.tts_engine import TTSEngineManager, TTSEngineNotFound


def test_tts_engines_register_engine() -> None:
Expand Down Expand Up @@ -56,9 +55,8 @@ def test_tts_engines_get_engine_missing() -> None:
tts_engine2 = MockTTSEngine()
tts_engines.register_engine(tts_engine1, "0.0.1")
tts_engines.register_engine(tts_engine2, "0.0.2")

# Test
with pytest.raises(HTTPException) as _:
with pytest.raises(TTSEngineNotFound):
tts_engines.get_engine("0.0.3")


Expand Down
20 changes: 20 additions & 0 deletions voicevox_engine/app/global_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
from fastapi.responses import JSONResponse

from voicevox_engine.core.core_initializer import CoreNotFound
from voicevox_engine.tts_pipeline.tts_engine import (
MockTTSEngineNotFound,
TTSEngineNotFound,
)


def configure_global_exception_handlers(app: FastAPI) -> FastAPI:
Expand All @@ -14,4 +18,20 @@ def configure_global_exception_handlers(app: FastAPI) -> FastAPI:
async def cnf_exception_handler(request: Request, e: CoreNotFound) -> JSONResponse:
return JSONResponse(status_code=422, content={"message": f"{str(e)}"})

# 指定されたエンジンが見つからないエラー
@app.exception_handler(TTSEngineNotFound)
async def no_engine_exception_handler(
request: Request, e: TTSEngineNotFound
) -> JSONResponse:
msg = f"バージョン {e.version} のコアが見つかりません。"
return JSONResponse(status_code=422, content={"message": msg})

# 指定されたモック版エンジンが見つからないエラー
@app.exception_handler(MockTTSEngineNotFound)
async def no_mock_exception_handler(
request: Request, e: MockTTSEngineNotFound
) -> JSONResponse:
msg = "モックが見つかりません。エンジンの起動引数 `--enable_mock` を確認してください。"
return JSONResponse(status_code=422, content={"message": msg})

return app
24 changes: 20 additions & 4 deletions voicevox_engine/tts_pipeline/tts_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

import copy
import math
from typing import Any

import numpy as np
from fastapi import HTTPException
from numpy.typing import NDArray
from soxr import resample

from ..core.core_adapter import CoreAdapter
from ..core.core_initializer import CoreManager
from ..core.core_initializer import MOCK_VER, CoreManager
from ..core.core_wrapper import CoreWrapper
from ..metas.Metas import StyleId
from ..model import AudioQuery
Expand Down Expand Up @@ -691,6 +691,19 @@ def frame_synthsize_wave(
return wave


class TTSEngineNotFound(Exception):
"""TTSEngine が見つからないエラー"""

def __init__(self, *args: list[Any], version: str, **kwargs: dict[str, Any]):
"""TTSEngine のバージョン番号を用いてインスタンス化する。"""
super().__init__(*args, **kwargs)
self.version = version


class MockTTSEngineNotFound(Exception):
"""モック TTSEngine が見つからないエラー"""


class TTSEngineManager:
"""TTS エンジンの集まりを一括管理するマネージャー"""

Expand All @@ -706,11 +719,14 @@ def register_engine(self, engine: TTSEngine, version: str) -> None:
self._engines[version] = engine

def get_engine(self, version: str) -> TTSEngine:
"""指定バージョンのエンジンを取得する"""
"""指定バージョンのエンジンを取得する"""
if version in self._engines:
return self._engines[version]

raise HTTPException(status_code=422, detail="不明なバージョンです")
if version == MOCK_VER:
raise MockTTSEngineNotFound()
else:
raise TTSEngineNotFound(version=version)

def has_engine(self, version: str) -> bool:
"""指定バージョンのエンジンが登録されているか否かを返す。"""
Expand Down