Skip to content

Commit

Permalink
InferenceOperationKindInferenceOperationImpl
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip committed Nov 15, 2023
1 parent 75fd7ac commit ad222c9
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 15 deletions.
4 changes: 2 additions & 2 deletions crates/voicevox_core/src/infer/domain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ use super::{
pub(crate) enum InferenceDomainImpl {}

impl InferenceDomain for InferenceDomainImpl {
type Operation = InferenceOperationKind;
type Operation = InferenceOperationImpl;
}

#[derive(Clone, Copy, Enum, InferenceOperation)]
#[inference_operation(
type Domain = InferenceDomainImpl;
)]
pub(crate) enum InferenceOperationKind {
pub(crate) enum InferenceOperationImpl {
#[inference_operation(
type Input = PredictDurationInput;
type Output = PredictDurationOutput;
Expand Down
14 changes: 7 additions & 7 deletions crates/voicevox_core/src/infer/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ mod tests {
use rstest::rstest;

use crate::{
infer::domain::{InferenceDomainImpl, InferenceOperationKind},
infer::domain::{InferenceDomainImpl, InferenceOperationImpl},
macros::tests::assert_debug_fmt_eq,
synthesizer::InferenceRuntimeImpl,
test_util::open_default_vvm_file,
Expand All @@ -354,23 +354,23 @@ mod tests {
let light_session_options = InferenceSessionOptions::new(cpu_num_threads, false);
let heavy_session_options = InferenceSessionOptions::new(cpu_num_threads, use_gpu);
let session_options = enum_map! {
InferenceOperationKind::PredictDuration
| InferenceOperationKind::PredictIntonation => light_session_options,
InferenceOperationKind::Decode => heavy_session_options,
InferenceOperationImpl::PredictDuration
| InferenceOperationImpl::PredictIntonation => light_session_options,
InferenceOperationImpl::Decode => heavy_session_options,
};
let status = Status::<InferenceRuntimeImpl, InferenceDomainImpl>::new(session_options);

assert_eq!(
light_session_options,
status.session_options[InferenceOperationKind::PredictDuration],
status.session_options[InferenceOperationImpl::PredictDuration],
);
assert_eq!(
light_session_options,
status.session_options[InferenceOperationKind::PredictIntonation],
status.session_options[InferenceOperationImpl::PredictIntonation],
);
assert_eq!(
heavy_session_options,
status.session_options[InferenceOperationKind::Decode],
status.session_options[InferenceOperationImpl::Decode],
);

assert!(status.loaded_models.lock().unwrap().0.is_empty());
Expand Down
8 changes: 4 additions & 4 deletions crates/voicevox_core/src/inference_core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use enum_map::enum_map;

use crate::infer::{
domain::{
DecodeInput, DecodeOutput, InferenceDomainImpl, InferenceOperationKind,
DecodeInput, DecodeOutput, InferenceDomainImpl, InferenceOperationImpl,
PredictDurationInput, PredictDurationOutput, PredictIntonationInput,
PredictIntonationOutput,
},
Expand All @@ -28,9 +28,9 @@ impl<R: InferenceRuntime> InferenceCore<R> {
let heavy_session_options = InferenceSessionOptions::new(cpu_num_threads, use_gpu);

let status = Status::new(enum_map! {
InferenceOperationKind::PredictDuration
| InferenceOperationKind::PredictIntonation => light_session_options,
InferenceOperationKind::Decode => heavy_session_options,
InferenceOperationImpl::PredictDuration
| InferenceOperationImpl::PredictIntonation => light_session_options,
InferenceOperationImpl::Decode => heavy_session_options,
});
Ok(Self { status })
} else {
Expand Down
4 changes: 2 additions & 2 deletions crates/voicevox_core/src/voice_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use futures::future::join3;
use serde::{de::DeserializeOwned, Deserialize};

use super::*;
use crate::infer::domain::InferenceOperationKind;
use crate::infer::domain::InferenceOperationImpl;
use std::{
collections::{BTreeMap, HashMap},
io,
Expand Down Expand Up @@ -40,7 +40,7 @@ pub struct VoiceModel {
impl VoiceModel {
pub(crate) async fn read_inference_models(
&self,
) -> LoadModelResult<EnumMap<InferenceOperationKind, Vec<u8>>> {
) -> LoadModelResult<EnumMap<InferenceOperationImpl, Vec<u8>>> {
let reader = VvmEntryReader::open(&self.path).await?;
let (decode_model_result, predict_duration_model_result, predict_intonation_model_result) =
join3(
Expand Down

0 comments on commit ad222c9

Please sign in to comment.