diff --git a/Cargo.lock b/Cargo.lock index 9f372f919..ee23a4a6e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -533,16 +533,16 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.31" +version = "0.4.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f2c685bad3eb3d45a01354cedb7d5faa66194d1d58ba6e267a8de788f79db38" +checksum = "9f13690e35a5e4ace198e7beea2895d29f3a9cc55015fcebe6336bd2010af9eb" dependencies = [ "android-tzdata", "iana-time-zone", "js-sys", "num-traits", "wasm-bindgen", - "windows-targets 0.48.5", + "windows-targets 0.52.0", ] [[package]] @@ -2100,11 +2100,13 @@ version = "0.0.0" dependencies = [ "anyhow", "crossbeam", + "futures", "parking_lot", "polling", "socket2", "thiserror", "tokio", + "tokio-stream", "tracing", "tracing-subscriber", ] @@ -2621,18 +2623,18 @@ dependencies = [ [[package]] name = "pin-project" -version = "1.1.3" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fda4ed1c6c173e3fc7a83629421152e01d7b1f9b7f65fb301e490e8cfc656422" +checksum = "0302c4a0442c456bd56f841aee5c3bfd17967563f6fadc9ceb9f9c23cf3807e0" dependencies = [ "pin-project-internal", ] [[package]] name = "pin-project-internal" -version = "1.1.3" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405" +checksum = "266c042b60c9c76b8d53061e52b2e0d1116abc57cefc8c5cd671619a56ac3690" dependencies = [ "proc-macro2 1.0.78", "quote 1.0.35", @@ -2687,8 +2689,7 @@ checksum = "626dec3cac7cc0e1577a2ec3fc496277ec2baa084bebad95bb6fdbfae235f84c" [[package]] name = "polling" version = "3.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "545c980a3880efd47b2e262f6a4bb6daad6555cf3367aa9c4e52895f69537a41" +source = "git+https://github.com/smol-rs/polling?rev=62430fd56e668559d08ca7071ab13a0e116ba515#62430fd56e668559d08ca7071ab13a0e116ba515" dependencies = [ "cfg-if", "concurrent-queue", @@ -2975,7 +2976,7 @@ checksum = "b62dbe01f0b06f9d8dc7d49e05a0785f153b00b2c227856282f671e0318c9b15" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.4", + "regex-automata 0.4.5", "regex-syntax 0.8.2", ] @@ -2990,9 +2991,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.4" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b7fa1134405e2ec9353fd416b17f8dacd46c473d7d3fd1cf202706a14eb792a" +checksum = "5bb987efffd3c6d0d8f5f89510bb458559eab11e4f869acb20bf845e016259cd" dependencies = [ "aho-corasick", "memchr", @@ -3898,9 +3899,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tls_codec" -version = "0.4.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d38a1d5fcfa859f0ec2b5e111dc903890bd7dac7f34713232bf9aa4fd7cad7b2" +checksum = "b5e78c9c330f8c85b2bae7c8368f2739157db9991235123aa1b15ef9502bfb6a" dependencies = [ "tls_codec_derive", "zeroize", @@ -3908,9 +3909,9 @@ dependencies = [ [[package]] name = "tls_codec_derive" -version = "0.4.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8e00e3e7a54e0f1c8834ce72ed49c8487fbd3f801d8cfe1a0ad0640382f8e15" +checksum = "8d9ef545650e79f30233c0003bcc2504d7efac6dad25fca40744de773fe2049c" dependencies = [ "proc-macro2 1.0.78", "quote 1.0.35", @@ -4850,9 +4851,9 @@ checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" [[package]] name = "winnow" -version = "0.5.34" +version = "0.5.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b7cf47b659b318dccbd69cc4797a39ae128f533dce7902a1096044d1967b9c16" +checksum = "1931d78a9c73861da0134f453bb1f790ce49b2e30eba8410b4b79bac72b46a2d" dependencies = [ "memchr", ] diff --git a/crates/network-scanner-net/Cargo.toml b/crates/network-scanner-net/Cargo.toml index f1be3be94..cea420bc0 100644 --- a/crates/network-scanner-net/Cargo.toml +++ b/crates/network-scanner-net/Cargo.toml @@ -5,14 +5,15 @@ authors = ["Devolutions Inc. "] edition = "2021" publish = false -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] anyhow = "1.0.79" crossbeam = { version = "0.8.4", features = ["crossbeam-channel"] } +futures = "0.3.30" parking_lot = "0.12.1" -polling = "3.3.2" +polling = { git = "https://github.com/smol-rs/polling", rev = "62430fd56e668559d08ca7071ab13a0e116ba515" } socket2 = { version = "0.5.5", features = ["all"] } thiserror = "1.0.56" +tokio-stream = "0.1.14" tracing = "0.1.40" [dev-dependencies] diff --git a/crates/network-scanner-net/examples/intense_tcp.rs b/crates/network-scanner-net/examples/intense_tcp.rs new file mode 100644 index 000000000..bae4ec456 --- /dev/null +++ b/crates/network-scanner-net/examples/intense_tcp.rs @@ -0,0 +1,137 @@ +use std::{ + mem::MaybeUninit, + net::{SocketAddr, SocketAddrV4}, + sync::atomic::AtomicU32, + time::Instant, +}; + +use network_scanner_net::socket::AsyncRawSocket; +use socket2::SockAddr; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + task::JoinHandle, +}; + +#[tokio::main(flavor = "multi_thread", worker_threads = 12)] +pub async fn main() -> anyhow::Result<()> { + tracing_subscriber::fmt::SubscriberBuilder::default() + .with_max_level(tracing::Level::DEBUG) + .with_thread_names(true) + .init(); + // trace info all args + let args: Vec = std::env::args().collect(); + tracing::info!("Args: {:?}", std::env::args().collect::>()); + if args.len() < 4 { + println!("Usage: {} [server|client] -p ", args[0]); + return Ok(()); + } + + let port = args[args.len() - 1].parse::()?; + let addr = format!("127.0.0.1:{}", port); + match args[1].as_str() { + "server" => { + tcp_server(&addr).await?; + } + "client" => { + tcp_client().await?; + } + _ => { + println!("Usage: {} [server|client] -p ", args[0]); + } + } + + Ok(()) +} + +async fn tcp_client() -> anyhow::Result<()> { + let async_runtime = network_scanner_net::runtime::Socket2Runtime::new(None)?; + let mut socket_arr = vec![]; + for _ in 0..100 { + tracing::info!("Creating socket"); + let socket = async_runtime.new_socket(socket2::Domain::IPV4, socket2::Type::STREAM, None)?; + socket_arr.push(socket); + } + + fn connect_and_write_something( + mut socket: AsyncRawSocket, + addr: std::net::SocketAddr, + ) -> JoinHandle> { + tokio::task::spawn(async move { + let addr: SockAddr = addr.into(); + socket.connect(&addr).await?; + let mut buffer = [MaybeUninit::uninit(); 1024]; + for i in 0..1000 { + let data = format!("hello world {} times", i); + tracing::info!("Sending: {} from socket {:?}", &data, &socket); + let write_future = socket.send(data.as_bytes()); + + let size = tokio::time::timeout(std::time::Duration::from_secs(1), write_future).await??; + + if size == 0 { + return Ok(()); + } + + let recv_future = socket.recv(&mut buffer); + tokio::time::timeout(std::time::Duration::from_secs(1), recv_future).await??; + let received = buffer[..size] + .iter() + .map(|x| unsafe { x.assume_init() }) + .collect::>(); + assert_eq!(received, data.as_bytes()); + tracing::debug!("Received: {}", std::str::from_utf8(&received)?); + } + Ok(()) + }) + } + + let mut handles = vec![]; + for socket in socket_arr { + let addr: SocketAddr = SocketAddrV4::new(std::net::Ipv4Addr::new(127, 0, 0, 1), 8080).into(); + + let handle = connect_and_write_something(socket, addr); + handles.push(handle); + } + + for future in handles { + future.await??; + } + + Ok(()) +} + +async fn tcp_server(addr: &str) -> anyhow::Result<()> { + let listener = tokio::net::TcpListener::bind(addr).await?; + println!("Listening on: {}", addr); + let count = std::sync::Arc::new(AtomicU32::new(0)); + loop { + let (mut socket, _) = listener.accept().await?; + let _now = Instant::now(); + let count = count.clone(); + tokio::spawn(async move { + let mut buf = vec![0; 1024]; + loop { + let n = match socket.read(&mut buf).await { + // socket closed + Ok(n) => { + if n == 0 { + println!("Socket closed"); + return; + } + n + } + Err(e) => { + eprintln!("Failed to read from socket; err = {:?}", e); + return; + } + }; + println!("Received {} bytes", n); + // Write the data back + if let Err(e) = socket.write_all(&buf[0..n]).await { + eprintln!("Failed to write to socket; err = {:?}", e); + return; + } + count.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + } + }); + } +} diff --git a/crates/network-scanner-net/examples/tcp_fail.rs b/crates/network-scanner-net/examples/tcp_fail.rs new file mode 100644 index 000000000..bb688a817 --- /dev/null +++ b/crates/network-scanner-net/examples/tcp_fail.rs @@ -0,0 +1,83 @@ +use std::net::SocketAddr; + +#[tokio::main] +pub async fn main() -> anyhow::Result<()> { + utils::start_server(); + tracing_subscriber::fmt::SubscriberBuilder::default() + .with_max_level(tracing::Level::TRACE) + .with_thread_names(true) + .init(); + + let async_runtime = network_scanner_net::runtime::Socket2Runtime::new(None)?; + let good_socket = async_runtime.new_socket(socket2::Domain::IPV4, socket2::Type::STREAM, None)?; + let bad_socket = async_runtime.new_socket(socket2::Domain::IPV4, socket2::Type::STREAM, None)?; + + tracing::info!("good_socket: {:?}", good_socket); + tracing::info!("bad_socket: {:?}", bad_socket); + + let available_port = 8080; + let non_available_port = 12345; + + let available_addr = SocketAddr::from(([127, 0, 0, 1], available_port)); + let non_available_addr = SocketAddr::from(([127, 0, 0, 1], non_available_port)); + + let handle = tokio::task::spawn(async move { good_socket.connect(&socket2::SockAddr::from(available_addr)).await }); + + let handle2 = + tokio::task::spawn(async move { bad_socket.connect(&socket2::SockAddr::from(non_available_addr)).await }); + + let (a, b) = tokio::join!(handle, handle2); + // remove the outer error from tokio task + let a = a?; + let b = b?; + tracing::info!("should connect: {:?}", &a); + tracing::info!("should not connect: {:?}", &b); + assert!(a.is_ok()); + assert!(b.is_err()); + Ok(()) +} + +mod utils { + use std::io::Read; + use std::io::Write; + use std::net::{TcpListener, TcpStream}; + use std::thread; + + fn handle_client(mut stream: TcpStream) { + // read 20 bytes at a time from stream echoing back to stream + loop { + let mut read = [0; 1028]; + match stream.read(&mut read) { + Ok(n) => { + if n == 0 { + // connection was closed + break; + } + let _ = stream.write(&read[0..n]).unwrap(); + } + Err(_) => { + return; + } + } + } + } + + pub(super) fn start_server() { + thread::spawn(|| { + let listener = TcpListener::bind("127.0.0.1:8080").unwrap(); + + for stream in listener.incoming() { + match stream { + Ok(stream) => { + thread::spawn(move || { + handle_client(stream); + }); + } + Err(_) => { + // println!("Error"); + } + } + } + }); + } +} diff --git a/crates/network-scanner-net/src/runtime.rs b/crates/network-scanner-net/src/runtime.rs index 85c1d1861..cd81e125f 100644 --- a/crates/network-scanner-net/src/runtime.rs +++ b/crates/network-scanner-net/src/runtime.rs @@ -1,5 +1,6 @@ use std::{ - collections::HashMap, + collections::{HashMap, HashSet}, + hash::Hash, num::NonZeroUsize, sync::{ atomic::{AtomicBool, AtomicUsize, Ordering}, @@ -10,6 +11,7 @@ use std::{ use anyhow::Context; use crossbeam::channel::{Receiver, Sender}; +use parking_lot::Mutex; use polling::{Event, Events}; use socket2::Socket; @@ -17,16 +19,18 @@ use crate::{socket::AsyncRawSocket, ScannnerNetError}; #[derive(Debug)] pub struct Socket2Runtime { - poller: polling::Poller, + poller: Arc, next_socket_id: AtomicUsize, - is_terminated: AtomicBool, - sender: Sender, + is_terminated: Arc, + register_sender: Sender, + event_receiver: Receiver, + event_cache: Mutex>, } impl Drop for Socket2Runtime { fn drop(&mut self) { - self.is_terminated.store(true, Ordering::SeqCst); tracing::debug!("dropping runtime"); + self.is_terminated.store(true, Ordering::SeqCst); let _ = self // ignore errors, cannot handle it here .poller .notify() @@ -41,15 +45,22 @@ impl Socket2Runtime { /// Create a new runtime with a queue capacity, default is 1024. pub fn new(queue_capacity: Option) -> anyhow::Result> { let poller = polling::Poller::new()?; - let (sender, receiver) = crossbeam::channel::bounded(queue_capacity.unwrap_or(QUEUE_CAPACITY)); + + let (register_sender, register_receiver) = + crossbeam::channel::bounded(queue_capacity.unwrap_or(QUEUE_CAPACITY)); + + let (event_sender, event_receiver) = crossbeam::channel::bounded(queue_capacity.unwrap_or(QUEUE_CAPACITY)); + let runtime = Self { - poller, + poller: Arc::new(poller), next_socket_id: AtomicUsize::new(0), - is_terminated: AtomicBool::new(false), - sender, + is_terminated: Arc::new(AtomicBool::new(false)), + register_sender, + event_receiver, + event_cache: Mutex::new(HashSet::new()), }; let runtime = Arc::new(runtime); - runtime.clone().start_loop(receiver)?; + runtime.start_loop(register_receiver, event_sender)?; Ok(runtime) } @@ -67,37 +78,51 @@ impl Socket2Runtime { Ok(AsyncRawSocket::from_socket(socket, id, self.clone())?) } - pub(crate) fn remove_socket(&self, socket: &socket2::Socket) -> anyhow::Result<()> { + pub(crate) fn remove_socket(&self, socket: &socket2::Socket, id: usize) -> anyhow::Result<()> { self.poller.delete(socket)?; + // remove all events related to this socket + self.event_cache.lock().retain(|event| id == event.0.key); Ok(()) } - fn start_loop(self: Arc, receiver: Receiver) -> anyhow::Result<()> { + fn start_loop( + &self, + register_receiver: Receiver, + event_sender: Sender, + ) -> anyhow::Result<()> { + // we make is_terminated Arc and poller Arc so that we can clone them and move them into the thread + // we cannot hold a Arc in the thread, because it will create a cycle reference and the runtime will never be dropped. + let is_terminated = self.is_terminated.clone(); + let poller = self.poller.clone(); std::thread::Builder::new() .name("[raw-socket]:io-event-loop".to_string()) .spawn(move || { let mut events = Events::with_capacity(NonZeroUsize::new(1024).unwrap()); tracing::debug!("starting io event loop"); + // events registered but not happened yet let mut events_registered = HashMap::new(); - let mut events_happend = HashMap::new(); + + // events happened but not registered yet + let mut events_happened = HashMap::new(); + loop { - if self.is_terminated.load(Ordering::Acquire) { + if is_terminated.load(Ordering::Acquire) { break; } tracing::debug!("polling events"); - if let Err(e) = self.poller.wait(&mut events, None) { + if let Err(e) = poller.wait(&mut events, None) { tracing::error!(error = ?e, "failed to poll events"); - self.is_terminated.store(true, Ordering::SeqCst); + is_terminated.store(true, Ordering::SeqCst); break; }; for event in events.iter() { - tracing::trace!(?event, "event happend"); - events_happend.insert(event.key, event); + tracing::trace!(?event, "event happened"); + events_happened.insert(event.key, event); } events.clear(); - while let Ok(event) = receiver.try_recv() { + while let Ok(event) = register_receiver.try_recv() { match event { RegisterEvent::Register { id, waker } => { events_registered.insert(id, waker); @@ -108,25 +133,79 @@ impl Socket2Runtime { } } - let intersection = events_happend + let intersection = events_happened .keys() .filter(|key| events_registered.contains_key(key)) .cloned() .collect::>(); intersection.into_iter().for_each(|ref key| { - let event = events_happend.remove(key).unwrap(); + let event = events_happened.remove(key).unwrap(); let waker = events_registered.remove(key).unwrap(); + let _ = event_sender.try_send(event); waker.wake_by_ref(); tracing::trace!(?event, "waking up waker"); }); } + tracing::info!("io event loop terminated"); }) .with_context(|| "failed to spawn io event loop thread")?; Ok(()) } + /// Ideally, we should have a dedicated thread to handle events we received, but we don't really want to spawn a second thread + /// Alternatively, we can have all socket futures call this function to check if there is any event for them. + /// The number of times the socket futures is polled is almost guaranteed to be more than the number of registration we received. + /// hence the event receiver will not be blocked. + pub(crate) fn check_event(&self, event: Event, remove: bool) -> Option { + let mut event_cache = self.event_cache.lock(); + while let Ok(event) = self.event_receiver.try_recv() { + event_cache.insert(event.into()); + } + tracing::debug!("checking event cache {:?}", event_cache); + + let event = if remove { + event_cache.take(&event.into()) + } else if event_cache.contains(&event.into()) { + Some(event.into()) + } else { + None + }; + + event.map(|event| event.into_inner()) + } + + pub(crate) fn check_event_with_id(&self, id: usize, remove: bool) -> Vec { + let mut event_cache = self.event_cache.lock(); + while let Ok(event) = self.event_receiver.try_recv() { + event_cache.insert(event.into()); + } + let event_interested = vec![ + Event::readable(id), + Event::writable(id), + Event::all(id), + Event::none(id), + ]; + let mut res = vec![]; + + if remove { + event_interested.into_iter().for_each(|event| { + if let Some(event) = event_cache.take(&event.into()) { + res.push(event.into_inner()); + } + }); + } else { + event_interested.into_iter().for_each(|event| { + if event_cache.contains(&event.into()) { + res.push(event); + } + }); + } + + res + } + pub(crate) fn register(&self, socket: &Socket, event: Event, waker: Waker) -> anyhow::Result<()> { if self.is_terminated.load(Ordering::Acquire) { Err(ScannnerNetError::AsyncRuntimeError("runtime is terminated".to_string()))?; @@ -138,7 +217,7 @@ impl Socket2Runtime { // Use try_send instead of send, in case some io events blocked the queue completely, // it would be better to drop the register event then block the worker thread or main thread. // as the worker thread is shared for the entire application. - self.sender + self.register_sender .try_send(RegisterEvent::Register { id: event.key, waker }) .with_context(|| "failed to send register event to register loop") } @@ -147,7 +226,7 @@ impl Socket2Runtime { if self.is_terminated.load(Ordering::Acquire) { Err(ScannnerNetError::AsyncRuntimeError("runtime is terminated".to_string()))?; } - self.sender + self.register_sender .try_send(RegisterEvent::Unregister { id }) .with_context(|| "failed to send unregister event to register loop") } @@ -158,3 +237,34 @@ enum RegisterEvent { Register { id: usize, waker: Waker }, Unregister { id: usize }, } + +#[derive(Debug)] +pub struct EventWrapper(Event); + +impl Hash for EventWrapper { + fn hash(&self, state: &mut H) { + self.0.key.hash(state); + self.0.readable.hash(state); + self.0.writable.hash(state); + } +} + +impl PartialEq for EventWrapper { + fn eq(&self, other: &Self) -> bool { + self.0.key == other.0.key && self.0.readable == other.0.readable && self.0.writable == other.0.writable + } +} + +impl Eq for EventWrapper {} + +impl From for EventWrapper { + fn from(event: Event) -> Self { + Self(event) + } +} + +impl EventWrapper { + pub(crate) fn into_inner(self) -> Event { + self.0 + } +} diff --git a/crates/network-scanner-net/src/socket.rs b/crates/network-scanner-net/src/socket.rs index e78371d19..61f305c89 100644 --- a/crates/network-scanner-net/src/socket.rs +++ b/crates/network-scanner-net/src/socket.rs @@ -1,5 +1,5 @@ use polling::Event; -use std::{future::Future, mem::MaybeUninit, sync::Arc, usize}; +use std::{fmt::Debug, future::Future, mem::MaybeUninit, sync::Arc, usize}; use socket2::{SockAddr, Socket}; use std::result::Result::Ok; @@ -13,12 +13,21 @@ pub struct AsyncRawSocket { id: usize, } +impl Debug for AsyncRawSocket { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AsyncRawSocket") + .field("socket", &self.socket) + .field("id", &self.id) + .finish() + } +} + impl Drop for AsyncRawSocket { fn drop(&mut self) { tracing::trace!(id = %self.id,socket = ?self.socket, "drop socket"); let _ = self // We ignore errors here, avoid crashing the thread .runtime - .remove_socket(&self.socket) + .remove_socket(&self.socket, self.id) .map_err(|e| tracing::error!("failed to remove socket from poller: {:?}", e)); } } @@ -98,7 +107,6 @@ impl<'a> AsyncRawSocket { runtime: self.runtime.clone(), addr, id: self.id, - is_first_poll: true, } } @@ -133,6 +141,9 @@ struct RecvFromFuture<'a> { impl Future for RecvFromFuture<'_> { type Output = std::io::Result<(usize, SockAddr)>; fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { + // By checking event at every call, it removes excessive events excessive in the cache and the channel + self.runtime.check_event(Event::readable(self.id), true); + let socket = &self.socket.clone(); // avoid borrow checker error match socket.recv_from(self.buf) { Ok(a) => std::task::Poll::Ready(Ok(a)), @@ -153,6 +164,7 @@ impl<'a> Future for SendToFuture<'a> { type Output = std::io::Result; fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { + self.runtime.check_event(Event::writable(self.id), true); match self.socket.send_to(self.data, self.addr) { Ok(a) => std::task::Poll::Ready(Ok(a)), Err(e) => resolve(e, &self.socket, &self.runtime, Event::writable(self.id), cx.waker()), @@ -170,6 +182,7 @@ impl Future for AcceptFuture { type Output = std::io::Result<(AsyncRawSocket, SockAddr)>; fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { + self.runtime.check_event(Event::readable(self.id), true); match self.socket.accept() { Ok((socket, addr)) => { let socket = AsyncRawSocket::from_socket(socket, self.id, self.runtime.clone())?; @@ -184,50 +197,53 @@ struct ConnectFuture<'a> { runtime: Arc, id: usize, addr: &'a socket2::SockAddr, - is_first_poll: bool, } impl<'a> Future for ConnectFuture<'a> { type Output = std::io::Result<()>; - fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { - if self.is_first_poll { - tracing::trace!("first poll connect future"); - // cannot call connect twice - self.is_first_poll = false; - let err = match self.socket.connect(self.addr) { - Ok(a) => { - return std::task::Poll::Ready(Ok(a)); - } - Err(e) => e, - }; - - // code 115, EINPROGRESS, only for linux - // reference: https://linux.die.net/man/2/connect - // it is the same as WouldBlock but for connect(2) only - #[cfg(target_os = "linux")] - let in_progress = err.kind() == std::io::ErrorKind::WouldBlock || err.raw_os_error() == Some(115); - - #[cfg(not(target_os = "linux"))] - let in_progress = err.kind() == std::io::ErrorKind::WouldBlock; - - if in_progress { - tracing::trace!("connect should register"); - if let Err(e) = self - .runtime - .register(&self.socket, Event::all(self.id), cx.waker().clone()) - { - tracing::warn!(?self.socket, ?self.addr, "failed to register socket to poller"); - return std::task::Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::Other, - format!("failed to register socket to poller: {}", e), - ))); - } - return std::task::Poll::Pending; + fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { + let events = self.runtime.check_event_with_id(self.id, true); + if events.iter().any(|e| e.is_connect_failed().unwrap_or(false)) { + tracing::warn!(?self.socket, ?self.addr, "connect failed"); + return std::task::Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, "connect failed"))); + }; + + if !events.is_empty() { + tracing::trace!(?events, "connect success"); + return std::task::Poll::Ready(Ok(())); + } + + let err = match self.socket.connect(self.addr) { + Ok(a) => { + return std::task::Poll::Ready(Ok(a)); + } + Err(e) => e, + }; + + // code 115, EINPROGRESS, only for linux + // reference: https://linux.die.net/man/2/connect + // it is the same as WouldBlock but for connect(2) only + #[cfg(target_os = "linux")] + let in_progress = err.kind() == std::io::ErrorKind::WouldBlock || err.raw_os_error() == Some(115); + + #[cfg(not(target_os = "linux"))] + let in_progress = err.kind() == std::io::ErrorKind::WouldBlock; + + if in_progress { + tracing::trace!("connect should register"); + if let Err(e) = self + .runtime + .register(&self.socket, Event::all(self.id), cx.waker().clone()) + { + tracing::warn!(?self.socket, ?self.addr, "failed to register socket to poller"); + return std::task::Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::Other, + format!("failed to register socket to poller: {}", e), + ))); } } - tracing::trace!("second poll connect future"); - std::task::Poll::Ready(Ok(())) + std::task::Poll::Pending } } @@ -242,6 +258,7 @@ impl<'a> Future for SendFuture<'a> { type Output = std::io::Result; fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { + self.runtime.check_event(Event::writable(self.id), true); match self.socket.send(self.data) { Ok(a) => std::task::Poll::Ready(Ok(a)), Err(e) => resolve(e, &self.socket, &self.runtime, Event::writable(self.id), cx.waker()), @@ -260,6 +277,7 @@ impl<'a> Future for RecvFuture<'a> { type Output = std::io::Result; fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { + self.runtime.check_event(Event::readable(self.id), true); let socket = &self.socket.clone(); // avoid borrow checker error match socket.recv(self.buf) { Ok(a) => std::task::Poll::Ready(Ok(a)), diff --git a/crates/network-scanner-net/src/test.rs b/crates/network-scanner-net/src/test.rs index 114c2e0c7..20c80368d 100644 --- a/crates/network-scanner-net/src/test.rs +++ b/crates/network-scanner-net/src/test.rs @@ -1,26 +1,21 @@ use std::{ - io::{ErrorKind, Read, Write}, + io::ErrorKind, mem::MaybeUninit, net::{SocketAddr, UdpSocket}, + sync::{atomic::AtomicBool, Arc}, + time::Duration, }; use socket2::SockAddr; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + task::JoinHandle, +}; use crate::socket::AsyncRawSocket; -#[ignore] -#[tokio::test(flavor = "multi_thread", worker_threads = 1)] -async fn test_connectivity() -> anyhow::Result<()> { - let addr = local_tcp_server()?; - let runtime = crate::runtime::Socket2Runtime::new(None)?; - let socket = runtime.new_socket(socket2::Domain::IPV4, socket2::Type::STREAM, None)?; - socket.connect(&socket2::SockAddr::from(addr)).await?; - Ok(()) -} - -#[ignore] #[tokio::test(flavor = "multi_thread", worker_threads = 4)] +#[ignore = "TODO"] async fn multiple_udp() -> anyhow::Result<()> { let addr = local_udp_server()?; tokio::time::sleep(std::time::Duration::from_millis(200)).await; // wait for the other socket to start @@ -30,21 +25,22 @@ async fn multiple_udp() -> anyhow::Result<()> { let socket2 = runtime.new_socket(socket2::Domain::IPV4, socket2::Type::DGRAM, None)?; let socket3 = runtime.new_socket(socket2::Domain::IPV4, socket2::Type::DGRAM, None)?; - fn send_to(mut socket: AsyncRawSocket, number: u8, addr: SocketAddr) -> tokio::task::JoinHandle<()> { + fn send_to( + mut socket: AsyncRawSocket, + number: u8, + addr: SocketAddr, + ) -> tokio::task::JoinHandle> { tokio::task::spawn(async move { let msg = format!("hello from socket {}", number); - socket - .send_to(msg.as_bytes(), &SockAddr::from(addr)) - .await - .expect("send_to"); + socket.send_to(msg.as_bytes(), &SockAddr::from(addr)).await?; + let mut buf = [MaybeUninit::::uninit(); 1024]; - let (size, addr) = socket - .recv_from(&mut buf) - .await - .unwrap_or_else(|_| panic!("recv_from: {}", number)); + let (size, addr) = socket.recv_from(&mut buf).await?; + tracing::info!("size: {}, addr: {:?}", size, addr); let back = unsafe { crate::assume_init(&buf[..size]) }; assert_eq!(back, format!("hello from socket {}", number).as_bytes()); + Ok::<(), anyhow::Error>(()) }) } @@ -57,32 +53,59 @@ async fn multiple_udp() -> anyhow::Result<()> { ]; for handle in handles { - handle.await?; + tokio::time::timeout(Duration::from_secs(10), handle).await???; } Ok(()) } -#[ignore] #[tokio::test(flavor = "multi_thread", worker_threads = 4)] +#[ignore = "TODO"] +async fn test_connectivity() -> anyhow::Result<()> { + let kill_server = Arc::new(AtomicBool::new(false)); + let (addr, handle) = local_tcp_server(kill_server.clone()).await?; + tokio::time::sleep(std::time::Duration::from_millis(200)).await; // wait for the other socket to start + + let runtime = crate::runtime::Socket2Runtime::new(None)?; + let socket = runtime.new_socket(socket2::Domain::IPV4, socket2::Type::STREAM, None)?; + let addr: SockAddr = addr.into(); + socket.connect(&addr).await?; + + // clean up + kill_server.store(true, std::sync::atomic::Ordering::Relaxed); + handle.abort(); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 6)] +#[ignore = "TODO"] async fn multiple_tcp() -> anyhow::Result<()> { - let addr = local_tcp_server()?; + let kill_server = Arc::new(AtomicBool::new(false)); + let (addr, handle) = local_tcp_server(kill_server.clone()).await?; + tokio::time::sleep(std::time::Duration::from_millis(200)).await; // wait for the other socket to start + let runtime = crate::runtime::Socket2Runtime::new(None)?; let socket0 = runtime.new_socket(socket2::Domain::IPV4, socket2::Type::STREAM, None)?; let socket1 = runtime.new_socket(socket2::Domain::IPV4, socket2::Type::STREAM, None)?; let socket2 = runtime.new_socket(socket2::Domain::IPV4, socket2::Type::STREAM, None)?; let socket3 = runtime.new_socket(socket2::Domain::IPV4, socket2::Type::STREAM, None)?; - fn connect(mut socket: AsyncRawSocket, number: u8, addr: SocketAddr) -> tokio::task::JoinHandle<()> { + fn connect( + mut socket: AsyncRawSocket, + number: u8, + addr: SocketAddr, + ) -> tokio::task::JoinHandle> { tokio::task::spawn(async move { - socket.connect(&socket2::SockAddr::from(addr)).await.expect("connect"); + socket.connect(&socket2::SockAddr::from(addr)).await?; let msg = format!("hello from socket {}", number); - socket.send(msg.as_bytes()).await.expect("send"); + socket.send(msg.as_bytes()).await?; let mut buf = [MaybeUninit::::uninit(); 1024]; - let size = socket.recv(&mut buf).await.expect("recv"); + let size = socket.recv(&mut buf).await?; tracing::info!("size: {}", size); let back = unsafe { crate::assume_init(&buf[..size]) }; assert_eq!(back, format!("hello from socket {}", number).as_bytes()); + Ok::<(), anyhow::Error>(()) }) } @@ -95,16 +118,23 @@ async fn multiple_tcp() -> anyhow::Result<()> { ]; for handle in handles { - handle.await?; + tokio::time::timeout(Duration::from_secs(5), handle).await???; } + // clean up + kill_server.store(true, std::sync::atomic::Ordering::Relaxed); + handle.abort(); + Ok(()) } -#[ignore] #[tokio::test(flavor = "multi_thread", worker_threads = 4)] +#[ignore = "TODO"] async fn work_with_tokio_tcp() -> anyhow::Result<()> { - let addr = local_tcp_server()?; + let kill_server = Arc::new(AtomicBool::new(false)); + let (addr, tcp_handle) = local_tcp_server(kill_server.clone()).await?; + tokio::time::sleep(std::time::Duration::from_millis(200)).await; // wait for the other socket to start + let runtime = crate::runtime::Socket2Runtime::new(None)?; let mut socket = runtime.new_socket(socket2::Domain::IPV4, socket2::Type::STREAM, None)?; @@ -115,7 +145,7 @@ async fn work_with_tokio_tcp() -> anyhow::Result<()> { socket.send(msg.as_bytes()).await?; let mut buf = [MaybeUninit::::uninit(); 1024]; let size = socket.recv(&mut buf).await?; - tracing::info!("size: {}", size); + tracing::debug!("size: {}", size); let back = unsafe { crate::assume_init(&buf[..size]) }; assert_eq!(back, msg.as_bytes()); } @@ -137,13 +167,65 @@ async fn work_with_tokio_tcp() -> anyhow::Result<()> { Ok::<(), anyhow::Error>(()) }); - let (a, b) = tokio::join!(handle, handle2); - a??; - b??; + tokio::time::timeout(Duration::from_secs(5), handle).await???; + tokio::time::timeout(Duration::from_secs(5), handle2).await???; + + // clean up + kill_server.store(true, std::sync::atomic::Ordering::Relaxed); + tcp_handle.abort(); Ok(()) } +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +pub async fn drop_runtime() -> anyhow::Result<()> { + let kill_server = Arc::new(AtomicBool::new(false)); + let (addr, handle) = local_tcp_server(kill_server.clone()).await?; + tracing_subscriber::fmt::SubscriberBuilder::default() + .with_max_level(tracing::Level::TRACE) + .with_thread_names(true) + .init(); + { + let async_runtime = crate::runtime::Socket2Runtime::new(None)?; + { + let good_socket = async_runtime.new_socket(socket2::Domain::IPV4, socket2::Type::STREAM, None)?; + let bad_socket = async_runtime.new_socket(socket2::Domain::IPV4, socket2::Type::STREAM, None)?; + + tracing::info!("good_socket: {:?}", good_socket); + tracing::info!("bad_socket: {:?}", bad_socket); + + let unused_port = 12345; + + let available_addr = addr; + let non_available_addr = SocketAddr::from(([127, 0, 0, 1], unused_port)); + + let handle = + tokio::task::spawn(async move { good_socket.connect(&socket2::SockAddr::from(available_addr)).await }); + + let handle2 = + tokio::task::spawn( + async move { bad_socket.connect(&socket2::SockAddr::from(non_available_addr)).await }, + ); + + let (a, b) = tokio::join!(handle, handle2); + // remove the outer error from tokio task + let a = a?; + let b = b?; + tracing::info!("should connect: {:?}", &a); + tracing::info!("should not connect: {:?}", &b); + assert!(a.is_ok()); + assert!(b.is_err()); + } + tracing::info!("runtime arc count: {}", Arc::strong_count(&async_runtime)); + + assert!(Arc::strong_count(&async_runtime) == 1); + } + tracing::info!("runtime should be dropped here"); + kill_server.store(true, std::sync::atomic::Ordering::Relaxed); + handle.abort(); + Ok(()) +} + fn local_udp_server() -> anyhow::Result { // Spawn a new thread let socket = UdpSocket::bind("127.0.0.1:0")?; @@ -151,18 +233,19 @@ fn local_udp_server() -> anyhow::Result { std::thread::spawn(move || { // Create and bind the UDP socket - println!("UDP server listening on {}", socket.local_addr()?); + tracing::debug!("UDP server listening on {}", socket.local_addr()?); let mut buffer = [0u8; 1024]; // A buffer to store incoming data loop { match socket.recv_from(&mut buffer) { Ok((size, src)) => { - println!("Received {} bytes from {}", size, src); - let socket_clone = socket.try_clone().expect("Failed to clone socket"); + tracing::trace!("Received {} bytes from {}", size, src); + let socket_clone = socket.try_clone()?; std::thread::spawn(move || { std::thread::sleep(std::time::Duration::from_millis(200)); // simulate some work - socket_clone.send_to(&buffer[..size], src).expect("Failed to send data") + socket_clone.send_to(&buffer[..size], src)?; + Ok::<(), anyhow::Error>(()) }); } Err(ref e) if e.kind() == ErrorKind::WouldBlock => { @@ -180,42 +263,55 @@ fn local_udp_server() -> anyhow::Result { Ok(res) } -fn handle_client(mut stream: std::net::TcpStream) -> std::io::Result<()> { +async fn handle_client(mut stream: tokio::net::TcpStream, awake: Arc) -> std::io::Result<()> { let mut buffer = [0; 1024]; loop { - // Read data from the stream - let size = stream.read(&mut buffer)?; - println!("Received {} bytes: {:?}", size, &buffer[..size]); + let read_future = stream.read(&mut buffer); + + let size = match tokio::time::timeout(Duration::from_secs(1), read_future).await { + Ok(res) => res?, + Err(_) => { + if awake.load(std::sync::atomic::Ordering::Relaxed) { + return Ok(()); + } + continue; + } + }; + + if size == 0 { + return Ok(()); + } + + tracing::debug!("Received {} bytes: {:?}", size, &buffer[..size]); std::thread::sleep(std::time::Duration::from_millis(200)); // simulate some work - stream.write_all(&buffer[..size])?; // Echo the data back to the client - println!("Echoed back {} bytes", size); + stream.write_all(&buffer[..size]).await?; // Echo the data back to the client } } -fn local_tcp_server() -> anyhow::Result { - // Bind the TCP listener to a local address - let listener = std::net::TcpListener::bind("127.0.0.1:0").expect("Could not bind TCP listener"); - println!("TCP server listening on {}", listener.local_addr().unwrap()); - let res = listener.local_addr().unwrap(); - std::thread::spawn(move || { - // Accept incoming connections - for stream in listener.incoming() { - match stream { - Ok(stream) => { - // Spawn a new thread for each connection - std::thread::spawn(move || { - if let Err(e) = handle_client(stream) { - tracing::error!("An error occurred while handling the client: {}", e); - } - }); +async fn local_tcp_server( + awake: Arc, +) -> anyhow::Result<(SocketAddr, JoinHandle>)> { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?; + let res = listener.local_addr()?; + let handle = tokio::task::spawn(async move { + loop { + let listener_future = listener.accept(); + let (stream, _) = match tokio::time::timeout(Duration::from_secs(1), listener_future).await { + Ok(res) => res, + Err(_) => { + if awake.load(std::sync::atomic::Ordering::Relaxed) { + return Ok::<(), anyhow::Error>(()); + } + continue; } - Err(e) => { - tracing::error!("Connection failed: {}", e); - return; + }?; + let awake = awake.clone(); + tokio::task::spawn(async move { + if let Err(e) = handle_client(stream, awake).await { + tracing::error!("An error occurred while handling the client: {}", e); } - } + }); } }); - - Ok(res) + Ok((res, handle)) }