Skip to content

Commit

Permalink
fix(mdns): Querier with cond-var WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
david-cermak committed Dec 13, 2024
1 parent 1805103 commit 07b53fc
Show file tree
Hide file tree
Showing 4 changed files with 266 additions and 162 deletions.
5 changes: 3 additions & 2 deletions components/mdns/examples/usage.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
// examples/basic_usage.rs

// use std::process::Termination;
use mdns::*;
use std::thread;
use std::time::Duration;

// use libc::__c_anonymous_xsk_tx_metadata_union;

fn main() {
// Initialize mDNS
mdns_init();

mdns_query("david-work.local");
thread::sleep(Duration::from_millis(500));
thread::sleep(Duration::from_millis(1500));
// Deinitialize mDNS
mdns_deinit();
}
266 changes: 108 additions & 158 deletions components/mdns/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,180 +1,93 @@
// src/lib.rs
mod service;
mod querier;

use service::{Service, NativeService, CService};
use querier::Querier;

use lazy_static::lazy_static;
use std::sync::{Arc, Mutex};
use dns_parser::{Builder, QueryClass, QueryType, Packet};
use std::time::{Duration, Instant};

use std::time::Duration;

#[cfg(not(feature = "ffi"))]
fn build_info() {
println!("Default build");
lazy_static! {
static ref SERVER: Arc<Mutex<Objects>> = Arc::new(Mutex::new(Objects {
service: None,
querier: None,
}));
}

#[cfg(not(feature = "ffi"))]
fn create_service(cb: fn(&[u8])) -> Box<dyn Service> {
NativeService::init(cb)
struct Objects {
service: Option<Box<dyn Service>>,
querier: Option<Querier>,
}

#[cfg(feature = "ffi")]
fn build_info() {
println!("FFI build");
#[cfg(not(feature = "ffi"))]
{
println!("Default build");
}
#[cfg(feature = "ffi")]
{
println!("FFI build");
}
}

#[cfg(feature = "ffi")]
fn create_service(cb: fn(&[u8])) -> Box<dyn Service> {
CService::init(cb)
#[cfg(not(feature = "ffi"))]
{
NativeService::init(cb)
}
#[cfg(feature = "ffi")]
{
CService::init(cb)
}
}

fn read_cb(vec: &[u8]) {
if vec.len() == 0 {
let mut service_guard = SERVER.lock().unwrap();
if let Some(querier) = &mut service_guard.querier {
println!("querier process {:?}", vec);
let packet = querier.process();
if packet.is_some() {
if let Some(service) = &service_guard.service {
service.send(packet.unwrap());
}
}
}
} else {
println!("Received {:?}", vec);
let mut service_guard = SERVER.lock().unwrap();
if let Some(querier) = &mut service_guard.querier {
querier.parse(&vec).expect("Failed to parse..");
}
// parse_dns_response(vec).unwrap();
}
}

fn parse_dns_response(data: &[u8]) -> Result<(), String> {
println!("Parsing DNS response with length 2 : {}", data.len());
let packet = Packet::parse(&data).unwrap();
let packet = Packet::parse(data).unwrap();
for answer in packet.answers {
println!("ANSWER:");
println!("{:?}", answer);
}
for question in packet.questions {
println!("QUESTION:");
println!("{:?}", question);
}
Ok(())
}


// pub trait Querier: Send + Sync {
// fn init() -> Box<Self>
// where
// Self: Sized;

// fn deinit(self: Box<Self>);
// }


#[derive(Debug)]
pub struct Query {
pub name: String,
pub service: String,
pub proto: String,
pub query_type: QueryType,
pub unicast: bool,
pub timeout: Duration,
pub added_at: Instant, // To track when the query was added
}

pub struct Querier {
queries: Vec<Query>,
}

fn create_querier() -> Box<dyn Querier> {
NativeService::init(cb)
}

impl Querier {
pub fn new() -> Self {
Self {
queries: Vec::new(),
}
}

pub fn init(&mut self) {
println!("Querier initialized");
}

pub fn deinit(&mut self) {
self.queries.clear();
println!("Querier deinitialized");
}

pub fn add(
&mut self,
name: String,
service: String,
proto: String,
query_type: QueryType,
unicast: bool,
timeout: Duration,
// semaphore: Option<tokio::sync::Semaphore>,
) -> usize {
let query = Query {
name,
service,
proto,
query_type,
unicast,
timeout,
// semaphore,
added_at: Instant::now(),
};
self.queries.push(query);
self.queries.len() - 1 // Return the ID (index of the query)
}

pub fn process(&mut self) {
let now = Instant::now();
self.queries.retain(|query| {
let elapsed = now.duration_since(query.added_at);
if elapsed > query.timeout {
println!("Query timed out: {:?}", query);
// if let Some(semaphore) = &query.semaphore {
// semaphore.add_permits(1); // Release semaphore if waiting
// }
false // Remove the query
} else {
println!("Processing query: {:?}", query);
// Implement retry logic here if needed
true // Keep the query
}
});
}

pub async fn wait(&mut self, id: usize) -> Result<(), &'static str> {
if let Some(query) = self.queries.get_mut(id) {
Ok(())
// if let Some(semaphore) = &query.semaphore {
// semaphore.acquire().await.unwrap(); // Block until the semaphore is released
// Ok(())
// } else {
// Err("No semaphore set for this query")
// }
} else {
Err("Invalid query ID")
}
}
}

struct Objects {
service: Option<Box<dyn Service>>,
// responder: Option<Box<dyn Responder>>,
querier: Option<Querier>,
}

// lazy_static! {
// static ref SERVICE: Arc<Mutex<Option<Box<dyn Service>>>> = Arc::new(Mutex::new(None));
// }

lazy_static! {
static ref SERVER: Arc<Mutex<Objects>> = Arc::new(Mutex::new(Objects {
service: None,
querier: None,
}));
}

fn read_cb(vec: &[u8]) {
println!("Received {:?}", vec);
parse_dns_response(vec).unwrap();
}

pub fn mdns_init() {
build_info();
let mut service_guard = SERVER.lock().unwrap();
if service_guard.service.is_none() {
// Initialize the service only if it hasn't been initialized
service_guard.service = Some(create_service(read_cb));
}
if service_guard.querier.is_none() {
// Initialize the service only if it hasn't been initialized
service_guard.querier = Some(init());
service_guard.querier = Some(Querier::new());
}

println!("mdns_init called");
}

Expand All @@ -186,26 +99,63 @@ pub fn mdns_deinit() {
println!("mdns_deinit called");
}

fn create_a_query(name: &str) -> Vec<u8> {
let query_type = QueryType::A; // Type A query for IPv4 address
let query_class = QueryClass::IN; // Class IN (Internet)

// Create a new query with ID and recursion setting
let mut builder = Builder::new_query(0x5555, true);
/*
pub fn mdns_query(name: &str) {
let mut service_guard = SERVER.lock().unwrap();
// Add the question for "david-work.local"
builder.add_question(name, false, query_type, query_class);
if let Some(querier) = &mut service_guard.querier {
let timeout = Duration::from_secs(5);
let query_id = querier.add(
name.to_string(),
"".to_string(),
"_http._tcp".to_string(),
QueryType::A,
false,
timeout,
);
querier.wait(query_id).await.unwrap();
println!("Query added with ID: {}", query_id);
}
// Build and return the query packet
builder.build().unwrap_or_else(|x| x)
}
*/

pub fn mdns_query(name: &str) {
let service_guard = SERVER.lock().unwrap();
if let Some(service) = &service_guard.service {
let packet = create_a_query(name);
service.send(packet);
} else {
println!("Service not initialized");
// Lock the server to access the querier
let query_id;
let querier_cvar;

{
let mut service_guard = SERVER.lock().unwrap();
if let Some(querier) = &mut service_guard.querier {
let timeout = Duration::from_secs(1);
query_id = querier.add(
name.to_string(),
"".to_string(),
"_http._tcp".to_string(),
QueryType::A,
false,
timeout,
);
querier_cvar = querier.completed_queries.clone(); // Clone the Arc<Mutex> pair
} else {
println!("No querier available");
return;
}
} // Release the SERVER lock here

// Wait for the query to complete
let (lock, cvar) = &*querier_cvar;
let mut completed = lock.lock().unwrap();
while !completed.get(&query_id).copied().unwrap_or(false) {
let result = cvar.wait_timeout(completed, Duration::from_secs(5)).unwrap();
completed = result.0; // Update the lock guard
if result.1.timed_out() {
println!("Query timed out: ID {}", query_id);
return;
}
}

println!("Query completed!!! ID {}", query_id);
}
Loading

0 comments on commit 07b53fc

Please sign in to comment.