From 393e1f04781ea88d4946a3af573fb174cb1d500f Mon Sep 17 00:00:00 2001 From: Louis Beaumont Date: Mon, 6 Jan 2025 14:27:21 -0800 Subject: [PATCH] fix: unit tests --- screenpipe-audio/examples/stt.rs | 9 +- screenpipe-audio/tests/accuracy_test.rs | 2 - screenpipe-audio/tests/core_tests.rs | 16 ++-- screenpipe-vision/examples/websocket.rs | 11 ++- .../examples/window-filtering.rs | 82 ------------------- screenpipe-vision/src/core.rs | 2 +- screenpipe-vision/tests/apple_vision_test.rs | 30 +++++-- 7 files changed, 38 insertions(+), 114 deletions(-) delete mode 100644 screenpipe-vision/examples/window-filtering.rs diff --git a/screenpipe-audio/examples/stt.rs b/screenpipe-audio/examples/stt.rs index 849d967b6..3e128675c 100644 --- a/screenpipe-audio/examples/stt.rs +++ b/screenpipe-audio/examples/stt.rs @@ -51,7 +51,6 @@ async fn main() { )); let vad_engine: Arc>> = Arc::new(Mutex::new(Box::new(SileroVad::new().await.unwrap()))); - let output_path = Arc::new(PathBuf::from("test_output")); let project_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); let segmentation_model_path = project_dir @@ -74,7 +73,6 @@ async fn main() { for (audio_file, expected_transcription) in test_cases { let whisper_model = Arc::clone(&whisper_model); let vad_engine = Arc::clone(&vad_engine); - let output_path = Arc::clone(&output_path); let segmentation_model_path = segmentation_model_path.clone(); let embedding_extractor = Arc::clone(&embedding_extractor); let embedding_manager = embedding_manager.clone(); @@ -92,11 +90,12 @@ async fn main() { }; let mut segments = prepare_segments( - &audio_input, + &audio_input.data, vad_engine.clone(), &segmentation_model_path, embedding_manager, embedding_extractor, + "default", ) .await .unwrap(); @@ -104,15 +103,13 @@ async fn main() { let mut transcription = String::new(); while let Some(segment) = segments.recv().await { - let (transcript, _) = stt( + let transcript = stt( &segment.samples, audio_input.sample_rate, &audio_input.device.to_string(), &mut whisper_model_guard, Arc::new(AudioTranscriptionEngine::WhisperLargeV3Turbo), None, - &output_path, - true, vec![Language::English], ) .await diff --git a/screenpipe-audio/tests/accuracy_test.rs b/screenpipe-audio/tests/accuracy_test.rs index 6d93c8e0f..5dae1c491 100644 --- a/screenpipe-audio/tests/accuracy_test.rs +++ b/screenpipe-audio/tests/accuracy_test.rs @@ -1,4 +1,3 @@ -use candle_transformers::models::metavoice::transformer; use candle_transformers::models::whisper; use futures::future::join_all; use screenpipe_audio::pyannote::embedding::EmbeddingExtractor; @@ -54,7 +53,6 @@ async fn test_transcription_accuracy() { )); let vad_engine: Arc>> = Arc::new(Mutex::new(Box::new(SileroVad::new().await.unwrap()))); - let output_path = Arc::new(PathBuf::from("test_output")); let mut tasks = Vec::new(); diff --git a/screenpipe-audio/tests/core_tests.rs b/screenpipe-audio/tests/core_tests.rs index 89c05a678..06fc7d0ce 100644 --- a/screenpipe-audio/tests/core_tests.rs +++ b/screenpipe-audio/tests/core_tests.rs @@ -299,7 +299,6 @@ mod tests { let vad_engine: Arc>> = Arc::new( tokio::sync::Mutex::new(Box::new(SileroVad::new().await.unwrap())), ); - let output_path = Arc::new(PathBuf::from("test_output")); let audio_data = screenpipe_audio::pcm_decode("test_data/Arifi.wav") .expect("Failed to decode audio file"); @@ -310,7 +309,6 @@ mod tests { device: Arc::new(screenpipe_audio::default_input_device().unwrap()), }; - // Create the missing parameters let project_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); let segmentation_model_path = project_dir @@ -334,11 +332,12 @@ mod tests { let embedding_manager = EmbeddingManager::new(usize::MAX); let mut segments = prepare_segments( - &audio_input, + &audio_input.data, vad_engine.clone(), &segmentation_model_path, embedding_manager, embedding_extractor, + "default", ) .await .unwrap(); @@ -346,15 +345,13 @@ mod tests { let mut transcription_result = String::new(); while let Some(segment) = segments.recv().await { - let (transcript, _) = stt( + let transcript = stt( &segment.samples, audio_input.sample_rate, &audio_input.device.to_string(), &mut whisper_model_guard, Arc::new(AudioTranscriptionEngine::WhisperLargeV3Turbo), None, - &output_path, - true, vec![Language::English], ) .await @@ -436,26 +433,25 @@ mod tests { let start_time = Instant::now(); let mut segments = prepare_segments( - &audio_input, + &audio_input.data, vad_engine.clone(), &segmentation_model_path, embedding_manager, embedding_extractor, + "default", ) .await .unwrap(); let mut transcription = String::new(); while let Some(segment) = segments.recv().await { - let (transcript, _) = stt( + let transcript = stt( &segment.samples, audio_input.sample_rate, &audio_input.device.to_string(), &mut whisper_model, Arc::new(AudioTranscriptionEngine::WhisperLargeV3Turbo), None, - &output_path, - true, vec![Language::English], ) .await diff --git a/screenpipe-vision/examples/websocket.rs b/screenpipe-vision/examples/websocket.rs index 6bab23b35..06f6ed1fc 100644 --- a/screenpipe-vision/examples/websocket.rs +++ b/screenpipe-vision/examples/websocket.rs @@ -3,6 +3,7 @@ use base64::{engine::general_purpose, Engine as _}; use clap::Parser; use futures_util::{SinkExt, StreamExt}; use image::ImageEncoder; +use screenpipe_vision::capture_screenshot_by_window::WindowFilters; use screenpipe_vision::{ continuous_capture, monitor::get_default_monitor, CaptureResult, OcrEngine, }; @@ -76,17 +77,19 @@ async fn main() -> Result<()> { let (result_tx, result_rx) = channel(512); - let save_text_files = cli.save_text_files; let ws_port = cli.ws_port; let monitor = get_default_monitor().await; let id = monitor.id(); + let window_filters = Arc::new(WindowFilters::new( + &cli.ignored_windows, + &cli.included_windows, + )); tokio::spawn(async move { continuous_capture( result_tx, Duration::from_secs_f64(1.0 / cli.fps), - save_text_files, // if apple use apple otherwise if windows use windows native otherwise use tesseract if cfg!(target_os = "macos") { OcrEngine::AppleNative @@ -96,9 +99,9 @@ async fn main() -> Result<()> { OcrEngine::Tesseract }, id, - &cli.ignored_windows, - &cli.included_windows, + window_filters, vec![], + false, ) .await }); diff --git a/screenpipe-vision/examples/window-filtering.rs b/screenpipe-vision/examples/window-filtering.rs deleted file mode 100644 index ea35da282..000000000 --- a/screenpipe-vision/examples/window-filtering.rs +++ /dev/null @@ -1,82 +0,0 @@ -use anyhow::Result; -use clap::Parser; -use screenpipe_vision::{ - continuous_capture, monitor::get_default_monitor, CaptureResult, OcrEngine, -}; -use std::time::Duration; -use tokio::sync::mpsc::channel; -use tracing_subscriber::{fmt::format::FmtSpan, EnvFilter}; - -#[derive(Parser)] -#[command(author, version, about, long_about = None)] -struct Cli { - /// Windows to ignore (can be specified multiple times) - #[arg(long)] - ignore: Vec, - /// Windows to include (can be specified multiple times) - #[arg(long)] - include: Vec, -} - -#[tokio::main] -async fn main() -> Result<()> { - tracing_subscriber::fmt() - .with_env_filter( - EnvFilter::from_default_env() - .add_directive(tracing::Level::DEBUG.into()) - .add_directive("tokenizers=error".parse().unwrap()), - ) - .with_span_events(FmtSpan::CLOSE) - .init(); - - let cli = Cli::parse(); - - let (result_tx, mut result_rx) = channel(512); - - let monitor = get_default_monitor().await; - let id = monitor.id(); - - tokio::spawn(async move { - continuous_capture( - result_tx, - Duration::from_secs(1), - false, - // if apple use apple otherwise if windows use windows native otherwise use tesseract - if cfg!(target_os = "macos") { - OcrEngine::AppleNative - } else if cfg!(target_os = "windows") { - OcrEngine::WindowsNative - } else { - OcrEngine::Tesseract - }, - id, - &cli.ignore, - &cli.include, - vec![], - ) - .await - }); - - // Stream OCR results to logs - while let Some(result) = result_rx.recv().await { - log_capture_result(&result); - } - - Ok(()) -} - -fn log_capture_result(result: &CaptureResult) { - for window in &result.window_ocr_results { - tracing::info!( - "Window: '{}' (App: '{}', Focused: {})", - window.window_name, - window.app_name, - window.focused - ); - tracing::info!("Text: {}", window.text); - tracing::info!("Confidence: {:.2}", window.confidence); - tracing::info!("---"); - } - tracing::info!("Timestamp: {:?}", result.timestamp); - tracing::info!("====================================="); -} diff --git a/screenpipe-vision/src/core.rs b/screenpipe-vision/src/core.rs index add9b1545..516665881 100644 --- a/screenpipe-vision/src/core.rs +++ b/screenpipe-vision/src/core.rs @@ -310,7 +310,7 @@ pub fn trigger_screen_capture_permission() -> Result<()> { } #[cfg(target_os = "macos")] -fn get_apple_languages(languages: Vec) -> Vec { +pub fn get_apple_languages(languages: Vec) -> Vec { let map = APPLE_LANGUAGE_MAP.get_or_init(|| { let mut m = HashMap::new(); m.insert(Language::English, "en-US"); diff --git a/screenpipe-vision/tests/apple_vision_test.rs b/screenpipe-vision/tests/apple_vision_test.rs index c988b93af..c750b85a5 100644 --- a/screenpipe-vision/tests/apple_vision_test.rs +++ b/screenpipe-vision/tests/apple_vision_test.rs @@ -1,8 +1,10 @@ #[cfg(target_os = "macos")] #[cfg(test)] mod tests { + use cidre::ns; use image::GenericImageView; - use screenpipe_vision::perform_ocr_apple; + use screenpipe_core::Language; + use screenpipe_vision::{core::get_apple_languages, perform_ocr_apple}; use std::path::PathBuf; #[tokio::test] @@ -25,13 +27,14 @@ mod tests { let rgb_image = image.to_rgb8(); println!("RGB image dimensions: {:?}", rgb_image.dimensions()); - let result = perform_ocr_apple(&image, vec![]); + let (ocr_text, _, _) = + perform_ocr_apple(&image, &ns::ArrayMut::::with_capacity(0)); - println!("OCR text: {:?}", result); + println!("OCR text: {:?}", ocr_text); assert!( - result.contains("receiver_count"), + ocr_text.contains("receiver_count"), "OCR failed: {:?}", - result + ocr_text ); } // # 中文测试 @@ -45,13 +48,22 @@ mod tests { let image = image::open(&path).expect("Failed to open Chinese test image"); println!("Image dimensions: {:?}", image.dimensions()); - let result = perform_ocr_apple(&image, vec![]); + let languages_slice = { + use ns; + let apple_languages = get_apple_languages(vec![Language::Chinese]); + let mut slice = ns::ArrayMut::::with_capacity(apple_languages.len()); + apple_languages.iter().for_each(|language| { + slice.push(&ns::String::with_str(language.as_str())); + }); + slice + }; + let (ocr_text, _, _) = perform_ocr_apple(&image, &languages_slice); - println!("OCR text: {:?}", result); + println!("OCR text: {:?}", ocr_text); assert!( - result.contains("管理分支"), // 替换为您的测试图像中的实际中文文本 + ocr_text.contains("管理分支"), "OCR failed to recognize Chinese text: {:?}", - result + ocr_text ); } }