Skip to content

Commit

Permalink
"Domain"と"Operation"に分離
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip committed Nov 14, 2023
1 parent 0998793 commit 75fd7ac
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 58 deletions.
8 changes: 6 additions & 2 deletions crates/voicevox_core/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ pub(crate) trait InferenceRuntime: 'static {
fn run(ctx: Self::RunContext<'_>) -> anyhow::Result<Vec<OutputTensor>>;
}

pub(crate) trait InferenceDomain: Copy + Enum {
pub(crate) trait InferenceDomain {
type Operation: InferenceOperation;
}

pub(crate) trait InferenceOperation: Copy + Enum {
/// `{InferenceInputSignature,InferenceOutputSignature}::PARAM_INFOS`を集めたもの。
///
/// マクロ(voicevox_core_macros)で実装される前提。
Expand All @@ -54,7 +58,7 @@ pub(crate) trait InferenceSignature: Sized + Send + 'static {
type Domain: InferenceDomain;
type Input: InferenceInputSignature<Signature = Self>;
type Output: InferenceOutputSignature;
const KIND: Self::Domain;
const OPERATION: <Self::Domain as crate::infer::InferenceDomain>::Operation;
}

pub(crate) trait InferenceInputSignature: Send + 'static {
Expand Down
25 changes: 18 additions & 7 deletions crates/voicevox_core/src/infer/domain.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,35 @@
use enum_map::Enum;
use macros::{InferenceDomain, InferenceInputSignature, InferenceOutputSignature};
use macros::{InferenceInputSignature, InferenceOperation, InferenceOutputSignature};
use ndarray::{Array0, Array1, Array2};

use super::{InferenceInputSignature as _, InferenceOutputSignature as _, OutputTensor};
use super::{
InferenceDomain, InferenceInputSignature as _, InferenceOutputSignature as _, OutputTensor,
};

#[derive(Clone, Copy, Enum, InferenceDomain)]
pub(crate) enum InferenceKind {
#[inference_domain(
pub(crate) enum InferenceDomainImpl {}

impl InferenceDomain for InferenceDomainImpl {
type Operation = InferenceOperationKind;
}

#[derive(Clone, Copy, Enum, InferenceOperation)]
#[inference_operation(
type Domain = InferenceDomainImpl;
)]
pub(crate) enum InferenceOperationKind {
#[inference_operation(
type Input = PredictDurationInput;
type Output = PredictDurationOutput;
)]
PredictDuration,

#[inference_domain(
#[inference_operation(
type Input = PredictIntonationInput;
type Output = PredictIntonationOutput;
)]
PredictIntonation,

#[inference_domain(
#[inference_operation(
type Input = DecodeInput;
type Output = DecodeOutput;
)]
Expand Down
53 changes: 28 additions & 25 deletions crates/voicevox_core/src/infer/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ use std::{

use anyhow::bail;
use educe::Educe;
use enum_map::EnumMap;
use enum_map::{Enum as _, EnumMap};
use itertools::{iproduct, Itertools as _};

use crate::{
error::{ErrorRepr, LoadModelError, LoadModelErrorKind, LoadModelResult},
infer::ParamInfo,
infer::{InferenceOperation, ParamInfo},
manifest::ModelInnerId,
metas::{SpeakerMeta, StyleId, StyleMeta, VoiceModelMeta},
voice_model::{VoiceModel, VoiceModelId},
Expand All @@ -26,11 +26,11 @@ use super::{

pub(crate) struct Status<R: InferenceRuntime, D: InferenceDomain> {
loaded_models: std::sync::Mutex<LoadedModels<R, D>>,
session_options: EnumMap<D, InferenceSessionOptions>,
session_options: EnumMap<D::Operation, InferenceSessionOptions>,
}

impl<R: InferenceRuntime, D: InferenceDomain> Status<R, D> {
pub fn new(session_options: EnumMap<D, InferenceSessionOptions>) -> Self {
pub fn new(session_options: EnumMap<D::Operation, InferenceSessionOptions>) -> Self {
Self {
loaded_models: Default::default(),
session_options,
Expand All @@ -40,7 +40,7 @@ impl<R: InferenceRuntime, D: InferenceDomain> Status<R, D> {
pub async fn load_model(
&self,
model: &VoiceModel,
model_bytes: &EnumMap<D, Vec<u8>>,
model_bytes: &EnumMap<D::Operation, Vec<u8>>,
) -> Result<()> {
self.loaded_models
.lock()
Expand Down Expand Up @@ -241,30 +241,31 @@ impl<R: InferenceRuntime, D: InferenceDomain> LoadedModels<R, D> {
}

struct SessionSet<R: InferenceRuntime, D: InferenceDomain>(
EnumMap<D, Arc<std::sync::Mutex<R::Session>>>,
EnumMap<D::Operation, Arc<std::sync::Mutex<R::Session>>>,
);

impl<R: InferenceRuntime, D: InferenceDomain> SessionSet<R, D> {
fn new(
model_bytes: &EnumMap<D, Vec<u8>>,
options: &EnumMap<D, InferenceSessionOptions>,
model_bytes: &EnumMap<D::Operation, Vec<u8>>,
options: &EnumMap<D::Operation, InferenceSessionOptions>,
) -> anyhow::Result<Self> {
let mut sessions = model_bytes
.iter()
.map(|(k, m)| {
let (expected_input_param_infos, expected_output_param_infos) = D::PARAM_INFOS[k];
.map(|(op, model_bytes)| {
let (expected_input_param_infos, expected_output_param_infos) =
<D::Operation as InferenceOperation>::PARAM_INFOS[op];

let (sess, actual_input_param_infos, actual_output_param_infos) =
R::new_session(|| model_file::decrypt(m), options[k])?;
R::new_session(|| model_file::decrypt(model_bytes), options[op])?;

check_param_infos(expected_input_param_infos, &actual_input_param_infos)?;
check_param_infos(expected_output_param_infos, &actual_output_param_infos)?;

Ok((k.into_usize(), std::sync::Mutex::new(sess).into()))
Ok((op.into_usize(), std::sync::Mutex::new(sess).into()))
})
.collect::<anyhow::Result<HashMap<_, _>>>()?;

return Ok(Self(EnumMap::<D, _>::from_fn(|k| {
return Ok(Self(EnumMap::<D::Operation, _>::from_fn(|k| {
sessions.remove(&k.into_usize()).expect("should exist")
})));

Expand Down Expand Up @@ -305,7 +306,7 @@ impl<R: InferenceRuntime, D: InferenceDomain> SessionSet<R, D> {
I::Signature: InferenceSignature<Domain = D>,
{
SessionCell {
inner: self.0[I::Signature::KIND].clone(),
inner: self.0[I::Signature::OPERATION].clone(),
marker: PhantomData,
}
}
Expand Down Expand Up @@ -333,8 +334,10 @@ mod tests {
use rstest::rstest;

use crate::{
infer::domain::InferenceKind, macros::tests::assert_debug_fmt_eq,
synthesizer::InferenceRuntimeImpl, test_util::open_default_vvm_file,
infer::domain::{InferenceDomainImpl, InferenceOperationKind},
macros::tests::assert_debug_fmt_eq,
synthesizer::InferenceRuntimeImpl,
test_util::open_default_vvm_file,
};

use super::{super::InferenceSessionOptions, Status};
Expand All @@ -351,23 +354,23 @@ mod tests {
let light_session_options = InferenceSessionOptions::new(cpu_num_threads, false);
let heavy_session_options = InferenceSessionOptions::new(cpu_num_threads, use_gpu);
let session_options = enum_map! {
InferenceKind::PredictDuration
| InferenceKind::PredictIntonation => light_session_options,
InferenceKind::Decode => heavy_session_options,
InferenceOperationKind::PredictDuration
| InferenceOperationKind::PredictIntonation => light_session_options,
InferenceOperationKind::Decode => heavy_session_options,
};
let status = Status::<InferenceRuntimeImpl, InferenceKind>::new(session_options);
let status = Status::<InferenceRuntimeImpl, InferenceDomainImpl>::new(session_options);

assert_eq!(
light_session_options,
status.session_options[InferenceKind::PredictDuration],
status.session_options[InferenceOperationKind::PredictDuration],
);
assert_eq!(
light_session_options,
status.session_options[InferenceKind::PredictIntonation],
status.session_options[InferenceOperationKind::PredictIntonation],
);
assert_eq!(
heavy_session_options,
status.session_options[InferenceKind::Decode],
status.session_options[InferenceOperationKind::Decode],
);

assert!(status.loaded_models.lock().unwrap().0.is_empty());
Expand All @@ -376,7 +379,7 @@ mod tests {
#[rstest]
#[tokio::test]
async fn status_load_model_works() {
let status = Status::<InferenceRuntimeImpl, InferenceKind>::new(
let status = Status::<InferenceRuntimeImpl, InferenceDomainImpl>::new(
enum_map!(_ => InferenceSessionOptions::new(0, false)),
);
let model = &open_default_vvm_file().await;
Expand All @@ -389,7 +392,7 @@ mod tests {
#[rstest]
#[tokio::test]
async fn status_is_model_loaded_works() {
let status = Status::<InferenceRuntimeImpl, InferenceKind>::new(
let status = Status::<InferenceRuntimeImpl, InferenceDomainImpl>::new(
enum_map!(_ => InferenceSessionOptions::new(0, false)),
);
let vvm = open_default_vvm_file().await;
Expand Down
13 changes: 7 additions & 6 deletions crates/voicevox_core/src/inference_core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ use enum_map::enum_map;

use crate::infer::{
domain::{
DecodeInput, DecodeOutput, InferenceKind, PredictDurationInput, PredictDurationOutput,
PredictIntonationInput, PredictIntonationOutput,
DecodeInput, DecodeOutput, InferenceDomainImpl, InferenceOperationKind,
PredictDurationInput, PredictDurationOutput, PredictIntonationInput,
PredictIntonationOutput,
},
status::Status,
InferenceRuntime, InferenceSessionOptions,
Expand All @@ -14,7 +15,7 @@ use super::*;
const PHONEME_LENGTH_MINIMAL: f32 = 0.01;

pub(crate) struct InferenceCore<R: InferenceRuntime> {
status: Status<R, InferenceKind>,
status: Status<R, InferenceDomainImpl>,
}

impl<R: InferenceRuntime> InferenceCore<R> {
Expand All @@ -27,9 +28,9 @@ impl<R: InferenceRuntime> InferenceCore<R> {
let heavy_session_options = InferenceSessionOptions::new(cpu_num_threads, use_gpu);

let status = Status::new(enum_map! {
InferenceKind::PredictDuration
| InferenceKind::PredictIntonation => light_session_options,
InferenceKind::Decode => heavy_session_options,
InferenceOperationKind::PredictDuration
| InferenceOperationKind::PredictIntonation => light_session_options,
InferenceOperationKind::Decode => heavy_session_options,
});
Ok(Self { status })
} else {
Expand Down
4 changes: 2 additions & 2 deletions crates/voicevox_core/src/voice_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use futures::future::join3;
use serde::{de::DeserializeOwned, Deserialize};

use super::*;
use crate::infer::domain::InferenceKind;
use crate::infer::domain::InferenceOperationKind;
use std::{
collections::{BTreeMap, HashMap},
io,
Expand Down Expand Up @@ -40,7 +40,7 @@ pub struct VoiceModel {
impl VoiceModel {
pub(crate) async fn read_inference_models(
&self,
) -> LoadModelResult<EnumMap<InferenceKind, Vec<u8>>> {
) -> LoadModelResult<EnumMap<InferenceOperationKind, Vec<u8>>> {
let reader = VvmEntryReader::open(&self.path).await?;
let (decode_model_result, predict_duration_model_result, predict_intonation_model_result) =
join3(
Expand Down
41 changes: 34 additions & 7 deletions crates/voicevox_core_macros/src/inference_domain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,41 @@ use syn::{
ItemType, Type, Variant,
};

pub(crate) fn derive_inference_domain(
pub(crate) fn derive_inference_operation(
input: &DeriveInput,
) -> syn::Result<proc_macro2::TokenStream> {
let DeriveInput {
attrs,
vis,
ident: domain_name,
ident: operation_ty_name,
generics,
data,
..
} = input;

deny_generics(generics)?;

let AssocTypeDomain(domain_ty) = attrs
.iter()
.find(|a| a.path().is_ident("inference_operation"))
.ok_or_else(|| {
syn::Error::new(
proc_macro2::Span::call_site(),
"missing `#[inference_operation(…)]`",
)
})?
.parse_args()?;

let variants = unit_enum_variants(data)?
.into_iter()
.map(|(attrs, variant_name)| {
let AssocTypes { input, output } = attrs
.iter()
.find(|a| a.path().is_ident("inference_domain"))
.find(|a| a.path().is_ident("inference_operation"))
.ok_or_else(|| {
syn::Error::new(
proc_macro2::Span::call_site(),
"missing `#[inference_domain(…)]`",
"missing `#[inference_operation(…)]`",
)
})?
.parse_args()?;
Expand All @@ -47,16 +59,18 @@ pub(crate) fn derive_inference_domain(
#vis enum #variant_name {}

impl crate::infer::InferenceSignature for #variant_name {
type Domain = #domain_name;
type Domain = #domain_ty;
type Input = #input_ty;
type Output = #output_ty;
const KIND: Self::Domain = #domain_name :: #variant_name;

const OPERATION: <Self::Domain as crate::infer::InferenceDomain>::Operation =
#operation_ty_name :: #variant_name;
}
}
});

return Ok(quote! {
impl crate::infer::InferenceDomain for #domain_name {
impl crate::infer::InferenceOperation for #operation_ty_name {
const PARAM_INFOS: ::enum_map::EnumMap<
Self,
(
Expand All @@ -74,6 +88,19 @@ pub(crate) fn derive_inference_domain(
#(#signatures)*
});

struct AssocTypeDomain(Type);

impl Parse for AssocTypeDomain {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
let ItemType { ident, ty, .. } = input.parse()?;

if ident != "Domain" {
return Err(syn::Error::new(ident.span(), "expected `Domain`"));
}
Ok(Self(*ty))
}
}

struct AssocTypes {
input: Type,
output: Type,
Expand Down
Loading

0 comments on commit 75fd7ac

Please sign in to comment.