diff --git a/crates/voicevox_core/src/engine/full_context_label.rs b/crates/voicevox_core/src/engine/full_context_label.rs index 9c0e28a25..b8e3cce2d 100644 --- a/crates/voicevox_core/src/engine/full_context_label.rs +++ b/crates/voicevox_core/src/engine/full_context_label.rs @@ -4,7 +4,7 @@ use super::*; use once_cell::sync::Lazy; use regex::Regex; -#[derive(thiserror::Error, Debug, PartialEq)] +#[derive(thiserror::Error, Debug)] pub enum FullContextLabelError { #[error("label parse error label:{label}")] LabelParse { label: String }, diff --git a/crates/voicevox_core/src/engine/open_jtalk.rs b/crates/voicevox_core/src/engine/open_jtalk.rs index c2d28efbe..60c3acbf1 100644 --- a/crates/voicevox_core/src/engine/open_jtalk.rs +++ b/crates/voicevox_core/src/engine/open_jtalk.rs @@ -14,36 +14,6 @@ pub enum OpenJtalkError { }, } -impl PartialEq for OpenJtalkError { - fn eq(&self, other: &Self) -> bool { - return match (self, other) { - ( - Self::Load { - mecab_dict_dir: mecab_dict_dir1, - }, - Self::Load { - mecab_dict_dir: mecab_dict_dir2, - }, - ) => mecab_dict_dir1 == mecab_dict_dir2, - ( - Self::ExtractFullContext { - text: text1, - source: source1, - }, - Self::ExtractFullContext { - text: text2, - source: source2, - }, - ) => (text1, by_display(source1)) == (text2, by_display(source2)), - _ => false, - }; - - fn by_display(source: &Option) -> impl PartialEq { - source.as_ref().map(|e| e.to_string()) - } - } -} - pub type Result = std::result::Result; pub struct OpenJtalk { @@ -131,10 +101,9 @@ impl OpenJtalk { #[cfg(test)] mod tests { use super::*; - use pretty_assertions::assert_eq; use test_util::OPEN_JTALK_DIC_DIR; - use crate::*; + use crate::{macros::tests::assert_debug_fmt_eq, *}; fn testdata_hello_hiho() -> Vec { // こんにちは、ヒホです。の期待値 @@ -230,7 +199,7 @@ mod tests { let mut open_jtalk = OpenJtalk::initialize(); open_jtalk.load(OPEN_JTALK_DIC_DIR).unwrap(); let result = open_jtalk.extract_fullcontext(text); - assert_eq!(expected, result); + assert_debug_fmt_eq!(expected, result); } #[rstest] @@ -243,7 +212,7 @@ mod tests { open_jtalk.load(OPEN_JTALK_DIC_DIR).unwrap(); for _ in 0..10 { let result = open_jtalk.extract_fullcontext(text); - assert_eq!(expected, result); + assert_debug_fmt_eq!(expected, result); } } } diff --git a/crates/voicevox_core/src/engine/synthesis_engine.rs b/crates/voicevox_core/src/engine/synthesis_engine.rs index f7b63c22e..b16271a50 100644 --- a/crates/voicevox_core/src/engine/synthesis_engine.rs +++ b/crates/voicevox_core/src/engine/synthesis_engine.rs @@ -647,7 +647,7 @@ mod tests { use pretty_assertions::assert_eq; use test_util::OPEN_JTALK_DIC_DIR; - use crate::*; + use crate::{macros::tests::assert_debug_fmt_eq, *}; #[rstest] fn load_openjtalk_dict_works() { @@ -655,10 +655,10 @@ mod tests { let mut synthesis_engine = SynthesisEngine::new(core, OpenJtalk::initialize()); let result = synthesis_engine.load_openjtalk_dict(OPEN_JTALK_DIC_DIR); - assert_eq!(result, Ok(())); + assert_debug_fmt_eq!(result, Ok(())); let result = synthesis_engine.load_openjtalk_dict(""); - assert_eq!(result, Err(Error::NotLoadedOpenjtalkDict)); + assert_debug_fmt_eq!(result, Err(Error::NotLoadedOpenjtalkDict)); } #[rstest] diff --git a/crates/voicevox_core/src/error.rs b/crates/voicevox_core/src/error.rs index 3835b678d..f4e3f882b 100644 --- a/crates/voicevox_core/src/error.rs +++ b/crates/voicevox_core/src/error.rs @@ -68,50 +68,6 @@ pub enum Error { ParseKana(#[from] KanaParseError), } -impl PartialEq for Error { - fn eq(&self, other: &Self) -> bool { - match (self, other) { - (Self::NotLoadedOpenjtalkDict, Self::NotLoadedOpenjtalkDict) - | (Self::GpuSupport, Self::GpuSupport) - | (Self::UninitializedStatus, Self::UninitializedStatus) - | (Self::InferenceFailed, Self::InferenceFailed) => true, - ( - Self::LoadModel { - path: path1, - source: source1, - }, - Self::LoadModel { - path: path2, - source: source2, - }, - ) => (path1, source1.to_string()) == (path2, source2.to_string()), - (Self::LoadMetas(e1), Self::LoadMetas(e2)) - | (Self::GetSupportedDevices(e1), Self::GetSupportedDevices(e2)) => { - e1.to_string() == e2.to_string() - } - ( - Self::InvalidSpeakerId { - speaker_id: speaker_id1, - }, - Self::InvalidSpeakerId { - speaker_id: speaker_id2, - }, - ) => speaker_id1 == speaker_id2, - ( - Self::InvalidModelIndex { - model_index: model_index1, - }, - Self::InvalidModelIndex { - model_index: model_index2, - }, - ) => model_index1 == model_index2, - (Self::ExtractFullContextLabel(e1), Self::ExtractFullContextLabel(e2)) => e1 == e2, - (Self::ParseKana(e1), Self::ParseKana(e2)) => e1 == e2, - _ => false, - } - } -} - fn base_error_message(result_code: VoicevoxResultCode) -> &'static str { let c_message: &'static str = crate::error_result_to_message(result_code); &c_message[..(c_message.len() - 1)] diff --git a/crates/voicevox_core/src/lib.rs b/crates/voicevox_core/src/lib.rs index 90a2bc93a..88cb39be7 100644 --- a/crates/voicevox_core/src/lib.rs +++ b/crates/voicevox_core/src/lib.rs @@ -3,6 +3,7 @@ /// cbindgen:ignore mod engine; mod error; +mod macros; mod numerics; mod publish; mod result; diff --git a/crates/voicevox_core/src/macros.rs b/crates/voicevox_core/src/macros.rs new file mode 100644 index 000000000..db0a4de83 --- /dev/null +++ b/crates/voicevox_core/src/macros.rs @@ -0,0 +1,45 @@ +#[cfg(test)] +pub(crate) mod tests { + use std::fmt::{self, Debug}; + + use pretty_assertions::StrComparison; + + /// 2つの`"{:#?}"`が等しいかを、pretty\_assertions風にassertする。 + /// + /// `io::Error`や`anyhow::Error`を抱えていて`PartialEq`実装が難しい型に使う。 + /// + /// # Panics + /// + /// 2つの`"{:#?}"`が等しくないとき、assertの失敗としてパニックする。 + macro_rules! assert_debug_fmt_eq { + ($left:expr, $right:expr $(,)?) => {{ + crate::macros::tests::__assert_debug_fmt(&$left, &$right, None) + }}; + ($left:expr, $right:expr, $($arg:tt)*) => {{ + crate::macros::tests::__assert_debug_fmt(&$left, &$right, Some(format_args!($($arg)*))) + }}; + } + pub(crate) use assert_debug_fmt_eq; + + #[track_caller] + pub(crate) fn __assert_debug_fmt( + left: &T, + right: &T, + args: Option>, + ) { + let (left, right) = (format!("{left:#?}"), format!("{right:#?}")); + if left != right { + panic!( + r#"assertion failed: `("{{left:#?}}" == "{{right:#?}}")`{} + +{} +"#, + match args { + Some(args) => format!(": {args}"), + None => "".to_owned(), + }, + StrComparison::new(&left, &right), + ); + } + } +} diff --git a/crates/voicevox_core/src/publish.rs b/crates/voicevox_core/src/publish.rs index 41e1b8805..fd0a409e9 100644 --- a/crates/voicevox_core/src/publish.rs +++ b/crates/voicevox_core/src/publish.rs @@ -651,6 +651,7 @@ fn list_windows_video_cards() { #[cfg(test)] mod tests { use super::*; + use crate::macros::tests::assert_debug_fmt_eq; use pretty_assertions::assert_eq; use test_util::OPEN_JTALK_DIC_DIR; @@ -661,7 +662,7 @@ mod tests { .lock() .unwrap() .initialize(InitializeOptions::default()); - assert_eq!(Ok(()), result); + assert_debug_fmt_eq!(Ok(()), result); internal.lock().unwrap().finalize(); assert_eq!( false, @@ -695,7 +696,7 @@ mod tests { ) { let internal = VoicevoxCore::new_with_mutex(); let result = internal.lock().unwrap().load_model(speaker_id); - assert_eq!(expected_result_at_uninitialized, result); + assert_debug_fmt_eq!(expected_result_at_uninitialized, result); internal .lock() @@ -706,9 +707,10 @@ mod tests { }) .unwrap(); let result = internal.lock().unwrap().load_model(speaker_id); - assert_eq!( - expected_result_at_initialized, result, - "got load_model result" + assert_debug_fmt_eq!( + expected_result_at_initialized, + result, + "got load_model result", ); } diff --git a/crates/voicevox_core/src/status.rs b/crates/voicevox_core/src/status.rs index 309d060b8..4c4704bd2 100644 --- a/crates/voicevox_core/src/status.rs +++ b/crates/voicevox_core/src/status.rs @@ -359,6 +359,7 @@ impl Status { mod tests { use super::*; + use crate::macros::tests::assert_debug_fmt_eq; use pretty_assertions::assert_eq; #[rstest] @@ -391,7 +392,7 @@ mod tests { fn status_load_metas_works() { let mut status = Status::new(true, 0); let result = status.load_metas(); - assert_eq!(Ok(()), result); + assert_debug_fmt_eq!(Ok(()), result); let expected = BTreeSet::from([0, 1, 2, 3]); assert_eq!(expected, status.supported_styles); } @@ -407,7 +408,7 @@ mod tests { fn status_load_model_works() { let mut status = Status::new(false, 0); let result = status.load_model(0); - assert_eq!(Ok(()), result); + assert_debug_fmt_eq!(Ok(()), result); assert_eq!(1, status.models.predict_duration.len()); assert_eq!(1, status.models.predict_intonation.len()); assert_eq!(1, status.models.decode.len()); @@ -422,7 +423,7 @@ mod tests { "model should not be loaded" ); let result = status.load_model(model_index); - assert_eq!(Ok(()), result); + assert_debug_fmt_eq!(Ok(()), result); assert!( status.is_model_loaded(model_index), "model should be loaded"