diff --git a/crates/voicevox_core/src/engine/synthesis_engine.rs b/crates/voicevox_core/src/engine/synthesis_engine.rs index 8db50b604..fc91005c4 100644 --- a/crates/voicevox_core/src/engine/synthesis_engine.rs +++ b/crates/voicevox_core/src/engine/synthesis_engine.rs @@ -633,12 +633,12 @@ mod tests { use ::test_util::OPEN_JTALK_DIC_DIR; use pretty_assertions::assert_eq; - use crate::{infer::runtimes::Onnxruntime, *}; + use crate::{synthesizer::InferenceRuntimeImpl, *}; #[rstest] #[tokio::test] async fn is_openjtalk_dict_loaded_works() { - let core = InferenceCore::::new(false, 0).unwrap(); + let core = InferenceCore::::new(false, 0).unwrap(); let synthesis_engine = SynthesisEngine::new(core, OpenJtalk::new(OPEN_JTALK_DIC_DIR).unwrap().into()); @@ -648,7 +648,7 @@ mod tests { #[rstest] #[tokio::test] async fn create_accent_phrases_works() { - let core = InferenceCore::::new(false, 0).unwrap(); + let core = InferenceCore::::new(false, 0).unwrap(); let model = &VoiceModel::sample().await.unwrap(); core.load_model(model).await.unwrap(); diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index 0bc74b6a9..d5a55cfea 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -2,7 +2,7 @@ mod model_file; pub(crate) mod runtimes; pub(crate) mod signatures; -use std::{fmt::Debug, hash::Hash, marker::PhantomData, sync::Arc}; +use std::{fmt::Debug, marker::PhantomData, sync::Arc}; use derive_new::new; use enum_map::{Enum, EnumMap}; @@ -11,9 +11,9 @@ use thiserror::Error; use crate::{ErrorRepr, SupportedDevices}; -pub(crate) trait InferenceRuntime: Copy + Ord + Hash + Debug + 'static { +pub(crate) trait InferenceRuntime: 'static { type Session: Session; - type RunBuilder<'a>: RunBuilder<'a, Runtime = Self>; + type RunBuilder<'a>: RunBuilder<'a, Session = Self::Session>; fn supported_devices() -> crate::Result; } @@ -24,10 +24,8 @@ pub(crate) trait Session: Sized + Send + 'static { ) -> anyhow::Result; } -pub(crate) trait RunBuilder<'a>: - From<&'a mut ::Session> -{ - type Runtime: InferenceRuntime; +pub(crate) trait RunBuilder<'a>: From<&'a mut Self::Session> { + type Session: Session; fn input(&mut self, tensor: Array) -> &mut Self; } @@ -36,7 +34,7 @@ pub(crate) trait InputScalar: LinalgScalar + Debug + sealed::OnnxruntimeInputSca impl InputScalar for i64 {} impl InputScalar for f32 {} -pub(crate) trait Signature: Sized + Send + Sync + 'static { +pub(crate) trait Signature: Sized + Send + 'static { type Kind: Enum; type Output; const KIND: Self::Kind; diff --git a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs index 636efd91a..ebaa2bcda 100644 --- a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs +++ b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs @@ -107,7 +107,7 @@ impl<'sess> From<&'sess mut AssertSend>> } impl<'sess> RunBuilder<'sess> for OnnxruntimeInferenceBuilder<'sess> { - type Runtime = Onnxruntime; + type Session = AssertSend>; fn input(&mut self, tensor: Array) -> &mut Self { self.inputs diff --git a/crates/voicevox_core/src/status.rs b/crates/voicevox_core/src/status.rs index 96f5dcc0a..a4cd4ee84 100644 --- a/crates/voicevox_core/src/status.rs +++ b/crates/voicevox_core/src/status.rs @@ -224,8 +224,8 @@ impl LoadedModels { mod tests { use super::*; - use crate::infer::runtimes::Onnxruntime; use crate::macros::tests::assert_debug_fmt_eq; + use crate::synthesizer::InferenceRuntimeImpl; use pretty_assertions::assert_eq; #[rstest] @@ -237,7 +237,7 @@ mod tests { #[case(false, 8)] #[case(false, 0)] fn status_new_works(#[case] use_gpu: bool, #[case] cpu_num_threads: u16) { - let status = Status::::new(use_gpu, cpu_num_threads); + let status = Status::::new(use_gpu, cpu_num_threads); assert_eq!(false, status.light_session_options.use_gpu); assert_eq!(use_gpu, status.heavy_session_options.use_gpu); assert_eq!( @@ -254,7 +254,7 @@ mod tests { #[rstest] #[tokio::test] async fn status_load_model_works() { - let status = Status::::new(false, 0); + let status = Status::::new(false, 0); let result = status.load_model(&open_default_vvm_file().await).await; assert_debug_fmt_eq!(Ok(()), result); assert_eq!(1, status.loaded_models.lock().unwrap().0.len()); @@ -263,7 +263,7 @@ mod tests { #[rstest] #[tokio::test] async fn status_is_model_loaded_works() { - let status = Status::::new(false, 0); + let status = Status::::new(false, 0); let vvm = open_default_vvm_file().await; assert!( !status.is_loaded_model(vvm.id()),