Skip to content

Commit

Permalink
fix: unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
louis030195 committed Jan 6, 2025
1 parent e147b2e commit 393e1f0
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 114 deletions.
9 changes: 3 additions & 6 deletions screenpipe-audio/examples/stt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ async fn main() {
));
let vad_engine: Arc<Mutex<Box<dyn VadEngine + Send>>> =
Arc::new(Mutex::new(Box::new(SileroVad::new().await.unwrap())));
let output_path = Arc::new(PathBuf::from("test_output"));

let project_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
let segmentation_model_path = project_dir
Expand All @@ -74,7 +73,6 @@ async fn main() {
for (audio_file, expected_transcription) in test_cases {
let whisper_model = Arc::clone(&whisper_model);
let vad_engine = Arc::clone(&vad_engine);
let output_path = Arc::clone(&output_path);
let segmentation_model_path = segmentation_model_path.clone();
let embedding_extractor = Arc::clone(&embedding_extractor);
let embedding_manager = embedding_manager.clone();
Expand All @@ -92,27 +90,26 @@ async fn main() {
};

let mut segments = prepare_segments(
&audio_input,
&audio_input.data,
vad_engine.clone(),
&segmentation_model_path,
embedding_manager,
embedding_extractor,
"default",
)
.await
.unwrap();
let mut whisper_model_guard = whisper_model.lock().await;

let mut transcription = String::new();
while let Some(segment) = segments.recv().await {
let (transcript, _) = stt(
let transcript = stt(
&segment.samples,
audio_input.sample_rate,
&audio_input.device.to_string(),
&mut whisper_model_guard,
Arc::new(AudioTranscriptionEngine::WhisperLargeV3Turbo),
None,
&output_path,
true,
vec![Language::English],
)
.await
Expand Down
2 changes: 0 additions & 2 deletions screenpipe-audio/tests/accuracy_test.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use candle_transformers::models::metavoice::transformer;
use candle_transformers::models::whisper;
use futures::future::join_all;
use screenpipe_audio::pyannote::embedding::EmbeddingExtractor;
Expand Down Expand Up @@ -54,7 +53,6 @@ async fn test_transcription_accuracy() {
));
let vad_engine: Arc<Mutex<Box<dyn VadEngine + Send>>> =
Arc::new(Mutex::new(Box::new(SileroVad::new().await.unwrap())));
let output_path = Arc::new(PathBuf::from("test_output"));

let mut tasks = Vec::new();

Expand Down
16 changes: 6 additions & 10 deletions screenpipe-audio/tests/core_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,6 @@ mod tests {
let vad_engine: Arc<tokio::sync::Mutex<Box<dyn VadEngine + Send>>> = Arc::new(
tokio::sync::Mutex::new(Box::new(SileroVad::new().await.unwrap())),
);
let output_path = Arc::new(PathBuf::from("test_output"));
let audio_data = screenpipe_audio::pcm_decode("test_data/Arifi.wav")
.expect("Failed to decode audio file");

Expand All @@ -310,7 +309,6 @@ mod tests {
device: Arc::new(screenpipe_audio::default_input_device().unwrap()),
};


// Create the missing parameters
let project_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
let segmentation_model_path = project_dir
Expand All @@ -334,27 +332,26 @@ mod tests {
let embedding_manager = EmbeddingManager::new(usize::MAX);

let mut segments = prepare_segments(
&audio_input,
&audio_input.data,
vad_engine.clone(),
&segmentation_model_path,
embedding_manager,
embedding_extractor,
"default",
)
.await
.unwrap();
let mut whisper_model_guard = whisper_model.lock().await;

let mut transcription_result = String::new();
while let Some(segment) = segments.recv().await {
let (transcript, _) = stt(
let transcript = stt(
&segment.samples,
audio_input.sample_rate,
&audio_input.device.to_string(),
&mut whisper_model_guard,
Arc::new(AudioTranscriptionEngine::WhisperLargeV3Turbo),
None,
&output_path,
true,
vec![Language::English],
)
.await
Expand Down Expand Up @@ -436,26 +433,25 @@ mod tests {
let start_time = Instant::now();

let mut segments = prepare_segments(
&audio_input,
&audio_input.data,
vad_engine.clone(),
&segmentation_model_path,
embedding_manager,
embedding_extractor,
"default",
)
.await
.unwrap();

let mut transcription = String::new();
while let Some(segment) = segments.recv().await {
let (transcript, _) = stt(
let transcript = stt(
&segment.samples,
audio_input.sample_rate,
&audio_input.device.to_string(),
&mut whisper_model,
Arc::new(AudioTranscriptionEngine::WhisperLargeV3Turbo),
None,
&output_path,
true,
vec![Language::English],
)
.await
Expand Down
11 changes: 7 additions & 4 deletions screenpipe-vision/examples/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use base64::{engine::general_purpose, Engine as _};
use clap::Parser;
use futures_util::{SinkExt, StreamExt};
use image::ImageEncoder;
use screenpipe_vision::capture_screenshot_by_window::WindowFilters;
use screenpipe_vision::{
continuous_capture, monitor::get_default_monitor, CaptureResult, OcrEngine,
};
Expand Down Expand Up @@ -76,17 +77,19 @@ async fn main() -> Result<()> {

let (result_tx, result_rx) = channel(512);

let save_text_files = cli.save_text_files;
let ws_port = cli.ws_port;

let monitor = get_default_monitor().await;
let id = monitor.id();
let window_filters = Arc::new(WindowFilters::new(
&cli.ignored_windows,
&cli.included_windows,
));

tokio::spawn(async move {
continuous_capture(
result_tx,
Duration::from_secs_f64(1.0 / cli.fps),
save_text_files,
// if apple use apple otherwise if windows use windows native otherwise use tesseract
if cfg!(target_os = "macos") {
OcrEngine::AppleNative
Expand All @@ -96,9 +99,9 @@ async fn main() -> Result<()> {
OcrEngine::Tesseract
},
id,
&cli.ignored_windows,
&cli.included_windows,
window_filters,
vec![],
false,
)
.await
});
Expand Down
82 changes: 0 additions & 82 deletions screenpipe-vision/examples/window-filtering.rs

This file was deleted.

2 changes: 1 addition & 1 deletion screenpipe-vision/src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ pub fn trigger_screen_capture_permission() -> Result<()> {
}

#[cfg(target_os = "macos")]
fn get_apple_languages(languages: Vec<screenpipe_core::Language>) -> Vec<String> {
pub fn get_apple_languages(languages: Vec<screenpipe_core::Language>) -> Vec<String> {
let map = APPLE_LANGUAGE_MAP.get_or_init(|| {
let mut m = HashMap::new();
m.insert(Language::English, "en-US");
Expand Down
30 changes: 21 additions & 9 deletions screenpipe-vision/tests/apple_vision_test.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#[cfg(target_os = "macos")]
#[cfg(test)]
mod tests {
use cidre::ns;
use image::GenericImageView;
use screenpipe_vision::perform_ocr_apple;
use screenpipe_core::Language;
use screenpipe_vision::{core::get_apple_languages, perform_ocr_apple};
use std::path::PathBuf;

#[tokio::test]
Expand All @@ -25,13 +27,14 @@ mod tests {
let rgb_image = image.to_rgb8();
println!("RGB image dimensions: {:?}", rgb_image.dimensions());

let result = perform_ocr_apple(&image, vec![]);
let (ocr_text, _, _) =
perform_ocr_apple(&image, &ns::ArrayMut::<ns::String>::with_capacity(0));

println!("OCR text: {:?}", result);
println!("OCR text: {:?}", ocr_text);
assert!(
result.contains("receiver_count"),
ocr_text.contains("receiver_count"),
"OCR failed: {:?}",
result
ocr_text
);
}
// # 中文测试
Expand All @@ -45,13 +48,22 @@ mod tests {
let image = image::open(&path).expect("Failed to open Chinese test image");
println!("Image dimensions: {:?}", image.dimensions());

let result = perform_ocr_apple(&image, vec![]);
let languages_slice = {
use ns;
let apple_languages = get_apple_languages(vec![Language::Chinese]);
let mut slice = ns::ArrayMut::<ns::String>::with_capacity(apple_languages.len());
apple_languages.iter().for_each(|language| {
slice.push(&ns::String::with_str(language.as_str()));
});
slice
};
let (ocr_text, _, _) = perform_ocr_apple(&image, &languages_slice);

println!("OCR text: {:?}", result);
println!("OCR text: {:?}", ocr_text);
assert!(
result.contains("管理分支"), // 替换为您的测试图像中的实际中文文本
ocr_text.contains("管理分支"),
"OCR failed to recognize Chinese text: {:?}",
result
ocr_text
);
}
}

0 comments on commit 393e1f0

Please sign in to comment.