Skip to content

Commit

Permalink
R: InferenceCoreSynthesisEngineまで持っていく
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip committed Nov 5, 2023
1 parent 5ff2b59 commit 33245ff
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 69 deletions.
26 changes: 26 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/voicevox_core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ derive-new = "0.5.9"
derive_more.workspace = true
duplicate = "1.0.0"
easy-ext.workspace = true
educe = "0.4.23"
fs-err.workspace = true
futures.workspace = true
indexmap = { version = "2.0.0", features = ["serde"] }
Expand Down
52 changes: 19 additions & 33 deletions crates/voicevox_core/src/engine/synthesis_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::sync::Arc;
use super::full_context_label::Utterance;
use super::open_jtalk::OpenJtalk;
use super::*;
use crate::infer::{InferenceRuntime, Output};
use crate::numerics::F32Ext as _;
use crate::InferenceCore;

Expand All @@ -15,18 +16,19 @@ const MORA_PHONEME_LIST: &[&str] = &[
];

#[derive(new)]
pub struct SynthesisEngine {
inference_core: InferenceCore,
pub(crate) struct SynthesisEngine<R: InferenceRuntime> {
inference_core: InferenceCore<R>,
open_jtalk: Arc<OpenJtalk>,
}

#[allow(unsafe_code)]
unsafe impl Send for SynthesisEngine {}

impl SynthesisEngine {
impl<R> SynthesisEngine<R>
where
R: InferenceRuntime,
(Vec<f32>,): Output<R>,
{
pub const DEFAULT_SAMPLING_RATE: u32 = 24000;

pub fn inference_core(&self) -> &InferenceCore {
pub fn inference_core(&self) -> &InferenceCore<R> {
&self.inference_core
}

Expand Down Expand Up @@ -123,7 +125,7 @@ impl SynthesisEngine {
accent_phrases: &[AccentPhraseModel],
style_id: StyleId,
) -> Result<Vec<AccentPhraseModel>> {
let (_, phoneme_data_list) = SynthesisEngine::initial_process(accent_phrases);
let (_, phoneme_data_list) = Self::initial_process(accent_phrases);

let (_, _, vowel_indexes_data) = split_mora(&phoneme_data_list);

Expand Down Expand Up @@ -185,36 +187,20 @@ impl SynthesisEngine {
accent_phrases: &[AccentPhraseModel],
style_id: StyleId,
) -> Result<Vec<AccentPhraseModel>> {
let (_, phoneme_data_list) = SynthesisEngine::initial_process(accent_phrases);
let (_, phoneme_data_list) = Self::initial_process(accent_phrases);

let mut base_start_accent_list = vec![0];
let mut base_end_accent_list = vec![0];
let mut base_start_accent_phrase_list = vec![0];
let mut base_end_accent_phrase_list = vec![0];
for accent_phrase in accent_phrases {
let mut accent = usize::from(*accent_phrase.accent() != 1);
SynthesisEngine::create_one_accent_list(
&mut base_start_accent_list,
accent_phrase,
accent as i32,
);
Self::create_one_accent_list(&mut base_start_accent_list, accent_phrase, accent as i32);

accent = *accent_phrase.accent() - 1;
SynthesisEngine::create_one_accent_list(
&mut base_end_accent_list,
accent_phrase,
accent as i32,
);
SynthesisEngine::create_one_accent_list(
&mut base_start_accent_phrase_list,
accent_phrase,
0,
);
SynthesisEngine::create_one_accent_list(
&mut base_end_accent_phrase_list,
accent_phrase,
-1,
);
Self::create_one_accent_list(&mut base_end_accent_list, accent_phrase, accent as i32);
Self::create_one_accent_list(&mut base_start_accent_phrase_list, accent_phrase, 0);
Self::create_one_accent_list(&mut base_end_accent_phrase_list, accent_phrase, -1);
}
base_start_accent_list.push(0);
base_end_accent_list.push(0);
Expand Down Expand Up @@ -328,7 +314,7 @@ impl SynthesisEngine {
query.accent_phrases().clone()
};

let (flatten_moras, phoneme_data_list) = SynthesisEngine::initial_process(&accent_phrases);
let (flatten_moras, phoneme_data_list) = Self::initial_process(&accent_phrases);

let mut phoneme_length_list = vec![pre_phoneme_length];
let mut f0_list = vec![0.];
Expand Down Expand Up @@ -647,12 +633,12 @@ mod tests {
use ::test_util::OPEN_JTALK_DIC_DIR;
use pretty_assertions::assert_eq;

use crate::*;
use crate::{infer::runtimes::Onnxruntime, *};

#[rstest]
#[tokio::test]
async fn is_openjtalk_dict_loaded_works() {
let core = InferenceCore::new(false, 0).unwrap();
let core = InferenceCore::<Onnxruntime>::new(false, 0).unwrap();
let synthesis_engine =
SynthesisEngine::new(core, OpenJtalk::new(OPEN_JTALK_DIC_DIR).unwrap().into());

Expand All @@ -662,7 +648,7 @@ mod tests {
#[rstest]
#[tokio::test]
async fn create_accent_phrases_works() {
let core = InferenceCore::new(false, 0).unwrap();
let core = InferenceCore::<Onnxruntime>::new(false, 0).unwrap();

let model = &VoiceModel::sample().await.unwrap();
core.load_model(model).await.unwrap();
Expand Down
6 changes: 3 additions & 3 deletions crates/voicevox_core/src/infer.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
pub(crate) mod runtimes;
pub(crate) mod signatures;

use std::{fmt::Debug, marker::PhantomData, sync::Arc};
use std::{fmt::Debug, hash::Hash, marker::PhantomData, sync::Arc};

use derive_new::new;
use ndarray::{Array, Dimension, LinalgScalar};
use thiserror::Error;

pub(crate) trait InferenceRuntime: Copy {
pub(crate) trait InferenceRuntime: Copy + Ord + Hash + Debug + 'static {
type Session: Session;
type RunBuilder<'a>: RunBuilder<'a, Runtime = Self>;
}

pub(crate) trait Session: Sized + 'static {
pub(crate) trait Session: Sized + Send + 'static {
fn new(
model: impl FnOnce() -> std::result::Result<Vec<u8>, DecryptModelError>,
options: SessionOptions,
Expand Down
2 changes: 1 addition & 1 deletion crates/voicevox_core/src/infer/runtimes/onnxruntime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::infer::{

pub(crate) use self::assert_send::AssertSend;

#[derive(Clone, Copy)]
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
pub(crate) enum Onnxruntime {}

impl InferenceRuntime for Onnxruntime {
Expand Down
15 changes: 11 additions & 4 deletions crates/voicevox_core/src/inference_core.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
use self::status::*;
use super::*;
use crate::infer::signatures::{Decode, PredictDuration, PredictIntonation};
use crate::infer::{
signatures::{Decode, PredictDuration, PredictIntonation},
InferenceRuntime, Output,
};

const PHONEME_LENGTH_MINIMAL: f32 = 0.01;

pub struct InferenceCore {
status: Status,
pub(crate) struct InferenceCore<R: InferenceRuntime> {
status: Status<R>,
}

impl InferenceCore {
impl<R> InferenceCore<R>
where
R: InferenceRuntime,
(Vec<f32>,): Output<R>,
{
pub(crate) fn new(use_gpu: bool, cpu_num_threads: u16) -> Result<Self> {
if !use_gpu || Self::can_support_gpu_feature()? {
let status = Status::new(use_gpu, cpu_num_threads);
Expand Down
54 changes: 28 additions & 26 deletions crates/voicevox_core/src/status.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use super::*;
use crate::infer::{
runtimes::Onnxruntime,
signatures::{Decode, PredictDuration, PredictIntonation, SessionSet},
DecryptModelError, Output, SessionOptions, Signature, TypedSession,
DecryptModelError, InferenceRuntime, Output, SessionOptions, Signature, TypedSession,
};
use derive_more::Index;
use educe::Educe;
use itertools::iproduct;
use std::path::Path;
use std::sync::Arc;
Expand All @@ -13,13 +13,13 @@ mod model_file;

use std::collections::BTreeMap;

pub struct Status {
loaded_models: std::sync::Mutex<LoadedModels>,
pub(crate) struct Status<R: InferenceRuntime> {
loaded_models: std::sync::Mutex<LoadedModels<R>>,
light_session_options: SessionOptions, // 軽いモデルはこちらを使う
heavy_session_options: SessionOptions, // 重いモデルはこちらを使う
}

impl Status {
impl<R: InferenceRuntime> Status<R> {
pub fn new(use_gpu: bool, cpu_num_threads: u16) -> Self {
Self {
loaded_models: Default::default(),
Expand Down Expand Up @@ -89,13 +89,14 @@ impl Status {
model: &[u8],
session_options: &SessionOptions,
path: impl AsRef<Path>,
) -> LoadModelResult<TypedSession<Onnxruntime, S>> {
TypedSession::<Onnxruntime, S>::new(|| model_file::decrypt(model), *session_options)
.map_err(|source| LoadModelError {
) -> LoadModelResult<TypedSession<R, S>> {
TypedSession::<R, S>::new(|| model_file::decrypt(model), *session_options).map_err(
|source| LoadModelError {
path: path.as_ref().to_owned(),
context: LoadModelErrorKind::InvalidModelData,
source: Some(source),
})
},
)
}

pub fn validate_speaker_id(&self, style_id: StyleId) -> bool {
Expand All @@ -112,13 +113,12 @@ impl Status {
) -> Result<S::Output>
where
S: Signature,
for<'a> &'a S::SessionSet<Onnxruntime>: From<&'a SessionSet<Onnxruntime>>,
S::Output: Output<Onnxruntime>,
for<'a> &'a S::SessionSet<R>: From<&'a SessionSet<R>>,
S::Output: Output<R>,
{
let sess = S::get_session::<Onnxruntime>(
(&self.loaded_models.lock().unwrap()[model_id].session_set).into(),
)
.clone();
let sess =
S::get_session::<R>((&self.loaded_models.lock().unwrap()[model_id].session_set).into())
.clone();

tokio::task::spawn_blocking(move || {
let mut sess = sess.lock().unwrap();
Expand All @@ -133,16 +133,17 @@ impl Status {
/// 読み込んだモデルの`Session`とそのメタ情報を保有し、追加/削除/取得の操作を提供する。
///
/// この構造体のメソッドは、すべて一瞬で完了すべきである。
#[derive(Default, Index)]
struct LoadedModels(BTreeMap<VoiceModelId, LoadedModel>);
#[derive(Educe, Index)]
#[educe(Default(bound = "R: InferenceRuntime"))]
struct LoadedModels<R: InferenceRuntime>(BTreeMap<VoiceModelId, LoadedModel<R>>);

struct LoadedModel {
struct LoadedModel<R: InferenceRuntime> {
model_inner_ids: BTreeMap<StyleId, ModelInnerId>,
metas: VoiceModelMeta,
session_set: SessionSet<Onnxruntime>,
session_set: SessionSet<R>,
}

impl LoadedModels {
impl<R: InferenceRuntime> LoadedModels<R> {
fn metas(&self) -> VoiceModelMeta {
self.0
.values()
Expand Down Expand Up @@ -216,9 +217,9 @@ impl LoadedModels {
fn insert(
&mut self,
model: &VoiceModel,
predict_duration: TypedSession<Onnxruntime, PredictDuration>,
predict_intonation: TypedSession<Onnxruntime, PredictIntonation>,
decode: TypedSession<Onnxruntime, Decode>,
predict_duration: TypedSession<R, PredictDuration>,
predict_intonation: TypedSession<R, PredictIntonation>,
decode: TypedSession<R, Decode>,
) -> Result<()> {
self.ensure_acceptable(model)?;

Expand Down Expand Up @@ -260,6 +261,7 @@ impl LoadedModels {
mod tests {

use super::*;
use crate::infer::runtimes::Onnxruntime;
use crate::macros::tests::assert_debug_fmt_eq;
use pretty_assertions::assert_eq;

Expand All @@ -272,7 +274,7 @@ mod tests {
#[case(false, 8)]
#[case(false, 0)]
fn status_new_works(#[case] use_gpu: bool, #[case] cpu_num_threads: u16) {
let status = Status::new(use_gpu, cpu_num_threads);
let status = Status::<Onnxruntime>::new(use_gpu, cpu_num_threads);
assert_eq!(false, status.light_session_options.use_gpu);
assert_eq!(use_gpu, status.heavy_session_options.use_gpu);
assert_eq!(
Expand All @@ -289,7 +291,7 @@ mod tests {
#[rstest]
#[tokio::test]
async fn status_load_model_works() {
let status = Status::new(false, 0);
let status = Status::<Onnxruntime>::new(false, 0);
let result = status.load_model(&open_default_vvm_file().await).await;
assert_debug_fmt_eq!(Ok(()), result);
assert_eq!(1, status.loaded_models.lock().unwrap().0.len());
Expand All @@ -298,7 +300,7 @@ mod tests {
#[rstest]
#[tokio::test]
async fn status_is_model_loaded_works() {
let status = Status::new(false, 0);
let status = Status::<Onnxruntime>::new(false, 0);
let vvm = open_default_vvm_file().await;
assert!(
!status.is_loaded_model(vvm.id()),
Expand Down
9 changes: 7 additions & 2 deletions crates/voicevox_core/src/synthesizer.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use std::sync::Arc;

use crate::engine::{create_kana, parse_kana, AccentPhraseModel, OpenJtalk, SynthesisEngine};
use crate::{
engine::{create_kana, parse_kana, AccentPhraseModel, OpenJtalk, SynthesisEngine},
infer::runtimes::Onnxruntime,
};

use super::*;

Expand Down Expand Up @@ -67,9 +70,11 @@ pub struct InitializeOptions {
pub cpu_num_threads: u16,
}

type SynthesizerInferenceRuntime = Onnxruntime;

/// 音声シンセサイザ。
pub struct Synthesizer {
synthesis_engine: SynthesisEngine,
synthesis_engine: SynthesisEngine<SynthesizerInferenceRuntime>,
use_gpu: bool,
}

Expand Down

0 comments on commit 33245ff

Please sign in to comment.