diff --git a/scripts/run.sh b/scripts/run.sh index 05cb712..8fde642 100644 --- a/scripts/run.sh +++ b/scripts/run.sh @@ -25,10 +25,10 @@ directories=$(find "$script_dir/../res/" -mindepth 1 -type d -printf "%P\n") for dir in $directories; do dir="../res/${dir}/" echo "Running Lazy eval with subsampling using ${dir}" - ../target/release/tbt-segmentation -l -c -u -s ../specification/shiplanding_formula_combined.tbt -f $dir> "${dir}/subsamplingAndLazy_result.txt" + ../target/release/tbt-segmentation -l -c -u -s ../specification/shiplanding_formula_combined.tbt --toml "${dir}/subsamplingAndLazy_result.toml" -f $dir> "${dir}/subsamplingAndLazy_result.txt" python3 infer_parameters_visualization.py "${dir}/subsamplingAndLazy_result.txt" echo "Running eval with subsampling using ${dir}" - ../target/release/tbt-segmentation -c -u -s ../specification/shiplanding_formula_combined.tbt -f $dir> "${dir}/subsampling_result.txt" + ../target/release/tbt-segmentation -c -u -s ../specification/shiplanding_formula_combined.tbt --toml "${dir}/subsampling_result.toml" -f $dir> "${dir}/subsampling_result.txt" python3 infer_parameters_visualization.py "${dir}/subsampling_result.txt" done diff --git a/src/lib.rs b/src/lib.rs index 0a5ad3a..d0d874c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -325,8 +325,14 @@ pub fn get_alternative_segmentation( pub fn generate_toml_output_file( location: String, + number_skipped_entries: usize, best_segmentation: ClonedSegmentation, alternative_segmentation: Option>, ) -> std::io::Result<()> { - generate_toml(location, best_segmentation, alternative_segmentation) + generate_toml( + location, + number_skipped_entries, + best_segmentation, + alternative_segmentation, + ) } diff --git a/src/main.rs b/src/main.rs index 21757b6..690937a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -58,6 +58,7 @@ fn main() { println!("Generating toml file: {}", toml_output_location); match generate_toml_output_file( toml_output_location, + number_skipped_entries, best_segmentation, alternative_segmentations, ) { diff --git a/src/toml_out.rs b/src/toml_out.rs index b914dd6..2afb681 100644 --- a/src/toml_out.rs +++ b/src/toml_out.rs @@ -26,6 +26,7 @@ struct NamedSegmentation(String, TomlSegmentation); #[derive(Serialize)] struct TomlSegmentation { + delta: usize, robustness: f32, segments: Vec, } @@ -37,16 +38,25 @@ struct Segmentations { pub fn generate_toml( location: String, + delta: usize, best_segmentation: ClonedSegmentation, alternative_segmentation: Option>, ) -> std::io::Result<()> { let mut segmentations = Vec::::new(); // Adding best segmentation - segmentations.push(read_segmentation(best_segmentation, "best".to_string())); + segmentations.push(read_segmentation( + best_segmentation, + delta, + "best".to_string(), + )); // Adding the alternative segmentations if let Some(alternatives) = alternative_segmentation { alternatives.into_iter().enumerate().for_each(|(i, seg)| { - segmentations.push(read_segmentation(seg, format!("alternative_{}", i + 1))) + segmentations.push(read_segmentation( + seg, + delta, + format!("alternative_{}", i + 1), + )) }) } let all_segmentations = Segmentations { segmentations }; @@ -59,6 +69,7 @@ pub fn generate_toml( fn read_segmentation( segmentation: (f32, Vec<(crate::behaviortree::TbtNode, usize, usize, f32)>), + delta: usize, name: String, ) -> NamedSegmentation { let mut segments = Vec::::new(); @@ -78,6 +89,7 @@ fn read_segmentation( NamedSegmentation( name, TomlSegmentation { + delta, robustness: segmentation.0, segments, }, diff --git a/tests/tests.rs b/tests/tests.rs index 7451b61..6e9de78 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -43,7 +43,7 @@ fn run_test( ) -> Result<(), String> { for (trace, expected) in traces_with_expected_value { let trace: Trace = (trace.len(), HashMap::from([(signal_name.clone(), trace)])); - let robustness = evaluate( + let ((robustness, _), _) = evaluate( tbt.clone(), trace, SystemTime::now(), @@ -54,10 +54,10 @@ fn run_test( None, false, ); - if robustness.0 == expected { + if robustness == expected { continue; } else { - return Err(format!("Expected {expected} but was {}.", robustness.0)); + return Err(format!("Expected {expected} but was {}.", robustness)); } } Ok(()) @@ -87,7 +87,7 @@ fn test_hardcoded_maneuver(specification: &str, logfile: &str) { let setting_2 = setting.clone(); let start = SystemTime::now(); println!("Run using parsed tbt!"); - let first_run = evaluate( + let (first_run, first_run_alt) = evaluate( new_tbt, trace, start, @@ -99,7 +99,7 @@ fn test_hardcoded_maneuver(specification: &str, logfile: &str) { false, ); println!("Run using hand-coded tbt!"); - let second_run = evaluate( + let (second_run, second_run_alt) = evaluate( old_tbt, trace_2, start, @@ -140,12 +140,15 @@ fn test_hardcoded_maneuver(specification: &str, logfile: &str) { v2.3 ); } - let alternatives_new = first_run.2.unwrap(); - let alternatives_old = second_run.2.unwrap(); + let alternatives_new = first_run_alt.unwrap(); + let alternatives_old = second_run_alt.unwrap(); for j in 0..alternatives_new.len() { - for i in 0..first_run.1.len() { - let v1 = &alternatives_new[j][i]; - let v2 = &alternatives_old[j][i]; + let alt_new = &alternatives_new[j]; + let alt_old = &alternatives_old[j]; + assert_eq!(alt_new.0, alt_old.0); + for i in 0..alt_old.1.len() { + let v1 = &alt_new.1[i]; + let v2 = &alt_old.1[i]; assert!( v1.1 == v2.1, "Entry of alternatives i={i} j={j} is expected to have same lower but wasnt {} != {}",