From 07b53fc021a0dbca79314154460017f10718904e Mon Sep 17 00:00:00 2001 From: David Cermak Date: Fri, 13 Dec 2024 17:32:57 +0100 Subject: [PATCH] fix(mdns): Querier with cond-var WIP --- components/mdns/examples/usage.rs | 5 +- components/mdns/src/lib.rs | 266 +++++++++++--------------- components/mdns/src/querier.rs | 148 ++++++++++++++ components/mdns/src/service/native.rs | 9 +- 4 files changed, 266 insertions(+), 162 deletions(-) create mode 100644 components/mdns/src/querier.rs diff --git a/components/mdns/examples/usage.rs b/components/mdns/examples/usage.rs index 49fa9e9775..366daed45c 100644 --- a/components/mdns/examples/usage.rs +++ b/components/mdns/examples/usage.rs @@ -1,16 +1,17 @@ // examples/basic_usage.rs +// use std::process::Termination; use mdns::*; use std::thread; use std::time::Duration; - +// use libc::__c_anonymous_xsk_tx_metadata_union; fn main() { // Initialize mDNS mdns_init(); mdns_query("david-work.local"); - thread::sleep(Duration::from_millis(500)); + thread::sleep(Duration::from_millis(1500)); // Deinitialize mDNS mdns_deinit(); } diff --git a/components/mdns/src/lib.rs b/components/mdns/src/lib.rs index 4b8c5e6e2e..67c1cdc8f7 100644 --- a/components/mdns/src/lib.rs +++ b/components/mdns/src/lib.rs @@ -1,180 +1,93 @@ -// src/lib.rs mod service; +mod querier; + use service::{Service, NativeService, CService}; +use querier::Querier; + use lazy_static::lazy_static; use std::sync::{Arc, Mutex}; use dns_parser::{Builder, QueryClass, QueryType, Packet}; -use std::time::{Duration, Instant}; - +use std::time::Duration; -#[cfg(not(feature = "ffi"))] -fn build_info() { - println!("Default build"); +lazy_static! { + static ref SERVER: Arc> = Arc::new(Mutex::new(Objects { + service: None, + querier: None, + })); } -#[cfg(not(feature = "ffi"))] -fn create_service(cb: fn(&[u8])) -> Box { - NativeService::init(cb) +struct Objects { + service: Option>, + querier: Option, } -#[cfg(feature = "ffi")] fn build_info() { - println!("FFI build"); + #[cfg(not(feature = "ffi"))] + { + println!("Default build"); + } + #[cfg(feature = "ffi")] + { + println!("FFI build"); + } } -#[cfg(feature = "ffi")] fn create_service(cb: fn(&[u8])) -> Box { - CService::init(cb) + #[cfg(not(feature = "ffi"))] + { + NativeService::init(cb) + } + #[cfg(feature = "ffi")] + { + CService::init(cb) + } +} + +fn read_cb(vec: &[u8]) { + if vec.len() == 0 { + let mut service_guard = SERVER.lock().unwrap(); + if let Some(querier) = &mut service_guard.querier { + println!("querier process {:?}", vec); + let packet = querier.process(); + if packet.is_some() { + if let Some(service) = &service_guard.service { + service.send(packet.unwrap()); + } + } + } + } else { + println!("Received {:?}", vec); + let mut service_guard = SERVER.lock().unwrap(); + if let Some(querier) = &mut service_guard.querier { + querier.parse(&vec).expect("Failed to parse.."); + } + // parse_dns_response(vec).unwrap(); + } } fn parse_dns_response(data: &[u8]) -> Result<(), String> { println!("Parsing DNS response with length 2 : {}", data.len()); - let packet = Packet::parse(&data).unwrap(); + let packet = Packet::parse(data).unwrap(); for answer in packet.answers { + println!("ANSWER:"); println!("{:?}", answer); } for question in packet.questions { + println!("QUESTION:"); println!("{:?}", question); } Ok(()) } - -// pub trait Querier: Send + Sync { -// fn init() -> Box -// where -// Self: Sized; - -// fn deinit(self: Box); -// } - - -#[derive(Debug)] -pub struct Query { - pub name: String, - pub service: String, - pub proto: String, - pub query_type: QueryType, - pub unicast: bool, - pub timeout: Duration, - pub added_at: Instant, // To track when the query was added -} - -pub struct Querier { - queries: Vec, -} - -fn create_querier() -> Box { - NativeService::init(cb) -} - -impl Querier { - pub fn new() -> Self { - Self { - queries: Vec::new(), - } - } - - pub fn init(&mut self) { - println!("Querier initialized"); - } - - pub fn deinit(&mut self) { - self.queries.clear(); - println!("Querier deinitialized"); - } - - pub fn add( - &mut self, - name: String, - service: String, - proto: String, - query_type: QueryType, - unicast: bool, - timeout: Duration, - // semaphore: Option, - ) -> usize { - let query = Query { - name, - service, - proto, - query_type, - unicast, - timeout, - // semaphore, - added_at: Instant::now(), - }; - self.queries.push(query); - self.queries.len() - 1 // Return the ID (index of the query) - } - - pub fn process(&mut self) { - let now = Instant::now(); - self.queries.retain(|query| { - let elapsed = now.duration_since(query.added_at); - if elapsed > query.timeout { - println!("Query timed out: {:?}", query); - // if let Some(semaphore) = &query.semaphore { - // semaphore.add_permits(1); // Release semaphore if waiting - // } - false // Remove the query - } else { - println!("Processing query: {:?}", query); - // Implement retry logic here if needed - true // Keep the query - } - }); - } - - pub async fn wait(&mut self, id: usize) -> Result<(), &'static str> { - if let Some(query) = self.queries.get_mut(id) { - Ok(()) - // if let Some(semaphore) = &query.semaphore { - // semaphore.acquire().await.unwrap(); // Block until the semaphore is released - // Ok(()) - // } else { - // Err("No semaphore set for this query") - // } - } else { - Err("Invalid query ID") - } - } -} - -struct Objects { - service: Option>, - // responder: Option>, - querier: Option, -} - -// lazy_static! { -// static ref SERVICE: Arc>>> = Arc::new(Mutex::new(None)); -// } - -lazy_static! { - static ref SERVER: Arc> = Arc::new(Mutex::new(Objects { - service: None, - querier: None, - })); -} - -fn read_cb(vec: &[u8]) { - println!("Received {:?}", vec); - parse_dns_response(vec).unwrap(); -} - pub fn mdns_init() { build_info(); let mut service_guard = SERVER.lock().unwrap(); if service_guard.service.is_none() { - // Initialize the service only if it hasn't been initialized service_guard.service = Some(create_service(read_cb)); } if service_guard.querier.is_none() { - // Initialize the service only if it hasn't been initialized - service_guard.querier = Some(init()); + service_guard.querier = Some(Querier::new()); } - println!("mdns_init called"); } @@ -186,26 +99,63 @@ pub fn mdns_deinit() { println!("mdns_deinit called"); } -fn create_a_query(name: &str) -> Vec { - let query_type = QueryType::A; // Type A query for IPv4 address - let query_class = QueryClass::IN; // Class IN (Internet) - - // Create a new query with ID and recursion setting - let mut builder = Builder::new_query(0x5555, true); +/* +pub fn mdns_query(name: &str) { + let mut service_guard = SERVER.lock().unwrap(); - // Add the question for "david-work.local" - builder.add_question(name, false, query_type, query_class); + if let Some(querier) = &mut service_guard.querier { + let timeout = Duration::from_secs(5); + let query_id = querier.add( + name.to_string(), + "".to_string(), + "_http._tcp".to_string(), + QueryType::A, + false, + timeout, + ); + querier.wait(query_id).await.unwrap(); + println!("Query added with ID: {}", query_id); + } - // Build and return the query packet - builder.build().unwrap_or_else(|x| x) } + */ + pub fn mdns_query(name: &str) { - let service_guard = SERVER.lock().unwrap(); - if let Some(service) = &service_guard.service { - let packet = create_a_query(name); - service.send(packet); - } else { - println!("Service not initialized"); + // Lock the server to access the querier + let query_id; + let querier_cvar; + + { + let mut service_guard = SERVER.lock().unwrap(); + if let Some(querier) = &mut service_guard.querier { + let timeout = Duration::from_secs(1); + query_id = querier.add( + name.to_string(), + "".to_string(), + "_http._tcp".to_string(), + QueryType::A, + false, + timeout, + ); + querier_cvar = querier.completed_queries.clone(); // Clone the Arc pair + } else { + println!("No querier available"); + return; + } + } // Release the SERVER lock here + + // Wait for the query to complete + let (lock, cvar) = &*querier_cvar; + let mut completed = lock.lock().unwrap(); + while !completed.get(&query_id).copied().unwrap_or(false) { + let result = cvar.wait_timeout(completed, Duration::from_secs(5)).unwrap(); + completed = result.0; // Update the lock guard + if result.1.timed_out() { + println!("Query timed out: ID {}", query_id); + return; + } } + + println!("Query completed!!! ID {}", query_id); } diff --git a/components/mdns/src/querier.rs b/components/mdns/src/querier.rs new file mode 100644 index 0000000000..6ff3e769a0 --- /dev/null +++ b/components/mdns/src/querier.rs @@ -0,0 +1,148 @@ +use dns_parser::{Builder, QueryClass, QueryType, Packet}; +use std::sync::{Arc, Condvar, Mutex}; +use std::collections::HashMap; +use std::time::{Duration, Instant}; + +#[derive(Debug)] +pub struct Query { + pub name: String, + pub service: String, + pub proto: String, + pub query_type: QueryType, + pub unicast: bool, + pub timeout: Duration, + pub added_at: Instant, + pub packet: Vec, + pub id: usize, +} + +pub struct Querier { + queries: Vec, + pub(crate) completed_queries: Arc<(Mutex>, Condvar)>, // Shared state for query completion +} + +impl Querier { + pub fn new() -> Self { + Self { + queries: Vec::new(), + completed_queries: Arc::new((Mutex::new(HashMap::new()), Condvar::new())), + } + } + + pub fn add( + &mut self, + name: String, + service: String, + proto: String, + query_type: QueryType, + unicast: bool, + timeout: Duration, + ) -> usize { + let id = self.queries.len(); + let query = Query { + name: name.clone(), + service, + proto, + query_type, + unicast, + timeout, + added_at: Instant::now(), + packet: create_a_query(&name), + id: id.clone() + }; + self.queries.push(query); + self.completed_queries.0.lock().unwrap().insert(id, false); // Mark as incomplete + id + } + pub fn parse(&mut self, data: &[u8]) -> Result<(), String> { + println!("Parsing DNS response with length 2 : {}", data.len()); + let packet = Packet::parse(data).unwrap(); + for answer in packet.answers { + println!("ANSWER:"); + println!("{:?}", answer); + let name = answer.name.to_string(); + let mut completed_queries = vec![]; + self.queries.retain(|query| { + if query.name == name { + println!("ANSWER: {:?}", answer.data); + completed_queries.push(query.id); // Track for completion + false + } + else { true } + }); + let (lock, cvar) = &*self.completed_queries; + let mut completed = lock.lock().unwrap(); + for query_id in completed_queries { + if let Some(entry) = completed.get_mut(&query_id) { + *entry = true; + } + } + cvar.notify_all(); + + } + for question in packet.questions { + println!("{:?}", question); + } + Ok(()) + + } + + pub fn process(&mut self) -> Option> { + let now = Instant::now(); + let mut packet_to_send: Option> = None; + + // Collect IDs of timed-out queries to mark them as complete + let mut timed_out_queries = vec![]; + self.queries.retain(|query| { + let elapsed = now.duration_since(query.added_at); + if elapsed > query.timeout { + timed_out_queries.push(query.id); // Track for completion + false // Remove the query + } else { + packet_to_send = Some(query.packet.clone()); + true // Keep the query + } + }); + + // Mark timed-out queries as complete and notify waiting threads + let (lock, cvar) = &*self.completed_queries; + let mut completed = lock.lock().unwrap(); + for query_id in timed_out_queries { + if let Some(entry) = completed.get_mut(&query_id) { + *entry = true; + } + } + cvar.notify_all(); + println!("Processing... query"); + + packet_to_send + } + pub fn wait(&self, id: usize) -> Result<(), &'static str> { + let (lock, cvar) = &*self.completed_queries; + + // Wait until the query is marked as complete or timeout expires + let mut completed = lock.lock().unwrap(); + while !completed.get(&id).copied().unwrap_or(false) { + completed = cvar.wait(completed).unwrap(); + } + Ok(()) + } + + fn mark_query_as_complete(&self, query: &Query) { + let (lock, cvar) = &*self.completed_queries; + let mut completed = lock.lock().unwrap(); + if let Some(entry) = completed.get_mut(&(self.queries.len() - 1)) { + *entry = true; + } + cvar.notify_all(); + } +} + +fn create_a_query(name: &str) -> Vec { + let query_type = QueryType::A; // Type A query for IPv4 address + let query_class = QueryClass::IN; // Class IN (Internet) + + let mut builder = Builder::new_query(0, true); + builder.add_question(name, false, query_type, query_class); + builder.build().unwrap_or_else(|x| x) +} diff --git a/components/mdns/src/service/native.rs b/components/mdns/src/service/native.rs index da530836c4..458e024911 100644 --- a/components/mdns/src/service/native.rs +++ b/components/mdns/src/service/native.rs @@ -7,6 +7,7 @@ use nix::unistd::{pipe, read, write, close}; use nix::sys::select::{select, FdSet}; use nix::sys::time::TimeVal; use std::os::fd::AsRawFd; +use std::ptr::null; enum Action { Action1, @@ -53,10 +54,14 @@ impl NativeService read_fds.insert(read_fd); - let mut timeout = TimeVal::new(0, 100_000); + let mut timeout = TimeVal::new(0, 500_000); match select(read_fd.max(socket_fd) + 1, Some(&mut read_fds), None, None, Some(&mut timeout)) { - Ok(0) => println!("ThreadHousekeeper: Performing housekeeping tasks"), + Ok(0) => { + println!("ThreadHousekeeper: Performing housekeeping tasks"); + let buf = vec![]; + local_cb(&buf); + } Ok(_) => { if read_fds.contains(socket_fd) { // let mut buf: [MaybeUninit; 1500] = unsafe { MaybeUninit::uninit().assume_init() };