Skip to content

Commit

Permalink
Minor refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip committed Nov 6, 2023
1 parent 192417f commit 20db67a
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 16 deletions.
6 changes: 3 additions & 3 deletions crates/voicevox_core/src/engine/synthesis_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -633,12 +633,12 @@ mod tests {
use ::test_util::OPEN_JTALK_DIC_DIR;
use pretty_assertions::assert_eq;

use crate::{infer::runtimes::Onnxruntime, *};
use crate::{synthesizer::InferenceRuntimeImpl, *};

#[rstest]
#[tokio::test]
async fn is_openjtalk_dict_loaded_works() {
let core = InferenceCore::<Onnxruntime>::new(false, 0).unwrap();
let core = InferenceCore::<InferenceRuntimeImpl>::new(false, 0).unwrap();
let synthesis_engine =
SynthesisEngine::new(core, OpenJtalk::new(OPEN_JTALK_DIC_DIR).unwrap().into());

Expand All @@ -648,7 +648,7 @@ mod tests {
#[rstest]
#[tokio::test]
async fn create_accent_phrases_works() {
let core = InferenceCore::<Onnxruntime>::new(false, 0).unwrap();
let core = InferenceCore::<InferenceRuntimeImpl>::new(false, 0).unwrap();

let model = &VoiceModel::sample().await.unwrap();
core.load_model(model).await.unwrap();
Expand Down
14 changes: 6 additions & 8 deletions crates/voicevox_core/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ mod model_file;
pub(crate) mod runtimes;
pub(crate) mod signatures;

use std::{fmt::Debug, hash::Hash, marker::PhantomData, sync::Arc};
use std::{fmt::Debug, marker::PhantomData, sync::Arc};

use derive_new::new;
use enum_map::{Enum, EnumMap};
Expand All @@ -11,9 +11,9 @@ use thiserror::Error;

use crate::{ErrorRepr, SupportedDevices};

pub(crate) trait InferenceRuntime: Copy + Ord + Hash + Debug + 'static {
pub(crate) trait InferenceRuntime: 'static {
type Session: Session;
type RunBuilder<'a>: RunBuilder<'a, Runtime = Self>;
type RunBuilder<'a>: RunBuilder<'a, Session = Self::Session>;
fn supported_devices() -> crate::Result<SupportedDevices>;
}

Expand All @@ -24,10 +24,8 @@ pub(crate) trait Session: Sized + Send + 'static {
) -> anyhow::Result<Self>;
}

pub(crate) trait RunBuilder<'a>:
From<&'a mut <Self::Runtime as InferenceRuntime>::Session>
{
type Runtime: InferenceRuntime;
pub(crate) trait RunBuilder<'a>: From<&'a mut Self::Session> {
type Session: Session;
fn input(&mut self, tensor: Array<impl InputScalar, impl Dimension + 'static>) -> &mut Self;
}

Expand All @@ -36,7 +34,7 @@ pub(crate) trait InputScalar: LinalgScalar + Debug + sealed::OnnxruntimeInputSca
impl InputScalar for i64 {}
impl InputScalar for f32 {}

pub(crate) trait Signature: Sized + Send + Sync + 'static {
pub(crate) trait Signature: Sized + Send + 'static {
type Kind: Enum;
type Output;
const KIND: Self::Kind;
Expand Down
2 changes: 1 addition & 1 deletion crates/voicevox_core/src/infer/runtimes/onnxruntime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ impl<'sess> From<&'sess mut AssertSend<onnxruntime::session::Session<'static>>>
}

impl<'sess> RunBuilder<'sess> for OnnxruntimeInferenceBuilder<'sess> {
type Runtime = Onnxruntime;
type Session = AssertSend<onnxruntime::session::Session<'static>>;

fn input(&mut self, tensor: Array<impl InputScalar, impl Dimension + 'static>) -> &mut Self {
self.inputs
Expand Down
8 changes: 4 additions & 4 deletions crates/voicevox_core/src/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,8 @@ impl<R: InferenceRuntime> LoadedModels<R> {
mod tests {

use super::*;
use crate::infer::runtimes::Onnxruntime;
use crate::macros::tests::assert_debug_fmt_eq;
use crate::synthesizer::InferenceRuntimeImpl;
use pretty_assertions::assert_eq;

#[rstest]
Expand All @@ -237,7 +237,7 @@ mod tests {
#[case(false, 8)]
#[case(false, 0)]
fn status_new_works(#[case] use_gpu: bool, #[case] cpu_num_threads: u16) {
let status = Status::<Onnxruntime>::new(use_gpu, cpu_num_threads);
let status = Status::<InferenceRuntimeImpl>::new(use_gpu, cpu_num_threads);
assert_eq!(false, status.light_session_options.use_gpu);
assert_eq!(use_gpu, status.heavy_session_options.use_gpu);
assert_eq!(
Expand All @@ -254,7 +254,7 @@ mod tests {
#[rstest]
#[tokio::test]
async fn status_load_model_works() {
let status = Status::<Onnxruntime>::new(false, 0);
let status = Status::<InferenceRuntimeImpl>::new(false, 0);
let result = status.load_model(&open_default_vvm_file().await).await;
assert_debug_fmt_eq!(Ok(()), result);
assert_eq!(1, status.loaded_models.lock().unwrap().0.len());
Expand All @@ -263,7 +263,7 @@ mod tests {
#[rstest]
#[tokio::test]
async fn status_is_model_loaded_works() {
let status = Status::<Onnxruntime>::new(false, 0);
let status = Status::<InferenceRuntimeImpl>::new(false, 0);
let vvm = open_default_vvm_file().await;
assert!(
!status.is_loaded_model(vvm.id()),
Expand Down

0 comments on commit 20db67a

Please sign in to comment.