diff --git a/Cargo.lock b/Cargo.lock index af8c95d..c84896e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1742,6 +1742,7 @@ dependencies = [ "futures-core", "pin-project-lite", "tokio", + "tokio-util", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 1431aaa..0ccb03b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,7 +41,7 @@ thiserror = "1.0" time = "0.3" tokio = "1.23" tokio-serde = { version = "0.8", features = ["cbor"] } -tokio-stream = "0.1" +tokio-stream = { version = "0.1", features = ["sync"] } tokio-util = { version = "0.7", features = ["codec"] } toml = "0.7" tracing = "0.1" diff --git a/event-broker/Cargo.toml b/event-broker/Cargo.toml index a7195a5..13a8615 100644 --- a/event-broker/Cargo.toml +++ b/event-broker/Cargo.toml @@ -17,7 +17,7 @@ futures.workspace = true inotify.workspace = true libsystemd = { version = "0.7", optional = true } serde_cbor.workspace = true -tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } +tokio = { workspace = true, features = ["macros", "rt-multi-thread", "signal"] } tokio-serde.workspace = true tokio-stream.workspace = true tokio-util.workspace = true diff --git a/event-broker/src/main.rs b/event-broker/src/main.rs index 10daaf3..504a30c 100644 --- a/event-broker/src/main.rs +++ b/event-broker/src/main.rs @@ -5,23 +5,23 @@ use anyhow::bail; use anyhow::{Context as _, Result}; use crypto_auditing::types::EventGroup; -use futures::{future, stream::StreamExt, try_join, SinkExt, TryStreamExt}; -use inotify::{EventMask, Inotify, WatchMask}; +use futures::{future, stream::StreamExt, try_join, SinkExt, Stream, TryStreamExt}; +use inotify::{EventMask, EventStream, Inotify, WatchDescriptor, WatchMask}; #[cfg(feature = "libsystemd")] use libsystemd::activation::receive_descriptors; use serde_cbor::de::Deserializer; use std::collections::HashMap; +use std::fs; +use std::marker; use std::os::fd::{AsRawFd, RawFd}; #[cfg(feature = "libsystemd")] use std::os::fd::{FromRawFd, IntoRawFd}; use std::os::unix::net::UnixListener as StdUnixListener; use std::path::{Path, PathBuf}; use std::sync::Arc; -use tokio::net::{unix::OwnedWriteHalf, UnixListener}; -use tokio::sync::{ - mpsc::{self, Receiver, Sender}, - RwLock, -}; +use tokio::net::{unix::OwnedWriteHalf, UnixListener, UnixStream}; +use tokio::signal; +use tokio::sync::{broadcast, mpsc, RwLock}; use tokio_serde::{formats::SymmetricalCbor, SymmetricallyFramed}; use tokio_stream::wrappers::ReceiverStream; use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec}; @@ -32,38 +32,81 @@ mod config; struct Reader { log_file: PathBuf, + watch_descriptor: Option, } impl Reader { fn new(log_file: impl AsRef) -> Self { let log_file = log_file.as_ref().to_path_buf(); - Self { log_file } + Self { + log_file, + watch_descriptor: None, + } + } + + fn enable_monitor(&mut self, stream: &EventStream<&mut [u8; 1024]>) -> Result<()> { + if self.watch_descriptor.is_none() { + let watch_descriptor = stream + .watches() + .add(&self.log_file, WatchMask::MODIFY | WatchMask::CREATE) + .with_context(|| { + format!("unable to start monitoring {}", self.log_file.display()) + })?; + self.watch_descriptor = Some(watch_descriptor); + info!("enabled monitoring of {}", self.log_file.display()); + } + Ok(()) } - async fn read(&self, sender: Sender) -> Result<()> { + fn disable_monitor(&mut self, stream: &EventStream<&mut [u8; 1024]>) -> Result<()> { + if self.watch_descriptor.is_some() { + let watch_descriptor = self.watch_descriptor.take(); + stream + .watches() + .remove(watch_descriptor.unwrap()) + .with_context(|| { + format!("unable to stop monitoring {}", self.log_file.display()) + })?; + info!("disabled monitoring of {}", self.log_file.display()); + } + Ok(()) + } + + async fn read( + &mut self, + event_sender: &mpsc::Sender, + mut subscription_stream: impl Stream + marker::Unpin, + shutdown_receiver: &mut broadcast::Receiver<()>, + ) -> Result<()> { let inotify = Inotify::init().with_context(|| "unable to initialize inotify".to_string())?; - inotify - .watches() - .add(&self.log_file, WatchMask::MODIFY | WatchMask::CREATE) - .with_context(|| format!("unable to monitor {}", self.log_file.display()))?; - let mut file = std::fs::File::open(&self.log_file).ok(); + let mut file = fs::File::open(&self.log_file) + .with_context(|| format!("unable to open {}", self.log_file.display()))?; let mut buffer = [0; 1024]; - let mut stream = inotify.into_event_stream(&mut buffer)?; + let mut inotify_stream = inotify.into_event_stream(&mut buffer)?; - while let Some(event_or_error) = stream.next().await { - let event = event_or_error?; - if event.mask.contains(EventMask::CREATE) { - let new_file = std::fs::File::open(&self.log_file).with_context(|| { - format!("unable to read file `{}`", self.log_file.display()) - })?; - let _old = file.replace(new_file); - } - if let Some(ref file) = file { - for group in Deserializer::from_reader(file).into_iter::() { - sender.send(group?).await? + loop { + tokio::select! { + Some(event_or_error) = inotify_stream.next() => { + let event = event_or_error?; + if event.mask.contains(EventMask::CREATE) { + file = fs::File::open(&self.log_file).with_context(|| { + format!("unable to read file `{}`", self.log_file.display()) + })?; + } + for group in Deserializer::from_reader(&mut file).into_iter::() { + event_sender.send(group?).await? + } + }, + Some(n_subscriptions) = subscription_stream.next() => { + if n_subscriptions > 0 { + self.enable_monitor(&inotify_stream)?; + } else { + self.disable_monitor(&inotify_stream)?; + } } + _ = shutdown_receiver.recv() => break, } } @@ -86,6 +129,7 @@ struct Subscription { struct Publisher { socket_path: PathBuf, subscriptions: Arc>>, + activated: Arc>, } impl Publisher { @@ -94,92 +138,186 @@ impl Publisher { Self { socket_path, subscriptions: Arc::new(RwLock::new(HashMap::new())), + activated: Arc::new(RwLock::new(false)), } } #[cfg(feature = "libsystemd")] - fn get_std_listener(&self) -> Result { - if let Ok(mut descriptors) = receive_descriptors(false) { - if descriptors.len() > 1 { - bail!("too many file descriptors"); - } else if descriptors.is_empty() { - bail!("no file descriptors received"); + async fn get_std_listener(&self) -> Result { + match receive_descriptors(false) { + Ok(mut descriptors) => { + if descriptors.len() > 1 { + bail!("too many file descriptors"); + } else if descriptors.is_empty() { + bail!("no file descriptors received"); + } + let fd = descriptors.pop().unwrap().into_raw_fd(); + let mut activated = self.activated.write().await; + *activated = true; + Ok(unsafe { StdUnixListener::from_raw_fd(fd) }) + } + Err(e) => { + info!(error = %e, "unable to receive file descriptors"); + Ok(StdUnixListener::bind(&self.socket_path)?) } - let fd = descriptors.pop().unwrap().into_raw_fd(); - Ok(unsafe { StdUnixListener::from_raw_fd(fd) }) - } else { - Ok(StdUnixListener::bind(&self.socket_path)?) } } #[cfg(not(feature = "libsystemd"))] - fn get_std_listener(&self) -> Result { + async fn get_std_listener(&self) -> Result { Ok(StdUnixListener::bind(&self.socket_path)?) } - async fn publish(&self, receiver: Receiver) -> Result<()> { - let std_listener = self.get_std_listener()?; - std_listener.set_nonblocking(true)?; - let listener = UnixListener::from_std(std_listener)?; - let subscriptions = self.subscriptions.clone(); + async fn accept_subscriber( + &self, + stream: UnixStream, + subscription_sender: &mpsc::Sender, + ) -> Result<()> { + let subscriber_fd = stream.as_raw_fd(); - tokio::spawn(async move { - while let Ok((stream, _sock_addr)) = listener.accept().await { - let subscriber_fd = stream.as_raw_fd(); + debug!(socket = subscriber_fd, "subscriber connected"); - debug!(socket = subscriber_fd, "subscriber connected"); + let (de, ser) = stream.into_split(); - let (de, ser) = stream.into_split(); + let ser = FramedWrite::new(ser, LengthDelimitedCodec::new()); + let de = FramedRead::new(de, LengthDelimitedCodec::new()); - let ser = FramedWrite::new(ser, LengthDelimitedCodec::new()); - let de = FramedRead::new(de, LengthDelimitedCodec::new()); + let ser = SymmetricallyFramed::new(ser, SymmetricalCbor::::default()); + let mut de = SymmetricallyFramed::new(de, SymmetricalCbor::>::default()); - let ser = SymmetricallyFramed::new(ser, SymmetricalCbor::::default()); - let mut de = - SymmetricallyFramed::new(de, SymmetricalCbor::>::default()); + // Populate the scopes + if let Some(scopes) = de.try_next().await.unwrap() { + let mut subscriptions = self.subscriptions.write().await; + subscriptions.insert( + subscriber_fd, + Subscription { + stream: ser, + scopes, + errored: Default::default(), + }, + ); + subscription_sender.send(subscriptions.len()).await?; + } + Ok(()) + } - // Populate the scopes - if let Some(scopes) = de.try_next().await.unwrap() { - subscriptions.write().await.insert( - subscriber_fd, - Subscription { - stream: ser, - scopes, - errored: Default::default(), - }, - ); - } - } - }); + async fn listen( + &self, + subscription_sender: &mpsc::Sender, + shutdown_receiver: &mut broadcast::Receiver<()>, + ) -> Result<()> { + let std_listener = self.get_std_listener().await?; + std_listener.set_nonblocking(true)?; + let listener = UnixListener::from_std(std_listener)?; - let mut stream = ReceiverStream::new(receiver); - while let Some(group) = stream.next().await { - let mut subscriptions = self.subscriptions.write().await; - let mut publications = Vec::new(); - - for (_, subscription) in subscriptions.iter_mut() { - let mut group = group.clone(); - group.events_filtered(&subscription.scopes); - if !group.events().is_empty() { - publications.push(async move { - if let Err(e) = subscription.stream.send(group).await { - info!(error = %e, "unable to send event"); - subscription.errored = true; + loop { + tokio::select! { + maybe_stream = listener.accept() => { + let stream = match maybe_stream { + Ok((stream, _sock_addr)) => stream, + Err(e) => { + info!(error = %e, "unable to accept connection"); + break; } - }); - } + }; + if let Err(e) = self.accept_subscriber( + stream, + subscription_sender, + ).await { + info!(error = %e, "unable to accept subscriber"); + break; + } + }, + _ = shutdown_receiver.recv() => { + if !*self.activated.read().await { + drop(listener); + if let Err(e) = fs::remove_file(&self.socket_path) { + info!(error = %e, "error removing socket"); + } + } + break; + }, } + } + Ok(()) + } + + async fn publish_event( + &self, + group: &EventGroup, + subscription_sender: &mpsc::Sender, + ) -> Result<()> { + let mut subscriptions = self.subscriptions.write().await; + let mut publications = Vec::new(); + + let n_subscriptions = subscriptions.len(); + + for (_, subscription) in subscriptions.iter_mut() { + let mut group = group.clone(); + group.events_filtered(&subscription.scopes); + if !group.events().is_empty() { + publications.push(async move { + if let Err(e) = subscription.stream.send(group).await { + info!(error = %e, "unable to send event"); + subscription.errored = true; + } + }); + } + } + + future::join_all(publications).await; + + // Remove errored subscriptions + subscriptions.retain(|_, v| !v.errored); - future::join_all(publications).await; + if subscriptions.len() != n_subscriptions { + subscription_sender.send(subscriptions.len()).await?; + } + + Ok(()) + } - // Remove errored subscriptions - subscriptions.retain(|_, v| !v.errored); + async fn publish( + &self, + mut event_stream: impl Stream + marker::Unpin, + subscription_sender: &mpsc::Sender, + shutdown_receiver: &mut broadcast::Receiver<()>, + ) -> Result<()> { + loop { + tokio::select! { + Some(ref group) = event_stream.next() => { + self.publish_event( + group, + subscription_sender, + ).await? + }, + _ = shutdown_receiver.recv() => break, + } } Ok(()) } } +async fn shutdown( + shutdown_receiver: &mut broadcast::Receiver<()>, + shutdown_sender: &broadcast::Sender<()>, +) -> Result<()> { + tokio::select! { + maybe_value = signal::ctrl_c() => { + if let Err(e) = maybe_value { + info!(error = %e, "error receiving ctrl-c") + } + info!("shutting down event broker"); + if let Err(e) = shutdown_sender.send(()) { + info!(error = %e, "unable to send shutdown"); + } + }, + _ = shutdown_receiver.recv() => (), + } + Ok(()) +} + #[tokio::main] async fn main() -> anyhow::Result<()> { let config = config::Config::new()?; @@ -189,9 +327,25 @@ async fn main() -> anyhow::Result<()> { .with(EnvFilter::from_default_env()) .try_init()?; - let reader = Reader::new(&config.log_file); + let mut reader = Reader::new(&config.log_file); let publisher = Publisher::new(&config.socket_path); - let (tx, rx) = mpsc::channel::(10); - try_join!(reader.read(tx), publisher.publish(rx),).map(|_| ()) + let (event_tx, event_rx) = mpsc::channel::(10); + let mut event_rx = ReceiverStream::new(event_rx); + + let (subscription_tx, subscription_rx) = mpsc::channel::(10); + let mut subscription_rx = ReceiverStream::new(subscription_rx); + + let (shutdown_tx, mut shutdown_rx1) = broadcast::channel::<()>(2); + let mut shutdown_rx2 = shutdown_tx.subscribe(); + let mut shutdown_rx3 = shutdown_tx.subscribe(); + let mut shutdown_rx4 = shutdown_tx.subscribe(); + + try_join!( + shutdown(&mut shutdown_rx1, &shutdown_tx), + reader.read(&event_tx, &mut subscription_rx, &mut shutdown_rx2), + publisher.listen(&subscription_tx, &mut shutdown_rx3), + publisher.publish(&mut event_rx, &subscription_tx, &mut shutdown_rx4), + ) + .map(|_| ()) } diff --git a/event-broker/tests/test.rs b/event-broker/tests/test.rs index e46a987..ee23e16 100644 --- a/event-broker/tests/test.rs +++ b/event-broker/tests/test.rs @@ -5,7 +5,6 @@ use crypto_auditing::event_broker::Client; use futures::stream::StreamExt; use std::env; use std::fs; -use std::io::{Read, Write}; use std::path::PathBuf; use std::process::{Child, Command}; use std::thread; @@ -53,7 +52,7 @@ async fn test_event_broker() { let test_dir = tempdir().expect("unable to create temporary directory"); let log_path = test_dir.path().join("agent.log"); - let mut log_file = fs::OpenOptions::new() + let _log_file = fs::OpenOptions::new() .write(true) .create(true) .append(true) @@ -89,19 +88,13 @@ async fn test_event_broker() { let (_handle, mut reader) = client.start().await.expect("unable to start client"); - // Append more data to log file - let mut fixture_file = fs::OpenOptions::new() - .read(true) - .open(&fixture_dir().join("normal").join("output.cborseq")) - .expect("unable to open fixture"); - let mut buffer = Vec::new(); - fixture_file - .read_to_end(&mut buffer) - .expect("unable to read fixture"); - log_file - .write_all(&buffer) - .expect("unable to append fixture"); - log_file.flush().expect("unable to flush fixture"); + // Append more data to log file, from a separate process + let mut child = std::process::Command::new("cp") + .arg(&fixture_dir().join("normal").join("output.cborseq")) + .arg(&log_path) + .spawn() + .expect("unable to spawn cp"); assert!(reader.next().await.is_some()); + child.wait().expect("unable to wait child to complete"); }