Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

整理: エンジンが見つからない例外を追加 #1309

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 15 additions & 0 deletions test/e2e/test_missing_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""エンジン取得失敗のテスト"""

from fastapi.testclient import TestClient
from syrupy.assertion import SnapshotAssertion


def test_missing_tts_engine_422(
client: TestClient, snapshot_json: SnapshotAssertion
) -> None:
"""存在しないコアを指定するとエラーを返す。"""
response = client.post(
"/audio_query", params={"text": "あ", "speaker": 1, "core_version": "4.0.4"}
)
assert response.status_code == 422
assert snapshot_json == response.json()
6 changes: 2 additions & 4 deletions test/unit/tts_pipeline/test_tts_engines.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
""" `TTSEngineManager` クラスのテスト"""

import pytest
from fastapi import HTTPException

from voicevox_engine.dev.tts_engine.mock import MockTTSEngine
from voicevox_engine.tts_pipeline.tts_engine import TTSEngineManager
from voicevox_engine.tts_pipeline.tts_engine import TTSEngineManager, TTSEngineNotFound


def test_tts_engines_register_engine() -> None:
Expand Down Expand Up @@ -88,9 +87,8 @@ def test_tts_engines_get_engine_missing() -> None:
tts_engine2 = MockTTSEngine()
tts_engines.register_engine(tts_engine1, "0.0.1")
tts_engines.register_engine(tts_engine2, "0.0.2")

# Test
with pytest.raises(HTTPException) as _:
with pytest.raises(TTSEngineNotFound):
tts_engines.get_engine("0.0.3")


Expand Down
17 changes: 16 additions & 1 deletion voicevox_engine/app/global_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse

from voicevox_engine.core.core_initializer import CoreNotFound
from voicevox_engine.core.core_initializer import MOCK_VER, CoreNotFound
from voicevox_engine.tts_pipeline.tts_engine import TTSEngineNotFound


def configure_global_exception_handlers(app: FastAPI) -> FastAPI:
Expand All @@ -14,4 +15,18 @@ def configure_global_exception_handlers(app: FastAPI) -> FastAPI:
async def cnf_exception_handler(request: Request, e: CoreNotFound) -> JSONResponse:
return JSONResponse(status_code=422, content={"message": f"{str(e)}"})

# 指定されたエンジンが見つからないエラー
@app.exception_handler(TTSEngineNotFound)
async def enf_exception_handler(
request: Request, e: TTSEngineNotFound
) -> JSONResponse:
version = e.version
if version == MOCK_VER:
msg = "コアのモックが見つかりません。エンジンの起動引数 `--enable_mock` を確認してください。"
elif version == "latest":
msg = "コアが1つも見つかりません。"
else:
msg = f"バージョン {version} のコアが見つかりません。"
return JSONResponse(status_code=422, content={"message": msg})

tarepan marked this conversation as resolved.
Show resolved Hide resolved
return app
14 changes: 12 additions & 2 deletions voicevox_engine/tts_pipeline/tts_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import copy
import math
from typing import Any

import numpy as np
from fastapi import HTTPException
from numpy.typing import NDArray
from soxr import resample

Expand Down Expand Up @@ -673,6 +673,15 @@ def frame_synthsize_wave(
return wave


class TTSEngineNotFound(Exception):
"""TTSEngine が見つからないエラー"""

def __init__(self, *args: list[Any], version: str, **kwargs: dict[str, Any]):
"""TTSEngine のバージョン番号を用いてインスタンス化する。"""
super().__init__(*args, **kwargs)
self.version = version


class TTSEngineManager:
"""TTS エンジンの集まりを一括管理するマネージャー"""

Expand All @@ -698,7 +707,8 @@ def get_engine(self, version: str | None = None) -> TTSEngine:
elif version in self._engines:
return self._engines[version]

raise HTTPException(status_code=422, detail="不明なバージョンです")
# TODO: `version` が None 受け入れを辞めたタイミングで falsy を削除する
raise TTSEngineNotFound(version=version or "latest")

def has_engine(self, version: str) -> bool:
"""指定バージョンのエンジンが登録されているか否かを返す。"""
Expand Down