diff --git a/Cargo.lock b/Cargo.lock index 9f372f919..be83d777f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -533,16 +533,16 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.31" +version = "0.4.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f2c685bad3eb3d45a01354cedb7d5faa66194d1d58ba6e267a8de788f79db38" +checksum = "41daef31d7a747c5c847246f36de49ced6f7403b4cdabc807a97b5cc184cda7a" dependencies = [ "android-tzdata", "iana-time-zone", "js-sys", "num-traits", "wasm-bindgen", - "windows-targets 0.48.5", + "windows-targets 0.52.0", ] [[package]] @@ -2086,6 +2086,7 @@ name = "network-scanner" version = "0.0.0" dependencies = [ "anyhow", + "futures", "network-scanner-net", "network-scanner-proto", "socket2", @@ -2100,11 +2101,13 @@ version = "0.0.0" dependencies = [ "anyhow", "crossbeam", + "futures", "parking_lot", "polling", "socket2", "thiserror", "tokio", + "tokio-stream", "tracing", "tracing-subscriber", ] @@ -2687,8 +2690,7 @@ checksum = "626dec3cac7cc0e1577a2ec3fc496277ec2baa084bebad95bb6fdbfae235f84c" [[package]] name = "polling" version = "3.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "545c980a3880efd47b2e262f6a4bb6daad6555cf3367aa9c4e52895f69537a41" +source = "git+https://github.com/irvingoujAtDevolution/polling.git#159ab1b540b75a86f7735bcb57449b4b766e49cf" dependencies = [ "cfg-if", "concurrent-queue", diff --git a/crates/network-scanner-net/Cargo.toml b/crates/network-scanner-net/Cargo.toml index f1be3be94..c86396883 100644 --- a/crates/network-scanner-net/Cargo.toml +++ b/crates/network-scanner-net/Cargo.toml @@ -9,10 +9,12 @@ publish = false [dependencies] anyhow = "1.0.79" crossbeam = { version = "0.8.4", features = ["crossbeam-channel"] } +futures = "0.3.30" parking_lot = "0.12.1" -polling = "3.3.2" +polling = {git = "https://github.com/irvingoujAtDevolution/polling.git"} socket2 = { version = "0.5.5", features = ["all"] } thiserror = "1.0.56" +tokio-stream = "0.1.14" tracing = "0.1.40" [dev-dependencies] diff --git a/crates/network-scanner-net/examples/drop_runtime.rs b/crates/network-scanner-net/examples/drop_runtime.rs new file mode 100644 index 000000000..026fdaa79 --- /dev/null +++ b/crates/network-scanner-net/examples/drop_runtime.rs @@ -0,0 +1,92 @@ +use std::{net::SocketAddr, sync::Arc}; + +#[tokio::main] +pub async fn main() -> anyhow::Result<()> { + utils::start_server(); + tracing_subscriber::fmt::SubscriberBuilder::default() + .with_max_level(tracing::Level::TRACE) + .with_thread_names(true) + .init(); + { + let async_runtime = network_scanner_net::runtime::Socket2Runtime::new(None)?; + { + let good_socket = async_runtime.new_socket(socket2::Domain::IPV4, socket2::Type::STREAM, None)?; + let bad_socket = async_runtime.new_socket(socket2::Domain::IPV4, socket2::Type::STREAM, None)?; + + tracing::info!("good_socket: {:?}", good_socket); + tracing::info!("bad_socket: {:?}", bad_socket); + + let available_port = 8080; + let non_available_port = 12345; + + let available_addr = SocketAddr::from(([127, 0, 0, 1], available_port)); + let non_available_addr = SocketAddr::from(([127, 0, 0, 1], non_available_port)); + + let handle = + tokio::task::spawn(async move { good_socket.connect(&socket2::SockAddr::from(available_addr)).await }); + + let handle2 = + tokio::task::spawn( + async move { bad_socket.connect(&socket2::SockAddr::from(non_available_addr)).await }, + ); + + let (a, b) = tokio::join!(handle, handle2); + // remove the outer error from tokio task + let a = a?; + let b = b?; + tracing::info!("should connect: {:?}", &a); + tracing::info!("should not connect: {:?}", &b); + assert!(a.is_ok()); + assert!(b.is_err()); + } + tracing::info!("runtime arc count: {}", Arc::strong_count(&async_runtime)); + assert!(Arc::strong_count(&async_runtime) == 1); + } + tracing::info!("runtime should be dropped here"); + Ok(()) +} + +mod utils { + use std::io::Read; + use std::io::Write; + use std::net::{TcpListener, TcpStream}; + use std::thread; + + fn handle_client(mut stream: TcpStream) { + // read 20 bytes at a time from stream echoing back to stream + loop { + let mut read = [0; 1028]; + match stream.read(&mut read) { + Ok(n) => { + if n == 0 { + // connection was closed + break; + } + let _ = stream.write(&read[0..n]).unwrap(); + } + Err(_) => { + return; + } + } + } + } + + pub(super) fn start_server() { + thread::spawn(|| { + let listener = TcpListener::bind("127.0.0.1:8080").unwrap(); + + for stream in listener.incoming() { + match stream { + Ok(stream) => { + thread::spawn(move || { + handle_client(stream); + }); + } + Err(_) => { + // println!("Error"); + } + } + } + }); + } +} diff --git a/crates/network-scanner-net/examples/intense_tcp.rs b/crates/network-scanner-net/examples/intense_tcp.rs new file mode 100644 index 000000000..2a9215541 --- /dev/null +++ b/crates/network-scanner-net/examples/intense_tcp.rs @@ -0,0 +1,107 @@ +use std::{ + mem::MaybeUninit, + net::{SocketAddr, SocketAddrV4}, +}; + +use network_scanner_net::socket::AsyncRawSocket; +use socket2::SockAddr; +use tokio::task::JoinHandle; + +/// This example needs to be run with a echo server running on a different process +/// ``` +/// use tokio::net::TcpListener; +/// use tokio::io::{AsyncReadExt, AsyncWriteExt}; +/// use std::env; +/// #[tokio::main] +/// async fn main() -> Result<(), Box> { +/// let addr = env::args().nth(1).unwrap_or_else(|| "127.0.0.1:8080".to_string()); +/// let listener = TcpListener::bind(&addr).await?; +/// println!("Listening on: {}", addr); +/// loop { +/// let (mut socket, _) = listener.accept().await?; +/// tokio::spawn(async move { +/// let mut buf = vec![0; 1024]; +/// // In a loop, read data from the socket and write the data back. +/// loop { +/// let n = match socket.read(&mut buf).await { +/// // socket closed +/// Ok(n) if n == 0 => return, +/// Ok(n) => n, +/// Err(e) => { +/// eprintln!("Failed to read from socket; err = {:?}", e); +/// return; +/// } +/// }; +/// println!("Received {} bytes", n); +/// // Write the data back +/// if let Err(e) = socket.write_all(&buf[0..n]).await { +/// eprintln!("Failed to write to socket; err = {:?}", e); +/// return; +/// } +/// } +/// }); +/// } +/// } + +/// ``` +#[tokio::main(flavor = "multi_thread", worker_threads = 12)] +pub async fn main() -> anyhow::Result<()> { + tracing_subscriber::fmt::SubscriberBuilder::default() + .with_max_level(tracing::Level::INFO) + .with_thread_names(true) + .init(); + + let async_runtime = network_scanner_net::runtime::Socket2Runtime::new(None)?; + let mut socket_arr = vec![]; + for _ in 0..100 { + tracing::info!("Creating socket"); + let socket = async_runtime.new_socket(socket2::Domain::IPV4, socket2::Type::STREAM, None)?; + socket_arr.push(socket); + } + + fn connect_and_write_something( + mut socket: AsyncRawSocket, + addr: std::net::SocketAddr, + ) -> JoinHandle> { + tokio::task::spawn(async move { + let addr: SockAddr = addr.into(); + socket.connect(&addr).await?; + let mut buffer = [MaybeUninit::uninit(); 1024]; + for i in 0..1000 { + let data = format!("hello world {} times", i); + tracing::info!("Sending: {} from socket {:?}", &data, &socket); + let write_future = socket.send(data.as_bytes()); + + let size = tokio::time::timeout(std::time::Duration::from_secs(1), write_future).await??; + + if size == 0 { + return Ok(()); + } + + let recv_future = socket.recv(&mut buffer); + tokio::time::timeout(std::time::Duration::from_secs(1), recv_future).await??; + let received = buffer[..size] + .iter() + .map(|x| unsafe { x.assume_init() }) + .collect::>(); + assert_eq!(received, data.as_bytes()); + tracing::debug!("Received: {}", std::str::from_utf8(&received)?); + } + Ok(()) + }) + } + + let mut futures = vec![]; + for socket in socket_arr { + let addr: SocketAddr = SocketAddrV4::new(std::net::Ipv4Addr::new(127, 0, 0, 1), 8080).into(); + + let future = connect_and_write_something(socket, addr); + futures.push(future); + } + + for future in futures { + future.await??; + } + + Ok(()) +} diff --git a/crates/network-scanner-net/examples/tcp_fail.rs b/crates/network-scanner-net/examples/tcp_fail.rs new file mode 100644 index 000000000..bb688a817 --- /dev/null +++ b/crates/network-scanner-net/examples/tcp_fail.rs @@ -0,0 +1,83 @@ +use std::net::SocketAddr; + +#[tokio::main] +pub async fn main() -> anyhow::Result<()> { + utils::start_server(); + tracing_subscriber::fmt::SubscriberBuilder::default() + .with_max_level(tracing::Level::TRACE) + .with_thread_names(true) + .init(); + + let async_runtime = network_scanner_net::runtime::Socket2Runtime::new(None)?; + let good_socket = async_runtime.new_socket(socket2::Domain::IPV4, socket2::Type::STREAM, None)?; + let bad_socket = async_runtime.new_socket(socket2::Domain::IPV4, socket2::Type::STREAM, None)?; + + tracing::info!("good_socket: {:?}", good_socket); + tracing::info!("bad_socket: {:?}", bad_socket); + + let available_port = 8080; + let non_available_port = 12345; + + let available_addr = SocketAddr::from(([127, 0, 0, 1], available_port)); + let non_available_addr = SocketAddr::from(([127, 0, 0, 1], non_available_port)); + + let handle = tokio::task::spawn(async move { good_socket.connect(&socket2::SockAddr::from(available_addr)).await }); + + let handle2 = + tokio::task::spawn(async move { bad_socket.connect(&socket2::SockAddr::from(non_available_addr)).await }); + + let (a, b) = tokio::join!(handle, handle2); + // remove the outer error from tokio task + let a = a?; + let b = b?; + tracing::info!("should connect: {:?}", &a); + tracing::info!("should not connect: {:?}", &b); + assert!(a.is_ok()); + assert!(b.is_err()); + Ok(()) +} + +mod utils { + use std::io::Read; + use std::io::Write; + use std::net::{TcpListener, TcpStream}; + use std::thread; + + fn handle_client(mut stream: TcpStream) { + // read 20 bytes at a time from stream echoing back to stream + loop { + let mut read = [0; 1028]; + match stream.read(&mut read) { + Ok(n) => { + if n == 0 { + // connection was closed + break; + } + let _ = stream.write(&read[0..n]).unwrap(); + } + Err(_) => { + return; + } + } + } + } + + pub(super) fn start_server() { + thread::spawn(|| { + let listener = TcpListener::bind("127.0.0.1:8080").unwrap(); + + for stream in listener.incoming() { + match stream { + Ok(stream) => { + thread::spawn(move || { + handle_client(stream); + }); + } + Err(_) => { + // println!("Error"); + } + } + } + }); + } +} diff --git a/crates/network-scanner-net/src/runtime.rs b/crates/network-scanner-net/src/runtime.rs index 85c1d1861..cd81e125f 100644 --- a/crates/network-scanner-net/src/runtime.rs +++ b/crates/network-scanner-net/src/runtime.rs @@ -1,5 +1,6 @@ use std::{ - collections::HashMap, + collections::{HashMap, HashSet}, + hash::Hash, num::NonZeroUsize, sync::{ atomic::{AtomicBool, AtomicUsize, Ordering}, @@ -10,6 +11,7 @@ use std::{ use anyhow::Context; use crossbeam::channel::{Receiver, Sender}; +use parking_lot::Mutex; use polling::{Event, Events}; use socket2::Socket; @@ -17,16 +19,18 @@ use crate::{socket::AsyncRawSocket, ScannnerNetError}; #[derive(Debug)] pub struct Socket2Runtime { - poller: polling::Poller, + poller: Arc, next_socket_id: AtomicUsize, - is_terminated: AtomicBool, - sender: Sender, + is_terminated: Arc, + register_sender: Sender, + event_receiver: Receiver, + event_cache: Mutex>, } impl Drop for Socket2Runtime { fn drop(&mut self) { - self.is_terminated.store(true, Ordering::SeqCst); tracing::debug!("dropping runtime"); + self.is_terminated.store(true, Ordering::SeqCst); let _ = self // ignore errors, cannot handle it here .poller .notify() @@ -41,15 +45,22 @@ impl Socket2Runtime { /// Create a new runtime with a queue capacity, default is 1024. pub fn new(queue_capacity: Option) -> anyhow::Result> { let poller = polling::Poller::new()?; - let (sender, receiver) = crossbeam::channel::bounded(queue_capacity.unwrap_or(QUEUE_CAPACITY)); + + let (register_sender, register_receiver) = + crossbeam::channel::bounded(queue_capacity.unwrap_or(QUEUE_CAPACITY)); + + let (event_sender, event_receiver) = crossbeam::channel::bounded(queue_capacity.unwrap_or(QUEUE_CAPACITY)); + let runtime = Self { - poller, + poller: Arc::new(poller), next_socket_id: AtomicUsize::new(0), - is_terminated: AtomicBool::new(false), - sender, + is_terminated: Arc::new(AtomicBool::new(false)), + register_sender, + event_receiver, + event_cache: Mutex::new(HashSet::new()), }; let runtime = Arc::new(runtime); - runtime.clone().start_loop(receiver)?; + runtime.start_loop(register_receiver, event_sender)?; Ok(runtime) } @@ -67,37 +78,51 @@ impl Socket2Runtime { Ok(AsyncRawSocket::from_socket(socket, id, self.clone())?) } - pub(crate) fn remove_socket(&self, socket: &socket2::Socket) -> anyhow::Result<()> { + pub(crate) fn remove_socket(&self, socket: &socket2::Socket, id: usize) -> anyhow::Result<()> { self.poller.delete(socket)?; + // remove all events related to this socket + self.event_cache.lock().retain(|event| id == event.0.key); Ok(()) } - fn start_loop(self: Arc, receiver: Receiver) -> anyhow::Result<()> { + fn start_loop( + &self, + register_receiver: Receiver, + event_sender: Sender, + ) -> anyhow::Result<()> { + // we make is_terminated Arc and poller Arc so that we can clone them and move them into the thread + // we cannot hold a Arc in the thread, because it will create a cycle reference and the runtime will never be dropped. + let is_terminated = self.is_terminated.clone(); + let poller = self.poller.clone(); std::thread::Builder::new() .name("[raw-socket]:io-event-loop".to_string()) .spawn(move || { let mut events = Events::with_capacity(NonZeroUsize::new(1024).unwrap()); tracing::debug!("starting io event loop"); + // events registered but not happened yet let mut events_registered = HashMap::new(); - let mut events_happend = HashMap::new(); + + // events happened but not registered yet + let mut events_happened = HashMap::new(); + loop { - if self.is_terminated.load(Ordering::Acquire) { + if is_terminated.load(Ordering::Acquire) { break; } tracing::debug!("polling events"); - if let Err(e) = self.poller.wait(&mut events, None) { + if let Err(e) = poller.wait(&mut events, None) { tracing::error!(error = ?e, "failed to poll events"); - self.is_terminated.store(true, Ordering::SeqCst); + is_terminated.store(true, Ordering::SeqCst); break; }; for event in events.iter() { - tracing::trace!(?event, "event happend"); - events_happend.insert(event.key, event); + tracing::trace!(?event, "event happened"); + events_happened.insert(event.key, event); } events.clear(); - while let Ok(event) = receiver.try_recv() { + while let Ok(event) = register_receiver.try_recv() { match event { RegisterEvent::Register { id, waker } => { events_registered.insert(id, waker); @@ -108,25 +133,79 @@ impl Socket2Runtime { } } - let intersection = events_happend + let intersection = events_happened .keys() .filter(|key| events_registered.contains_key(key)) .cloned() .collect::>(); intersection.into_iter().for_each(|ref key| { - let event = events_happend.remove(key).unwrap(); + let event = events_happened.remove(key).unwrap(); let waker = events_registered.remove(key).unwrap(); + let _ = event_sender.try_send(event); waker.wake_by_ref(); tracing::trace!(?event, "waking up waker"); }); } + tracing::info!("io event loop terminated"); }) .with_context(|| "failed to spawn io event loop thread")?; Ok(()) } + /// Ideally, we should have a dedicated thread to handle events we received, but we don't really want to spawn a second thread + /// Alternatively, we can have all socket futures call this function to check if there is any event for them. + /// The number of times the socket futures is polled is almost guaranteed to be more than the number of registration we received. + /// hence the event receiver will not be blocked. + pub(crate) fn check_event(&self, event: Event, remove: bool) -> Option { + let mut event_cache = self.event_cache.lock(); + while let Ok(event) = self.event_receiver.try_recv() { + event_cache.insert(event.into()); + } + tracing::debug!("checking event cache {:?}", event_cache); + + let event = if remove { + event_cache.take(&event.into()) + } else if event_cache.contains(&event.into()) { + Some(event.into()) + } else { + None + }; + + event.map(|event| event.into_inner()) + } + + pub(crate) fn check_event_with_id(&self, id: usize, remove: bool) -> Vec { + let mut event_cache = self.event_cache.lock(); + while let Ok(event) = self.event_receiver.try_recv() { + event_cache.insert(event.into()); + } + let event_interested = vec![ + Event::readable(id), + Event::writable(id), + Event::all(id), + Event::none(id), + ]; + let mut res = vec![]; + + if remove { + event_interested.into_iter().for_each(|event| { + if let Some(event) = event_cache.take(&event.into()) { + res.push(event.into_inner()); + } + }); + } else { + event_interested.into_iter().for_each(|event| { + if event_cache.contains(&event.into()) { + res.push(event); + } + }); + } + + res + } + pub(crate) fn register(&self, socket: &Socket, event: Event, waker: Waker) -> anyhow::Result<()> { if self.is_terminated.load(Ordering::Acquire) { Err(ScannnerNetError::AsyncRuntimeError("runtime is terminated".to_string()))?; @@ -138,7 +217,7 @@ impl Socket2Runtime { // Use try_send instead of send, in case some io events blocked the queue completely, // it would be better to drop the register event then block the worker thread or main thread. // as the worker thread is shared for the entire application. - self.sender + self.register_sender .try_send(RegisterEvent::Register { id: event.key, waker }) .with_context(|| "failed to send register event to register loop") } @@ -147,7 +226,7 @@ impl Socket2Runtime { if self.is_terminated.load(Ordering::Acquire) { Err(ScannnerNetError::AsyncRuntimeError("runtime is terminated".to_string()))?; } - self.sender + self.register_sender .try_send(RegisterEvent::Unregister { id }) .with_context(|| "failed to send unregister event to register loop") } @@ -158,3 +237,34 @@ enum RegisterEvent { Register { id: usize, waker: Waker }, Unregister { id: usize }, } + +#[derive(Debug)] +pub struct EventWrapper(Event); + +impl Hash for EventWrapper { + fn hash(&self, state: &mut H) { + self.0.key.hash(state); + self.0.readable.hash(state); + self.0.writable.hash(state); + } +} + +impl PartialEq for EventWrapper { + fn eq(&self, other: &Self) -> bool { + self.0.key == other.0.key && self.0.readable == other.0.readable && self.0.writable == other.0.writable + } +} + +impl Eq for EventWrapper {} + +impl From for EventWrapper { + fn from(event: Event) -> Self { + Self(event) + } +} + +impl EventWrapper { + pub(crate) fn into_inner(self) -> Event { + self.0 + } +} diff --git a/crates/network-scanner-net/src/socket.rs b/crates/network-scanner-net/src/socket.rs index e78371d19..61f305c89 100644 --- a/crates/network-scanner-net/src/socket.rs +++ b/crates/network-scanner-net/src/socket.rs @@ -1,5 +1,5 @@ use polling::Event; -use std::{future::Future, mem::MaybeUninit, sync::Arc, usize}; +use std::{fmt::Debug, future::Future, mem::MaybeUninit, sync::Arc, usize}; use socket2::{SockAddr, Socket}; use std::result::Result::Ok; @@ -13,12 +13,21 @@ pub struct AsyncRawSocket { id: usize, } +impl Debug for AsyncRawSocket { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AsyncRawSocket") + .field("socket", &self.socket) + .field("id", &self.id) + .finish() + } +} + impl Drop for AsyncRawSocket { fn drop(&mut self) { tracing::trace!(id = %self.id,socket = ?self.socket, "drop socket"); let _ = self // We ignore errors here, avoid crashing the thread .runtime - .remove_socket(&self.socket) + .remove_socket(&self.socket, self.id) .map_err(|e| tracing::error!("failed to remove socket from poller: {:?}", e)); } } @@ -98,7 +107,6 @@ impl<'a> AsyncRawSocket { runtime: self.runtime.clone(), addr, id: self.id, - is_first_poll: true, } } @@ -133,6 +141,9 @@ struct RecvFromFuture<'a> { impl Future for RecvFromFuture<'_> { type Output = std::io::Result<(usize, SockAddr)>; fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { + // By checking event at every call, it removes excessive events excessive in the cache and the channel + self.runtime.check_event(Event::readable(self.id), true); + let socket = &self.socket.clone(); // avoid borrow checker error match socket.recv_from(self.buf) { Ok(a) => std::task::Poll::Ready(Ok(a)), @@ -153,6 +164,7 @@ impl<'a> Future for SendToFuture<'a> { type Output = std::io::Result; fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { + self.runtime.check_event(Event::writable(self.id), true); match self.socket.send_to(self.data, self.addr) { Ok(a) => std::task::Poll::Ready(Ok(a)), Err(e) => resolve(e, &self.socket, &self.runtime, Event::writable(self.id), cx.waker()), @@ -170,6 +182,7 @@ impl Future for AcceptFuture { type Output = std::io::Result<(AsyncRawSocket, SockAddr)>; fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { + self.runtime.check_event(Event::readable(self.id), true); match self.socket.accept() { Ok((socket, addr)) => { let socket = AsyncRawSocket::from_socket(socket, self.id, self.runtime.clone())?; @@ -184,50 +197,53 @@ struct ConnectFuture<'a> { runtime: Arc, id: usize, addr: &'a socket2::SockAddr, - is_first_poll: bool, } impl<'a> Future for ConnectFuture<'a> { type Output = std::io::Result<()>; - fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { - if self.is_first_poll { - tracing::trace!("first poll connect future"); - // cannot call connect twice - self.is_first_poll = false; - let err = match self.socket.connect(self.addr) { - Ok(a) => { - return std::task::Poll::Ready(Ok(a)); - } - Err(e) => e, - }; - - // code 115, EINPROGRESS, only for linux - // reference: https://linux.die.net/man/2/connect - // it is the same as WouldBlock but for connect(2) only - #[cfg(target_os = "linux")] - let in_progress = err.kind() == std::io::ErrorKind::WouldBlock || err.raw_os_error() == Some(115); - - #[cfg(not(target_os = "linux"))] - let in_progress = err.kind() == std::io::ErrorKind::WouldBlock; - - if in_progress { - tracing::trace!("connect should register"); - if let Err(e) = self - .runtime - .register(&self.socket, Event::all(self.id), cx.waker().clone()) - { - tracing::warn!(?self.socket, ?self.addr, "failed to register socket to poller"); - return std::task::Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::Other, - format!("failed to register socket to poller: {}", e), - ))); - } - return std::task::Poll::Pending; + fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { + let events = self.runtime.check_event_with_id(self.id, true); + if events.iter().any(|e| e.is_connect_failed().unwrap_or(false)) { + tracing::warn!(?self.socket, ?self.addr, "connect failed"); + return std::task::Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, "connect failed"))); + }; + + if !events.is_empty() { + tracing::trace!(?events, "connect success"); + return std::task::Poll::Ready(Ok(())); + } + + let err = match self.socket.connect(self.addr) { + Ok(a) => { + return std::task::Poll::Ready(Ok(a)); + } + Err(e) => e, + }; + + // code 115, EINPROGRESS, only for linux + // reference: https://linux.die.net/man/2/connect + // it is the same as WouldBlock but for connect(2) only + #[cfg(target_os = "linux")] + let in_progress = err.kind() == std::io::ErrorKind::WouldBlock || err.raw_os_error() == Some(115); + + #[cfg(not(target_os = "linux"))] + let in_progress = err.kind() == std::io::ErrorKind::WouldBlock; + + if in_progress { + tracing::trace!("connect should register"); + if let Err(e) = self + .runtime + .register(&self.socket, Event::all(self.id), cx.waker().clone()) + { + tracing::warn!(?self.socket, ?self.addr, "failed to register socket to poller"); + return std::task::Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::Other, + format!("failed to register socket to poller: {}", e), + ))); } } - tracing::trace!("second poll connect future"); - std::task::Poll::Ready(Ok(())) + std::task::Poll::Pending } } @@ -242,6 +258,7 @@ impl<'a> Future for SendFuture<'a> { type Output = std::io::Result; fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { + self.runtime.check_event(Event::writable(self.id), true); match self.socket.send(self.data) { Ok(a) => std::task::Poll::Ready(Ok(a)), Err(e) => resolve(e, &self.socket, &self.runtime, Event::writable(self.id), cx.waker()), @@ -260,6 +277,7 @@ impl<'a> Future for RecvFuture<'a> { type Output = std::io::Result; fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { + self.runtime.check_event(Event::readable(self.id), true); let socket = &self.socket.clone(); // avoid borrow checker error match socket.recv(self.buf) { Ok(a) => std::task::Poll::Ready(Ok(a)), diff --git a/crates/network-scanner-net/src/test.rs b/crates/network-scanner-net/src/test.rs index 114c2e0c7..2d4f332ba 100644 --- a/crates/network-scanner-net/src/test.rs +++ b/crates/network-scanner-net/src/test.rs @@ -1,26 +1,21 @@ use std::{ - io::{ErrorKind, Read, Write}, + io::ErrorKind, mem::MaybeUninit, net::{SocketAddr, UdpSocket}, + sync::{atomic::AtomicBool, Arc}, + time::Duration, }; use socket2::SockAddr; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + task::JoinHandle, +}; use crate::socket::AsyncRawSocket; -#[ignore] -#[tokio::test(flavor = "multi_thread", worker_threads = 1)] -async fn test_connectivity() -> anyhow::Result<()> { - let addr = local_tcp_server()?; - let runtime = crate::runtime::Socket2Runtime::new(None)?; - let socket = runtime.new_socket(socket2::Domain::IPV4, socket2::Type::STREAM, None)?; - socket.connect(&socket2::SockAddr::from(addr)).await?; - Ok(()) -} - -#[ignore] #[tokio::test(flavor = "multi_thread", worker_threads = 4)] +#[ignore = "TODO"] async fn multiple_udp() -> anyhow::Result<()> { let addr = local_udp_server()?; tokio::time::sleep(std::time::Duration::from_millis(200)).await; // wait for the other socket to start @@ -30,21 +25,22 @@ async fn multiple_udp() -> anyhow::Result<()> { let socket2 = runtime.new_socket(socket2::Domain::IPV4, socket2::Type::DGRAM, None)?; let socket3 = runtime.new_socket(socket2::Domain::IPV4, socket2::Type::DGRAM, None)?; - fn send_to(mut socket: AsyncRawSocket, number: u8, addr: SocketAddr) -> tokio::task::JoinHandle<()> { + fn send_to( + mut socket: AsyncRawSocket, + number: u8, + addr: SocketAddr, + ) -> tokio::task::JoinHandle> { tokio::task::spawn(async move { let msg = format!("hello from socket {}", number); - socket - .send_to(msg.as_bytes(), &SockAddr::from(addr)) - .await - .expect("send_to"); + socket.send_to(msg.as_bytes(), &SockAddr::from(addr)).await?; + let mut buf = [MaybeUninit::::uninit(); 1024]; - let (size, addr) = socket - .recv_from(&mut buf) - .await - .unwrap_or_else(|_| panic!("recv_from: {}", number)); + let (size, addr) = socket.recv_from(&mut buf).await?; + tracing::info!("size: {}, addr: {:?}", size, addr); let back = unsafe { crate::assume_init(&buf[..size]) }; assert_eq!(back, format!("hello from socket {}", number).as_bytes()); + Ok::<(), anyhow::Error>(()) }) } @@ -57,32 +53,59 @@ async fn multiple_udp() -> anyhow::Result<()> { ]; for handle in handles { - handle.await?; + tokio::time::timeout(Duration::from_secs(10), handle).await???; } Ok(()) } -#[ignore] #[tokio::test(flavor = "multi_thread", worker_threads = 4)] +#[ignore = "TODO"] +async fn test_connectivity() -> anyhow::Result<()> { + let kill_server = Arc::new(AtomicBool::new(false)); + let (addr, handle) = local_tcp_server(kill_server.clone()).await?; + tokio::time::sleep(std::time::Duration::from_millis(200)).await; // wait for the other socket to start + + let runtime = crate::runtime::Socket2Runtime::new(None)?; + let socket = runtime.new_socket(socket2::Domain::IPV4, socket2::Type::STREAM, None)?; + let addr: SockAddr = addr.into(); + socket.connect(&addr).await?; + + // clean up + kill_server.store(true, std::sync::atomic::Ordering::Relaxed); + handle.abort(); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 6)] +#[ignore = "TODO"] async fn multiple_tcp() -> anyhow::Result<()> { - let addr = local_tcp_server()?; + let kill_server = Arc::new(AtomicBool::new(false)); + let (addr, handle) = local_tcp_server(kill_server.clone()).await?; + tokio::time::sleep(std::time::Duration::from_millis(200)).await; // wait for the other socket to start + let runtime = crate::runtime::Socket2Runtime::new(None)?; let socket0 = runtime.new_socket(socket2::Domain::IPV4, socket2::Type::STREAM, None)?; let socket1 = runtime.new_socket(socket2::Domain::IPV4, socket2::Type::STREAM, None)?; let socket2 = runtime.new_socket(socket2::Domain::IPV4, socket2::Type::STREAM, None)?; let socket3 = runtime.new_socket(socket2::Domain::IPV4, socket2::Type::STREAM, None)?; - fn connect(mut socket: AsyncRawSocket, number: u8, addr: SocketAddr) -> tokio::task::JoinHandle<()> { + fn connect( + mut socket: AsyncRawSocket, + number: u8, + addr: SocketAddr, + ) -> tokio::task::JoinHandle> { tokio::task::spawn(async move { - socket.connect(&socket2::SockAddr::from(addr)).await.expect("connect"); + socket.connect(&socket2::SockAddr::from(addr)).await?; let msg = format!("hello from socket {}", number); - socket.send(msg.as_bytes()).await.expect("send"); + socket.send(msg.as_bytes()).await?; let mut buf = [MaybeUninit::::uninit(); 1024]; - let size = socket.recv(&mut buf).await.expect("recv"); + let size = socket.recv(&mut buf).await?; tracing::info!("size: {}", size); let back = unsafe { crate::assume_init(&buf[..size]) }; assert_eq!(back, format!("hello from socket {}", number).as_bytes()); + Ok::<(), anyhow::Error>(()) }) } @@ -95,16 +118,23 @@ async fn multiple_tcp() -> anyhow::Result<()> { ]; for handle in handles { - handle.await?; + tokio::time::timeout(Duration::from_secs(5), handle).await???; } + // clean up + kill_server.store(true, std::sync::atomic::Ordering::Relaxed); + handle.abort(); + Ok(()) } -#[ignore] #[tokio::test(flavor = "multi_thread", worker_threads = 4)] +#[ignore = "TODO"] async fn work_with_tokio_tcp() -> anyhow::Result<()> { - let addr = local_tcp_server()?; + let kill_server = Arc::new(AtomicBool::new(false)); + let (addr, tcp_handle) = local_tcp_server(kill_server.clone()).await?; + tokio::time::sleep(std::time::Duration::from_millis(200)).await; // wait for the other socket to start + let runtime = crate::runtime::Socket2Runtime::new(None)?; let mut socket = runtime.new_socket(socket2::Domain::IPV4, socket2::Type::STREAM, None)?; @@ -115,7 +145,7 @@ async fn work_with_tokio_tcp() -> anyhow::Result<()> { socket.send(msg.as_bytes()).await?; let mut buf = [MaybeUninit::::uninit(); 1024]; let size = socket.recv(&mut buf).await?; - tracing::info!("size: {}", size); + tracing::debug!("size: {}", size); let back = unsafe { crate::assume_init(&buf[..size]) }; assert_eq!(back, msg.as_bytes()); } @@ -137,9 +167,12 @@ async fn work_with_tokio_tcp() -> anyhow::Result<()> { Ok::<(), anyhow::Error>(()) }); - let (a, b) = tokio::join!(handle, handle2); - a??; - b??; + tokio::time::timeout(Duration::from_secs(5), handle).await???; + tokio::time::timeout(Duration::from_secs(5), handle2).await???; + + // clean up + kill_server.store(true, std::sync::atomic::Ordering::Relaxed); + tcp_handle.abort(); Ok(()) } @@ -151,18 +184,19 @@ fn local_udp_server() -> anyhow::Result { std::thread::spawn(move || { // Create and bind the UDP socket - println!("UDP server listening on {}", socket.local_addr()?); + tracing::debug!("UDP server listening on {}", socket.local_addr()?); let mut buffer = [0u8; 1024]; // A buffer to store incoming data loop { match socket.recv_from(&mut buffer) { Ok((size, src)) => { - println!("Received {} bytes from {}", size, src); - let socket_clone = socket.try_clone().expect("Failed to clone socket"); + tracing::trace!("Received {} bytes from {}", size, src); + let socket_clone = socket.try_clone()?; std::thread::spawn(move || { std::thread::sleep(std::time::Duration::from_millis(200)); // simulate some work - socket_clone.send_to(&buffer[..size], src).expect("Failed to send data") + socket_clone.send_to(&buffer[..size], src)?; + Ok::<(), anyhow::Error>(()) }); } Err(ref e) if e.kind() == ErrorKind::WouldBlock => { @@ -180,42 +214,55 @@ fn local_udp_server() -> anyhow::Result { Ok(res) } -fn handle_client(mut stream: std::net::TcpStream) -> std::io::Result<()> { +async fn handle_client(mut stream: tokio::net::TcpStream, awake: Arc) -> std::io::Result<()> { let mut buffer = [0; 1024]; loop { - // Read data from the stream - let size = stream.read(&mut buffer)?; - println!("Received {} bytes: {:?}", size, &buffer[..size]); + let read_future = stream.read(&mut buffer); + + let size = match tokio::time::timeout(Duration::from_secs(1), read_future).await { + Ok(res) => res?, + Err(_) => { + if awake.load(std::sync::atomic::Ordering::Relaxed) { + return Ok(()); + } + continue; + } + }; + + if size == 0 { + return Ok(()); + } + + tracing::debug!("Received {} bytes: {:?}", size, &buffer[..size]); std::thread::sleep(std::time::Duration::from_millis(200)); // simulate some work - stream.write_all(&buffer[..size])?; // Echo the data back to the client - println!("Echoed back {} bytes", size); + stream.write_all(&buffer[..size]).await?; // Echo the data back to the client } } -fn local_tcp_server() -> anyhow::Result { - // Bind the TCP listener to a local address - let listener = std::net::TcpListener::bind("127.0.0.1:0").expect("Could not bind TCP listener"); - println!("TCP server listening on {}", listener.local_addr().unwrap()); - let res = listener.local_addr().unwrap(); - std::thread::spawn(move || { - // Accept incoming connections - for stream in listener.incoming() { - match stream { - Ok(stream) => { - // Spawn a new thread for each connection - std::thread::spawn(move || { - if let Err(e) = handle_client(stream) { - tracing::error!("An error occurred while handling the client: {}", e); - } - }); +async fn local_tcp_server( + awake: Arc, +) -> anyhow::Result<(SocketAddr, JoinHandle>)> { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?; + let res = listener.local_addr()?; + let handle = tokio::task::spawn(async move { + loop { + let listener_future = listener.accept(); + let (stream, _) = match tokio::time::timeout(Duration::from_secs(1), listener_future).await { + Ok(res) => res, + Err(_) => { + if awake.load(std::sync::atomic::Ordering::Relaxed) { + return Ok::<(), anyhow::Error>(()); + } + continue; } - Err(e) => { - tracing::error!("Connection failed: {}", e); - return; + }?; + let awake = awake.clone(); + tokio::task::spawn(async move { + if let Err(e) = handle_client(stream, awake).await { + tracing::error!("An error occurred while handling the client: {}", e); } - } + }); } }); - - Ok(res) + Ok((res, handle)) } diff --git a/crates/network-scanner-proto/src/icmp_v4.rs b/crates/network-scanner-proto/src/icmp_v4.rs index f6d6fd3ed..98e18780b 100644 --- a/crates/network-scanner-proto/src/icmp_v4.rs +++ b/crates/network-scanner-proto/src/icmp_v4.rs @@ -17,6 +17,7 @@ pub enum Icmpv4MessageType { InformationReply = 16, } +#[derive(Debug)] pub enum Icmpv4Message { EchoReply { // type 0 @@ -213,6 +214,7 @@ impl Icmpv4Message { bytes } } +#[derive(Debug)] pub struct Icmpv4Packet { pub code: u8, pub checksum: u16, diff --git a/crates/network-scanner/Cargo.toml b/crates/network-scanner/Cargo.toml index af1413f3b..54d1f17b3 100644 --- a/crates/network-scanner/Cargo.toml +++ b/crates/network-scanner/Cargo.toml @@ -8,10 +8,11 @@ publish = false [dependencies] anyhow = "1.0.79" +futures = { version = "0.3.30" } network-scanner-net ={ path = "../network-scanner-net" } network-scanner-proto ={ path = "../network-scanner-proto" } socket2 = "0.5.5" -tokio = { version = "1.35.1", features = ["io-util"] } +tokio = { version = "1.35.1" } tracing = "0.1.40" [dev-dependencies] diff --git a/crates/network-scanner/examples/broadcast.rs b/crates/network-scanner/examples/broadcast.rs new file mode 100644 index 000000000..e116a0cda --- /dev/null +++ b/crates/network-scanner/examples/broadcast.rs @@ -0,0 +1,42 @@ +use futures::StreamExt; +use network_scanner::broadcast::broadcast; +use std::time::Duration; + +#[tokio::main] +pub async fn main() -> anyhow::Result<()> { + tracing_subscriber::fmt::SubscriberBuilder::default() + .with_max_level(tracing::Level::INFO) + .with_thread_names(true) + .init(); + + let ip = std::net::Ipv4Addr::new(192, 168, 1, 255); + let runtime = network_scanner_net::runtime::Socket2Runtime::new(None)?; + { + let socket = runtime.new_socket( + socket2::Domain::IPV4, + socket2::Type::RAW, + Some(socket2::Protocol::ICMPV4), + )?; + let mut stream = broadcast(ip, Some(Duration::from_secs(1)), socket).await?; + + while let Some(result) = stream.next().await { + match result { + Ok(res) => { + tracing::info!("received result {:?}", &res) + } + Err(e) => { + if let Some(e) = e.downcast_ref::() { + // if is timeout, say timeout then break + if let std::io::ErrorKind::TimedOut = e.kind() { + tracing::info!("timed out"); + break; + } + } + return Err(e); + } + } + } + } // drop socket + + Ok(()) +} diff --git a/crates/network-scanner/src/broadcast.rs b/crates/network-scanner/src/broadcast.rs new file mode 100644 index 000000000..13b872614 --- /dev/null +++ b/crates/network-scanner/src/broadcast.rs @@ -0,0 +1,227 @@ +use std::{ + mem::MaybeUninit, + net::{Ipv4Addr, SocketAddr}, + time::Duration, +}; + +use anyhow::Context; +use network_scanner_net::socket::AsyncRawSocket; +use network_scanner_proto::icmp_v4; +use socket2::SockAddr; + +use crate::ping::create_echo_request; + +#[derive(Debug)] +pub struct PingResponse { + pub addr: Ipv4Addr, + pub packet: icmp_v4::Icmpv4Packet, +} + +impl PingResponse { + pub(crate) unsafe fn from_raw( + addr: socket2::SockAddr, + payload: &[MaybeUninit], + size: usize, + ) -> anyhow::Result { + let addr = *addr + .as_socket_ipv4() + .with_context(|| "sock addr is not ipv4".to_string())? + .ip(); // ip is private + + let payload = payload[..size] + .as_ref() + .iter() + .map(|u| unsafe { u.assume_init() }) + .collect::>(); + + let packet = icmp_v4::Icmpv4Packet::parse(payload.as_slice())?; + + Ok(PingResponse { addr, packet }) + } + + pub fn verify(&self, verifier: &[u8]) -> bool { + if let icmp_v4::Icmpv4Message::EchoReply { payload, .. } = &self.packet.message { + payload == verifier + } else { + false + } + } +} + +type StreamReceiver = tokio::sync::mpsc::Receiver>, SockAddr), std::io::Error>>; +pub struct BroadcastStream { + receiver: StreamReceiver, + verifier: Vec, + should_verify: bool, +} + +impl BroadcastStream { + pub fn should_verify(&mut self, should_verify: bool) { + self.should_verify = should_verify; + } +} + +impl futures::stream::Stream for BroadcastStream { + type Item = anyhow::Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let ready_res = match self.receiver.poll_recv(cx) { + std::task::Poll::Ready(res) => res, + std::task::Poll::Pending => return std::task::Poll::Pending, + }; + + let not_none_res = match ready_res { + Some(res) => res, + None => return std::task::Poll::Ready(None), + }; + + let (buf, addr) = match not_none_res { + Ok(res) => res, + Err(e) => return std::task::Poll::Ready(Some(Err(e.into()))), + }; + + let ping_result = unsafe { PingResponse::from_raw(addr, &buf, buf.len()) }; + + let ping_result = match ping_result { + Ok(a) => a, + Err(e) => return std::task::Poll::Ready(Some(Err(e))), + }; + + if self.should_verify && !ping_result.verify(&self.verifier) { + return std::task::Poll::Ready(Some(Err(anyhow::anyhow!("failed to verify ping response")))); + } + + std::task::Poll::Ready(Some(Ok(ping_result))) + } +} + +/// Broadcasts a ping to the given ip address +/// caller need to make sure that the ip address is a broadcast address +pub async fn broadcast( + ip: Ipv4Addr, + read_time_out: Option, + mut socket: AsyncRawSocket, +) -> anyhow::Result { + socket.set_broadcast(true)?; + if let Some(time_out) = read_time_out { + socket.set_read_timeout(time_out)?; + } + let (packet, verifier) = create_echo_request()?; + socket + .send_to(&packet.to_bytes(true), &SockAddr::from(SocketAddr::new(ip.into(), 0))) + .await?; + let (sender, receiver) = tokio::sync::mpsc::channel(255); + + let _handle = tokio::task::spawn(async move { + let mut buffer = [MaybeUninit::uninit(); icmp_v4::ICMPV4_MTU]; + loop { + let future = socket.recv_from(&mut buffer); + + let result = match tokio::time::timeout(read_time_out.unwrap_or(Duration::from_secs(200)), future).await { + Ok(res) => res, + Err(e) => { + sender + .send(Err(std::io::Error::new(std::io::ErrorKind::TimedOut, e))) + .await?; + break; + } + }; + + let (size, addr) = match result { + Ok(res) => res, + Err(e) => { + if sender.send(Err(e)).await.is_err() { + tracing::error!("channel failed, sending Err to receiver"); + } + break; + } + }; + + let buffer_copy = buffer[..size].as_ref().to_vec(); + sender.send(Ok((buffer_copy, addr))).await?; + } + Ok::<(), anyhow::Error>(()) + }); + + Ok(BroadcastStream { + receiver, + verifier, + should_verify: true, + }) +} + +pub struct BorcastBlockStream { + socket: socket2::Socket, + verifier: Vec, + should_verify: bool, +} + +impl Iterator for BorcastBlockStream { + type Item = anyhow::Result; + + fn next(&mut self) -> Option { + let mut buffer = [MaybeUninit::uninit(); icmp_v4::ICMPV4_MTU]; + let res = self.socket.recv_from(&mut buffer); + + let (size, addr) = match res { + Ok(res) => res, + Err(e) => { + return Some(Err(e.into())); + } + }; + + if size == 0 { + return None; + } + + let ping_result = unsafe { PingResponse::from_raw(addr, &buffer, size) }; + + let ping_result = match ping_result { + Ok(a) => a, + Err(e) => return Some(Err(e)), + }; + + if self.should_verify && !ping_result.verify(&self.verifier) { + return Some(Err(anyhow::anyhow!("failed to verify ping response"))); + } + + Some(Ok(ping_result)) + } +} + +impl BorcastBlockStream { + pub fn should_verify(&mut self, should_verify: bool) { + self.should_verify = should_verify; + } +} + +pub fn block_broadcast(ip: Ipv4Addr, read_time_out: Option) -> anyhow::Result { + let socket = socket2::Socket::new( + socket2::Domain::IPV4, + socket2::Type::RAW, + Some(socket2::Protocol::ICMPV4), + )?; + socket.set_broadcast(true)?; + + if let Some(time_out) = read_time_out { + socket.set_read_timeout(Some(time_out))?; + } + + let addr = SocketAddr::new(ip.into(), 0); + + let (packet, verifier) = create_echo_request()?; + + tracing::trace!(?packet, "sending packet"); + socket + .send_to(&packet.to_bytes(true), &addr.into()) + .with_context(|| format!("Failed to send packet to {}", ip))?; + + Ok(BorcastBlockStream { + socket, + verifier, + should_verify: true, + }) +} diff --git a/crates/network-scanner/src/lib.rs b/crates/network-scanner/src/lib.rs index a766209cf..52a06922a 100644 --- a/crates/network-scanner/src/lib.rs +++ b/crates/network-scanner/src/lib.rs @@ -1 +1,2 @@ +pub mod broadcast; pub mod ping; diff --git a/crates/network-scanner/src/ping.rs b/crates/network-scanner/src/ping.rs index c16952b8b..d979a0d58 100644 --- a/crates/network-scanner/src/ping.rs +++ b/crates/network-scanner/src/ping.rs @@ -59,7 +59,7 @@ pub(crate) unsafe fn assume_init(buf: &[MaybeUninit]) -> &[u8] { &*(buf as *const [MaybeUninit] as *const [u8]) } -fn create_echo_request() -> anyhow::Result<(icmp_v4::Icmpv4Packet, Vec)> { +pub(crate) fn create_echo_request() -> anyhow::Result<(icmp_v4::Icmpv4Packet, Vec)> { let time = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .with_context(|| "failed to get current time")?