Skip to content

Commit

Permalink
"model"ではなく"inference"と呼ぶ
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip committed Nov 11, 2023
1 parent 8b4f3b6 commit c40afd5
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 28 deletions.
16 changes: 8 additions & 8 deletions crates/voicevox_core/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,26 +67,26 @@ pub(crate) trait SupportsInferenceOutput<O: Send>: InferenceRuntime {
fn run(ctx: Self::RunContext<'_>) -> anyhow::Result<O>;
}

pub(crate) trait InferenceModelGroup {
pub(crate) trait InferenceGroup {
type Kind: Copy + Enum;
}

pub(crate) trait InferenceSignature: Sized + Send + 'static {
type ModelGroup: InferenceModelGroup;
type Group: InferenceGroup;
type Input: InferenceInputSignature<Signature = Self>;
type Output: Send;
const MODEL: <Self::ModelGroup as InferenceModelGroup>::Kind;
const INFERENCE: <Self::Group as InferenceGroup>::Kind;
}

pub(crate) trait InferenceInputSignature: Send + 'static {
type Signature: InferenceSignature<Input = Self>;
}

pub(crate) struct InferenceSessionSet<G: InferenceModelGroup, R: InferenceRuntime>(
pub(crate) struct InferenceSessionSet<G: InferenceGroup, R: InferenceRuntime>(
EnumMap<G::Kind, Arc<std::sync::Mutex<R::Session>>>,
);

impl<G: InferenceModelGroup, R: InferenceRuntime> InferenceSessionSet<G, R> {
impl<G: InferenceGroup, R: InferenceRuntime> InferenceSessionSet<G, R> {
pub(crate) fn new(
model_bytes: &EnumMap<G::Kind, Vec<u8>>,
mut options: impl FnMut(G::Kind) -> InferenceSessionOptions,
Expand All @@ -105,14 +105,14 @@ impl<G: InferenceModelGroup, R: InferenceRuntime> InferenceSessionSet<G, R> {
}
}

impl<G: InferenceModelGroup, R: InferenceRuntime> InferenceSessionSet<G, R> {
impl<G: InferenceGroup, R: InferenceRuntime> InferenceSessionSet<G, R> {
pub(crate) fn get<I>(&self) -> InferenceSessionCell<R, I>
where
I: InferenceInputSignature,
I::Signature: InferenceSignature<ModelGroup = G>,
I::Signature: InferenceSignature<Group = G>,
{
InferenceSessionCell {
inner: self.0[I::Signature::MODEL].clone(),
inner: self.0[I::Signature::INFERENCE].clone(),
marker: PhantomData,
}
}
Expand Down
22 changes: 11 additions & 11 deletions crates/voicevox_core/src/infer/signatures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@ use enum_map::Enum;
use ndarray::{Array0, Array1, Array2};

use crate::infer::{
InferenceInputSignature, InferenceModelGroup, InferenceSignature, RunContextExt as _,
InferenceGroup, InferenceInputSignature, InferenceSignature, RunContextExt as _,
SupportsInferenceInputSignature, SupportsInferenceInputTensor,
};

pub(crate) enum InferenceModelGroupImpl {}
pub(crate) enum InferenceGroupImpl {}

impl InferenceModelGroup for InferenceModelGroupImpl {
type Kind = InferenceModelKindImpl;
impl InferenceGroup for InferenceGroupImpl {
type Kind = InferencelKindImpl;
}

#[derive(Clone, Copy, Enum)]
pub(crate) enum InferenceModelKindImpl {
pub(crate) enum InferencelKindImpl {
PredictDuration,
PredictIntonation,
Decode,
Expand All @@ -22,10 +22,10 @@ pub(crate) enum InferenceModelKindImpl {
pub(crate) enum PredictDuration {}

impl InferenceSignature for PredictDuration {
type ModelGroup = InferenceModelGroupImpl;
type Group = InferenceGroupImpl;
type Input = PredictDurationInput;
type Output = (Vec<f32>,);
const MODEL: InferenceModelKindImpl = InferenceModelKindImpl::PredictDuration;
const INFERENCE: InferencelKindImpl = InferencelKindImpl::PredictDuration;
}

pub(crate) struct PredictDurationInput {
Expand Down Expand Up @@ -53,10 +53,10 @@ impl<R: SupportsInferenceInputTensor<Array1<i64>>>
pub(crate) enum PredictIntonation {}

impl InferenceSignature for PredictIntonation {
type ModelGroup = InferenceModelGroupImpl;
type Group = InferenceGroupImpl;
type Input = PredictIntonationInput;
type Output = (Vec<f32>,);
const MODEL: InferenceModelKindImpl = InferenceModelKindImpl::PredictIntonation;
const INFERENCE: InferencelKindImpl = InferencelKindImpl::PredictIntonation;
}

pub(crate) struct PredictIntonationInput {
Expand Down Expand Up @@ -96,10 +96,10 @@ impl<R: SupportsInferenceInputTensor<Array0<i64>> + SupportsInferenceInputTensor
pub(crate) enum Decode {}

impl InferenceSignature for Decode {
type ModelGroup = InferenceModelGroupImpl;
type Group = InferenceGroupImpl;
type Input = DecodeInput;
type Output = (Vec<f32>,);
const MODEL: InferenceModelKindImpl = InferenceModelKindImpl::Decode;
const INFERENCE: InferencelKindImpl = InferencelKindImpl::Decode;
}

pub(crate) struct DecodeInput {
Expand Down
14 changes: 7 additions & 7 deletions crates/voicevox_core/src/status.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::*;
use crate::infer::{
signatures::{InferenceModelGroupImpl, InferenceModelKindImpl},
signatures::{InferenceGroupImpl, InferencelKindImpl},
InferenceInputSignature, InferenceRuntime, InferenceSessionCell, InferenceSessionOptions,
InferenceSessionSet, InferenceSignature, SupportsInferenceInputSignature,
SupportsInferenceOutput,
Expand Down Expand Up @@ -34,10 +34,10 @@ impl<R: InferenceRuntime> Status<R> {
let model_bytes = &model.read_inference_models().await?;

let session_set = InferenceSessionSet::new(model_bytes, |kind| match kind {
InferenceModelKindImpl::PredictDuration | InferenceModelKindImpl::PredictIntonation => {
InferencelKindImpl::PredictDuration | InferencelKindImpl::PredictIntonation => {
self.light_session_options
}
InferenceModelKindImpl::Decode => self.heavy_session_options,
InferencelKindImpl::Decode => self.heavy_session_options,
})
.map_err(|source| LoadModelError {
path: model.path().clone(),
Expand Down Expand Up @@ -89,7 +89,7 @@ impl<R: InferenceRuntime> Status<R> {
) -> Result<<I::Signature as InferenceSignature>::Output>
where
I: InferenceInputSignature,
I::Signature: InferenceSignature<ModelGroup = InferenceModelGroupImpl>,
I::Signature: InferenceSignature<Group = InferenceGroupImpl>,
R: SupportsInferenceInputSignature<I>
+ SupportsInferenceOutput<<I::Signature as InferenceSignature>::Output>,
{
Expand All @@ -111,7 +111,7 @@ struct LoadedModels<R: InferenceRuntime>(BTreeMap<VoiceModelId, LoadedModel<R>>)
struct LoadedModel<R: InferenceRuntime> {
model_inner_ids: BTreeMap<StyleId, ModelInnerId>,
metas: VoiceModelMeta,
session_set: InferenceSessionSet<InferenceModelGroupImpl, R>,
session_set: InferenceSessionSet<InferenceGroupImpl, R>,
}

impl<R: InferenceRuntime> LoadedModels<R> {
Expand Down Expand Up @@ -153,7 +153,7 @@ impl<R: InferenceRuntime> LoadedModels<R> {
fn get<I>(&self, model_id: &VoiceModelId) -> InferenceSessionCell<R, I>
where
I: InferenceInputSignature,
I::Signature: InferenceSignature<ModelGroup = InferenceModelGroupImpl>,
I::Signature: InferenceSignature<Group = InferenceGroupImpl>,
{
self.0[model_id].session_set.get()
}
Expand Down Expand Up @@ -199,7 +199,7 @@ impl<R: InferenceRuntime> LoadedModels<R> {
fn insert(
&mut self,
model: &VoiceModel,
session_set: InferenceSessionSet<InferenceModelGroupImpl, R>,
session_set: InferenceSessionSet<InferenceGroupImpl, R>,
) -> Result<()> {
self.ensure_acceptable(model)?;

Expand Down
4 changes: 2 additions & 2 deletions crates/voicevox_core/src/voice_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use futures::future::join3;
use serde::{de::DeserializeOwned, Deserialize};

use super::*;
use crate::infer::signatures::InferenceModelKindImpl;
use crate::infer::signatures::InferencelKindImpl;
use std::{
collections::{BTreeMap, HashMap},
io,
Expand Down Expand Up @@ -40,7 +40,7 @@ pub struct VoiceModel {
impl VoiceModel {
pub(crate) async fn read_inference_models(
&self,
) -> LoadModelResult<EnumMap<InferenceModelKindImpl, Vec<u8>>> {
) -> LoadModelResult<EnumMap<InferencelKindImpl, Vec<u8>>> {
let reader = VvmEntryReader::open(&self.path).await?;
let (decode_model_result, predict_duration_model_result, predict_intonation_model_result) =
join3(
Expand Down

0 comments on commit c40afd5

Please sign in to comment.