From 92c4a36f4bf78b4e2364df848938fe36395ac9bc Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Sat, 24 Feb 2024 00:06:36 +0900 Subject: [PATCH] =?UTF-8?q?C=20API=E3=81=A8Python=20API=E3=81=AE=E4=B8=8D?= =?UTF-8?q?=E5=BF=85=E8=A6=81=E3=81=AAUTF-8=E3=81=AE=E8=A6=81=E6=B1=82?= =?UTF-8?q?=E3=82=92=E5=A4=96=E3=81=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core_c_api/src/lib.rs | 6 ++-- crates/voicevox_core_python_api/src/lib.rs | 32 ++++++++++++---------- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/crates/voicevox_core_c_api/src/lib.rs b/crates/voicevox_core_c_api/src/lib.rs index d2946f02b..a5da9b6d3 100644 --- a/crates/voicevox_core_c_api/src/lib.rs +++ b/crates/voicevox_core_c_api/src/lib.rs @@ -441,8 +441,10 @@ pub unsafe extern "C" fn voicevox_synthesizer_is_loaded_voice_model( model_id: VoicevoxVoiceModelId, ) -> bool { init_logger_once(); - // FIXME: 不正なUTF-8文字列に対し、正式なエラーとするか黙って`false`を返す - let raw_model_id = ensure_utf8(unsafe { CStr::from_ptr(model_id) }).unwrap(); + let Ok(raw_model_id) = ensure_utf8(unsafe { CStr::from_ptr(model_id) }) else { + // 与えられたIDがUTF-8ではない場合、それに対応する`VoicdModel`は確実に存在しない + return false; + }; synthesizer .synthesizer() .is_loaded_voice_model(&VoiceModelId::new(raw_model_id.into())) diff --git a/crates/voicevox_core_python_api/src/lib.rs b/crates/voicevox_core_python_api/src/lib.rs index 4cde8d711..ea20066b4 100644 --- a/crates/voicevox_core_python_api/src/lib.rs +++ b/crates/voicevox_core_python_api/src/lib.rs @@ -1,4 +1,4 @@ -use std::{marker::PhantomData, sync::Arc}; +use std::{marker::PhantomData, path::PathBuf, sync::Arc}; mod convert; use self::convert::{ @@ -13,7 +13,7 @@ use pyo3::{ create_exception, exceptions::{PyException, PyKeyError, PyValueError}, pyclass, pyfunction, pymethods, pymodule, - types::{IntoPyDict as _, PyBytes, PyDict, PyList, PyModule}, + types::{IntoPyDict as _, PyBytes, PyDict, PyList, PyModule, PyString}, wrap_pyfunction, PyAny, PyObject, PyRef, PyResult, PyTypeInfo, Python, ToPyObject, }; use uuid::Uuid; @@ -114,10 +114,7 @@ fn supported_devices(py: Python<'_>) -> PyResult<&PyAny> { #[pymethods] impl VoiceModel { #[staticmethod] - fn from_path( - py: Python<'_>, - #[pyo3(from_py_with = "from_utf8_path")] path: Utf8PathBuf, - ) -> PyResult<&PyAny> { + fn from_path(py: Python<'_>, path: PathBuf) -> PyResult<&PyAny> { pyo3_asyncio::tokio::future_into_py(py, async move { let model = voicevox_core::tokio::VoiceModel::from_path(path).await; let model = Python::with_gil(|py| model.into_py_result(py))?; @@ -247,7 +244,12 @@ impl Synthesizer { .into_py_result(py) } - fn is_loaded_voice_model(&self, voice_model_id: &str) -> PyResult { + // C APIの挙動と一貫性を持たせる。 + fn is_loaded_voice_model(&self, voice_model_id: &PyString) -> PyResult { + let Ok(voice_model_id) = voice_model_id.to_str() else { + // 与えられたIDがUTF-8ではない場合、それに対応する`VoicdModel`は確実に存在しない + return Ok(false); + }; Ok(self .synthesizer .get()? @@ -636,12 +638,12 @@ impl UserDict { } mod blocking { - use std::sync::Arc; + use std::{path::PathBuf, sync::Arc}; use camino::Utf8PathBuf; use pyo3::{ pyclass, pymethods, - types::{IntoPyDict as _, PyBytes, PyDict, PyList}, + types::{IntoPyDict as _, PyBytes, PyDict, PyList, PyString}, PyAny, PyObject, PyRef, PyResult, Python, }; use uuid::Uuid; @@ -661,10 +663,7 @@ mod blocking { #[pymethods] impl VoiceModel { #[staticmethod] - fn from_path( - py: Python<'_>, - #[pyo3(from_py_with = "crate::convert::from_utf8_path")] path: Utf8PathBuf, - ) -> PyResult { + fn from_path(py: Python<'_>, path: PathBuf) -> PyResult { let model = voicevox_core::blocking::VoiceModel::from_path(path).into_py_result(py)?; Ok(Self { model }) } @@ -786,7 +785,12 @@ mod blocking { .into_py_result(py) } - fn is_loaded_voice_model(&self, voice_model_id: &str) -> PyResult { + // C APIの挙動と一貫性を持たせる。 + fn is_loaded_voice_model(&self, voice_model_id: &PyString) -> PyResult { + let Ok(voice_model_id) = voice_model_id.to_str() else { + // 与えられたIDがUTF-8ではない場合、それに対応する`VoicdModel`は確実に存在しない + return Ok(false); + }; Ok(self .synthesizer .get()?