Skip to content

Commit

Permalink
整理: 合成系のコア直接呼び出しを TTSEngine へ移動 (#1420)
Browse files Browse the repository at this point in the history
* refactor: `default_sampling_rate` を TTSEngine へ移動

* refactor: 初期化をリネーム

* refactor: 初期化に関する docstring を明確化

* refactor: `.supported_devices` を `TTSEngine` へ移動

* refactor: `.get_core()` エラーのチェックに用いる API を変更

* fix: `default_sampling_rate` の移動に追従

* fix: lint

* refactor: サンプリングレートに関する docstring を追加
  • Loading branch information
tarepan authored Jun 25, 2024
1 parent 4965657 commit 75314de
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 20 deletions.
2 changes: 1 addition & 1 deletion test/e2e/test_missing_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@

def test_missing_core_422(client: TestClient, snapshot_json: SnapshotAssertion) -> None:
"""存在しないコアを指定するとエラーを返す。"""
response = client.get("/supported_devices", params={"core_version": "4.0.4"})
response = client.get("/speakers", params={"core_version": "4.0.4"})
assert response.status_code == 422
assert snapshot_json == response.json()
4 changes: 3 additions & 1 deletion voicevox_engine/app/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ def generate_app(
generate_library_router(library_manager, verify_mutability_allowed)
)
app.include_router(generate_user_dict_router(user_dict, verify_mutability_allowed))
app.include_router(generate_engine_info_router(core_manager, engine_manifest))
app.include_router(
generate_engine_info_router(core_manager, tts_engines, engine_manifest)
)
app.include_router(
generate_setting_router(
setting_loader, engine_manifest.brand_name, verify_mutability_allowed
Expand Down
7 changes: 5 additions & 2 deletions voicevox_engine/app/routers/engine_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from voicevox_engine.core.core_adapter import DeviceSupport
from voicevox_engine.core.core_initializer import CoreManager
from voicevox_engine.engine_manifest import EngineManifest
from voicevox_engine.tts_pipeline.tts_engine import TTSEngineManager


class SupportedDevicesInfo(BaseModel):
Expand All @@ -32,7 +33,9 @@ def generate_from(cls, device_support: DeviceSupport) -> Self:


def generate_engine_info_router(
core_manager: CoreManager, engine_manifest_data: EngineManifest
core_manager: CoreManager,
tts_engine_manager: TTSEngineManager,
engine_manifest_data: EngineManifest,
) -> APIRouter:
"""エンジン情報 API Router を生成する"""
router = APIRouter(tags=["その他"])
Expand All @@ -53,7 +56,7 @@ def supported_devices(
) -> SupportedDevicesInfo:
"""対応デバイスの一覧を取得します。"""
version = core_version or core_manager.latest_version()
supported_devices = core_manager.get_core(version).supported_devices
supported_devices = tts_engine_manager.get_engine(version).supported_devices
if supported_devices is None:
raise HTTPException(status_code=422, detail="非対応の機能です。")
return SupportedDevicesInfo.generate_from(supported_devices)
Expand Down
1 change: 0 additions & 1 deletion voicevox_engine/app/routers/morphing.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ def _synthesis_morphing(
# 生成したパラメータはキャッシュされる
morph_param = synthesis_morphing_parameter(
engine=engine,
core=core,
query=query,
base_style_id=base_style_id,
target_style_id=target_style_id,
Expand Down
17 changes: 7 additions & 10 deletions voicevox_engine/app/routers/tts_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ def audio_query(
"""
version = core_version or core_manager.latest_version()
engine = tts_engines.get_engine(version)
core = core_manager.get_core(version)
accent_phrases = engine.create_accent_phrases(text, style_id)
return AudioQuery(
accent_phrases=accent_phrases,
Expand All @@ -99,7 +98,7 @@ def audio_query(
postPhonemeLength=0.1,
pauseLength=None,
pauseLengthScale=1,
outputSamplingRate=core.default_sampling_rate,
outputSamplingRate=engine.default_sampling_rate,
outputStereo=False,
kana=create_kana(accent_phrases),
)
Expand All @@ -119,7 +118,6 @@ def audio_query_from_preset(
"""
version = core_version or core_manager.latest_version()
engine = tts_engines.get_engine(version)
core = core_manager.get_core(version)
try:
presets = preset_manager.load_presets()
except PresetInputError as err:
Expand All @@ -146,7 +144,7 @@ def audio_query_from_preset(
postPhonemeLength=selected_preset.postPhonemeLength,
pauseLength=selected_preset.pauseLength,
pauseLengthScale=selected_preset.pauseLengthScale,
outputSamplingRate=core.default_sampling_rate,
outputSamplingRate=engine.default_sampling_rate,
outputStereo=False,
kana=create_kana(accent_phrases),
)
Expand Down Expand Up @@ -378,7 +376,6 @@ def sing_frame_audio_query(
"""
version = core_version or core_manager.latest_version()
engine = tts_engines.get_engine(version)
core = core_manager.get_core(version)
try:
phonemes, f0, volume = engine.create_sing_phoneme_and_f0_and_volume(
score, style_id
Expand All @@ -391,7 +388,7 @@ def sing_frame_audio_query(
volume=volume,
phonemes=phonemes,
volumeScale=1,
outputSamplingRate=core.default_sampling_rate,
outputSamplingRate=engine.default_sampling_rate,
outputStereo=False,
)

Expand Down Expand Up @@ -532,8 +529,8 @@ def initialize_speaker(
実行しなくても他のAPIは使用できますが、初回実行時に時間がかかることがあります。
"""
version = core_version or core_manager.latest_version()
core = core_manager.get_core(version)
core.initialize_style_id_synthesis(style_id, skip_reinit=skip_reinit)
engine = tts_engines.get_engine(version)
engine.initialize_synthesis(style_id, skip_reinit=skip_reinit)

@router.get("/is_initialized_speaker", tags=["その他"])
def is_initialized_speaker(
Expand All @@ -544,7 +541,7 @@ def is_initialized_speaker(
指定されたスタイルが初期化されているかどうかを返します。
"""
version = core_version or core_manager.latest_version()
core = core_manager.get_core(version)
return core.is_initialized_style_id_synthesis(style_id)
engine = tts_engines.get_engine(version)
return engine.is_synthesis_initialized(style_id)

return router
4 changes: 1 addition & 3 deletions voicevox_engine/morphing/morphing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from voicevox_engine.morphing.model import MorphableTargetInfo

from ..core.core_adapter import CoreAdapter
from ..metas.Metas import Speaker, StyleId
from ..model import AudioQuery
from ..tts_pipeline.tts_engine import TTSEngine
Expand Down Expand Up @@ -98,15 +97,14 @@ def is_morphable(

def synthesis_morphing_parameter(
engine: TTSEngine,
core: CoreAdapter,
query: AudioQuery,
base_style_id: StyleId,
target_style_id: StyleId,
) -> _MorphingParameter:
query = deepcopy(query)

# 不具合回避のためデフォルトのサンプリングレートでWORLDに掛けた後に指定のサンプリングレートに変換する
query.outputSamplingRate = core.default_sampling_rate
query.outputSamplingRate = engine.default_sampling_rate

# WORLDに掛けるため合成はモノラルで行う
query.outputStereo = False
Expand Down
21 changes: 19 additions & 2 deletions voicevox_engine/tts_pipeline/tts_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from numpy.typing import NDArray
from soxr import resample

from ..core.core_adapter import CoreAdapter
from ..core.core_adapter import CoreAdapter, DeviceSupport
from ..core.core_initializer import CoreManager
from ..core.core_wrapper import CoreWrapper
from ..metas.Metas import StyleId
Expand Down Expand Up @@ -438,7 +438,16 @@ class TTSEngine:
def __init__(self, core: CoreWrapper):
super().__init__()
self._core = CoreAdapter(core)
# NOTE: self._coreは将来的に消す予定

@property
def default_sampling_rate(self) -> int:
"""合成される音声波形のデフォルトサンプリングレートを取得する。"""
return self._core.default_sampling_rate

@property
def supported_devices(self) -> DeviceSupport | None:
"""合成時に各デバイスが利用可能か否かの一覧を取得する。"""
return self._core.supported_devices

def update_length(
self, accent_phrases: list[AccentPhrase], style_id: StyleId
Expand Down Expand Up @@ -574,6 +583,14 @@ def synthesize_wave(
wave = raw_wave_to_output_wave(query, raw_wave, sr_raw_wave)
return wave

def initialize_synthesis(self, style_id: StyleId, skip_reinit: bool) -> None:
"""指定されたスタイル ID に関する合成機能を初期化する。既に初期化されていた場合は引数に応じて再初期化する。"""
self._core.initialize_style_id_synthesis(style_id, skip_reinit=skip_reinit)

def is_synthesis_initialized(self, style_id: StyleId) -> bool:
"""指定されたスタイル ID に関する合成機能が初期化済みか否かを取得する。"""
return self._core.is_initialized_style_id_synthesis(style_id)

# FIXME: sing用のエンジンに移すかクラス名変える
# 返す値の総称を考え、関数名を変更する
def create_sing_phoneme_and_f0_and_volume(
Expand Down

0 comments on commit 75314de

Please sign in to comment.