Skip to content

Commit

Permalink
anyhow::Errorを抱えたエラー型は"{:#?}"を比較する (VOICEVOX#443)
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip authored Apr 10, 2023
1 parent 8655283 commit 44c4582
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 90 deletions.
2 changes: 1 addition & 1 deletion crates/voicevox_core/src/engine/full_context_label.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use super::*;
use once_cell::sync::Lazy;
use regex::Regex;

#[derive(thiserror::Error, Debug, PartialEq)]
#[derive(thiserror::Error, Debug)]
pub enum FullContextLabelError {
#[error("label parse error label:{label}")]
LabelParse { label: String },
Expand Down
37 changes: 3 additions & 34 deletions crates/voicevox_core/src/engine/open_jtalk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,36 +14,6 @@ pub enum OpenJtalkError {
},
}

impl PartialEq for OpenJtalkError {
fn eq(&self, other: &Self) -> bool {
return match (self, other) {
(
Self::Load {
mecab_dict_dir: mecab_dict_dir1,
},
Self::Load {
mecab_dict_dir: mecab_dict_dir2,
},
) => mecab_dict_dir1 == mecab_dict_dir2,
(
Self::ExtractFullContext {
text: text1,
source: source1,
},
Self::ExtractFullContext {
text: text2,
source: source2,
},
) => (text1, by_display(source1)) == (text2, by_display(source2)),
_ => false,
};

fn by_display(source: &Option<anyhow::Error>) -> impl PartialEq {
source.as_ref().map(|e| e.to_string())
}
}
}

pub type Result<T> = std::result::Result<T, OpenJtalkError>;

pub struct OpenJtalk {
Expand Down Expand Up @@ -131,10 +101,9 @@ impl OpenJtalk {
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
use test_util::OPEN_JTALK_DIC_DIR;

use crate::*;
use crate::{macros::tests::assert_debug_fmt_eq, *};

fn testdata_hello_hiho() -> Vec<String> {
// こんにちは、ヒホです。の期待値
Expand Down Expand Up @@ -230,7 +199,7 @@ mod tests {
let mut open_jtalk = OpenJtalk::initialize();
open_jtalk.load(OPEN_JTALK_DIC_DIR).unwrap();
let result = open_jtalk.extract_fullcontext(text);
assert_eq!(expected, result);
assert_debug_fmt_eq!(expected, result);
}

#[rstest]
Expand All @@ -243,7 +212,7 @@ mod tests {
open_jtalk.load(OPEN_JTALK_DIC_DIR).unwrap();
for _ in 0..10 {
let result = open_jtalk.extract_fullcontext(text);
assert_eq!(expected, result);
assert_debug_fmt_eq!(expected, result);
}
}
}
6 changes: 3 additions & 3 deletions crates/voicevox_core/src/engine/synthesis_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -647,18 +647,18 @@ mod tests {
use pretty_assertions::assert_eq;
use test_util::OPEN_JTALK_DIC_DIR;

use crate::*;
use crate::{macros::tests::assert_debug_fmt_eq, *};

#[rstest]
fn load_openjtalk_dict_works() {
let core = InferenceCore::new(false, None);
let mut synthesis_engine = SynthesisEngine::new(core, OpenJtalk::initialize());

let result = synthesis_engine.load_openjtalk_dict(OPEN_JTALK_DIC_DIR);
assert_eq!(result, Ok(()));
assert_debug_fmt_eq!(result, Ok(()));

let result = synthesis_engine.load_openjtalk_dict("");
assert_eq!(result, Err(Error::NotLoadedOpenjtalkDict));
assert_debug_fmt_eq!(result, Err(Error::NotLoadedOpenjtalkDict));
}

#[rstest]
Expand Down
44 changes: 0 additions & 44 deletions crates/voicevox_core/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,50 +68,6 @@ pub enum Error {
ParseKana(#[from] KanaParseError),
}

impl PartialEq for Error {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::NotLoadedOpenjtalkDict, Self::NotLoadedOpenjtalkDict)
| (Self::GpuSupport, Self::GpuSupport)
| (Self::UninitializedStatus, Self::UninitializedStatus)
| (Self::InferenceFailed, Self::InferenceFailed) => true,
(
Self::LoadModel {
path: path1,
source: source1,
},
Self::LoadModel {
path: path2,
source: source2,
},
) => (path1, source1.to_string()) == (path2, source2.to_string()),
(Self::LoadMetas(e1), Self::LoadMetas(e2))
| (Self::GetSupportedDevices(e1), Self::GetSupportedDevices(e2)) => {
e1.to_string() == e2.to_string()
}
(
Self::InvalidSpeakerId {
speaker_id: speaker_id1,
},
Self::InvalidSpeakerId {
speaker_id: speaker_id2,
},
) => speaker_id1 == speaker_id2,
(
Self::InvalidModelIndex {
model_index: model_index1,
},
Self::InvalidModelIndex {
model_index: model_index2,
},
) => model_index1 == model_index2,
(Self::ExtractFullContextLabel(e1), Self::ExtractFullContextLabel(e2)) => e1 == e2,
(Self::ParseKana(e1), Self::ParseKana(e2)) => e1 == e2,
_ => false,
}
}
}

fn base_error_message(result_code: VoicevoxResultCode) -> &'static str {
let c_message: &'static str = crate::error_result_to_message(result_code);
&c_message[..(c_message.len() - 1)]
Expand Down
1 change: 1 addition & 0 deletions crates/voicevox_core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
/// cbindgen:ignore
mod engine;
mod error;
mod macros;
mod numerics;
mod publish;
mod result;
Expand Down
45 changes: 45 additions & 0 deletions crates/voicevox_core/src/macros.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#[cfg(test)]
pub(crate) mod tests {
use std::fmt::{self, Debug};

use pretty_assertions::StrComparison;

/// 2つの`"{:#?}"`が等しいかを、pretty\_assertions風にassertする。
///
/// `io::Error`や`anyhow::Error`を抱えていて`PartialEq`実装が難しい型に使う。
///
/// # Panics
///
/// 2つの`"{:#?}"`が等しくないとき、assertの失敗としてパニックする。
macro_rules! assert_debug_fmt_eq {
($left:expr, $right:expr $(,)?) => {{
crate::macros::tests::__assert_debug_fmt(&$left, &$right, None)
}};
($left:expr, $right:expr, $($arg:tt)*) => {{
crate::macros::tests::__assert_debug_fmt(&$left, &$right, Some(format_args!($($arg)*)))
}};
}
pub(crate) use assert_debug_fmt_eq;

#[track_caller]
pub(crate) fn __assert_debug_fmt<T: Debug>(
left: &T,
right: &T,
args: Option<fmt::Arguments<'_>>,
) {
let (left, right) = (format!("{left:#?}"), format!("{right:#?}"));
if left != right {
panic!(
r#"assertion failed: `("{{left:#?}}" == "{{right:#?}}")`{}
{}
"#,
match args {
Some(args) => format!(": {args}"),
None => "".to_owned(),
},
StrComparison::new(&left, &right),
);
}
}
}
12 changes: 7 additions & 5 deletions crates/voicevox_core/src/publish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,7 @@ fn list_windows_video_cards() {
#[cfg(test)]
mod tests {
use super::*;
use crate::macros::tests::assert_debug_fmt_eq;
use pretty_assertions::assert_eq;
use test_util::OPEN_JTALK_DIC_DIR;

Expand All @@ -661,7 +662,7 @@ mod tests {
.lock()
.unwrap()
.initialize(InitializeOptions::default());
assert_eq!(Ok(()), result);
assert_debug_fmt_eq!(Ok(()), result);
internal.lock().unwrap().finalize();
assert_eq!(
false,
Expand Down Expand Up @@ -695,7 +696,7 @@ mod tests {
) {
let internal = VoicevoxCore::new_with_mutex();
let result = internal.lock().unwrap().load_model(speaker_id);
assert_eq!(expected_result_at_uninitialized, result);
assert_debug_fmt_eq!(expected_result_at_uninitialized, result);

internal
.lock()
Expand All @@ -706,9 +707,10 @@ mod tests {
})
.unwrap();
let result = internal.lock().unwrap().load_model(speaker_id);
assert_eq!(
expected_result_at_initialized, result,
"got load_model result"
assert_debug_fmt_eq!(
expected_result_at_initialized,
result,
"got load_model result",
);
}

Expand Down
7 changes: 4 additions & 3 deletions crates/voicevox_core/src/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ impl Status {
mod tests {

use super::*;
use crate::macros::tests::assert_debug_fmt_eq;
use pretty_assertions::assert_eq;

#[rstest]
Expand Down Expand Up @@ -391,7 +392,7 @@ mod tests {
fn status_load_metas_works() {
let mut status = Status::new(true, 0);
let result = status.load_metas();
assert_eq!(Ok(()), result);
assert_debug_fmt_eq!(Ok(()), result);
let expected = BTreeSet::from([0, 1, 2, 3]);
assert_eq!(expected, status.supported_styles);
}
Expand All @@ -407,7 +408,7 @@ mod tests {
fn status_load_model_works() {
let mut status = Status::new(false, 0);
let result = status.load_model(0);
assert_eq!(Ok(()), result);
assert_debug_fmt_eq!(Ok(()), result);
assert_eq!(1, status.models.predict_duration.len());
assert_eq!(1, status.models.predict_intonation.len());
assert_eq!(1, status.models.decode.len());
Expand All @@ -422,7 +423,7 @@ mod tests {
"model should not be loaded"
);
let result = status.load_model(model_index);
assert_eq!(Ok(()), result);
assert_debug_fmt_eq!(Ok(()), result);
assert!(
status.is_model_loaded(model_index),
"model should be loaded"
Expand Down

0 comments on commit 44c4582

Please sign in to comment.