From 8aeb858f937e5240f87f7546754fda861685990a Mon Sep 17 00:00:00 2001 From: Daiki Ueno Date: Sun, 26 Nov 2023 17:09:32 +0900 Subject: [PATCH] event-broker: Shut down activated process where possible When event-broker is launched through systemd activation and no subscribers remain, we don't need to poll inotify and incoming connections until a new subscriber is online. This also adds handling of Ctrl-C to stop the deamon if it is not activated through systemd. Signed-off-by: Daiki Ueno --- Cargo.lock | 1 + Cargo.toml | 2 +- event-broker/Cargo.toml | 2 +- event-broker/src/main.rs | 249 +++++++++++++++++++++++++++------------ 4 files changed, 176 insertions(+), 78 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c58489e..889b935 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1731,6 +1731,7 @@ dependencies = [ "futures-core", "pin-project-lite", "tokio", + "tokio-util", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 4bbce47..12d10d1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,7 +41,7 @@ thiserror = "1.0" time = "0.3" tokio = "1.23" tokio-serde = { version = "0.8", features = ["cbor"] } -tokio-stream = "0.1" +tokio-stream = { version = "0.1", features = ["sync"] } tokio-util = { version = "0.7", features = ["codec"] } toml = "0.7" tracing = "0.1" diff --git a/event-broker/Cargo.toml b/event-broker/Cargo.toml index a7195a5..13a8615 100644 --- a/event-broker/Cargo.toml +++ b/event-broker/Cargo.toml @@ -17,7 +17,7 @@ futures.workspace = true inotify.workspace = true libsystemd = { version = "0.7", optional = true } serde_cbor.workspace = true -tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } +tokio = { workspace = true, features = ["macros", "rt-multi-thread", "signal"] } tokio-serde.workspace = true tokio-stream.workspace = true tokio-util.workspace = true diff --git a/event-broker/src/main.rs b/event-broker/src/main.rs index 5b09c2b..1d7790a 100644 --- a/event-broker/src/main.rs +++ b/event-broker/src/main.rs @@ -5,20 +5,23 @@ use anyhow::bail; use anyhow::{Context as _, Result}; use crypto_auditing::types::EventGroup; -use futures::{future, stream::StreamExt, try_join, SinkExt, TryStreamExt}; +use futures::{future, stream::StreamExt, try_join, SinkExt, Stream, TryStreamExt}; use inotify::{EventMask, Inotify, WatchMask}; #[cfg(feature = "libsystemd")] use libsystemd::activation::receive_descriptors; use serde_cbor::de::Deserializer; use std::collections::HashMap; +use std::fs; +use std::marker; use std::os::fd::{AsRawFd, RawFd}; #[cfg(feature = "libsystemd")] use std::os::fd::{FromRawFd, IntoRawFd}; use std::os::unix::net::UnixListener as StdUnixListener; use std::path::{Path, PathBuf}; use std::sync::{Arc, RwLock}; -use tokio::net::{unix::OwnedWriteHalf, UnixListener}; -use tokio::sync::mpsc::{self, Receiver, Sender}; +use tokio::net::{unix::OwnedWriteHalf, UnixListener, UnixStream}; +use tokio::signal; +use tokio::sync::{broadcast, mpsc}; use tokio_serde::{formats::SymmetricalCbor, SymmetricallyFramed}; use tokio_stream::wrappers::ReceiverStream; use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec}; @@ -37,30 +40,37 @@ impl Reader { Self { log_file } } - async fn read(&self, sender: Sender) -> Result<()> { + async fn read( + &self, + event_sender: &mpsc::Sender, + shutdown_receiver: &mut broadcast::Receiver<()>, + ) -> Result<()> { let inotify = Inotify::init().with_context(|| "unable to initialize inotify".to_string())?; inotify .watches() .add(&self.log_file, WatchMask::MODIFY | WatchMask::CREATE) .with_context(|| format!("unable to monitor {}", self.log_file.display()))?; - let mut file = std::fs::File::open(&self.log_file).ok(); + let mut file = fs::File::open(&self.log_file) + .with_context(|| format!("unable to open {}", self.log_file.display()))?; let mut buffer = [0; 1024]; let mut stream = inotify.into_event_stream(&mut buffer)?; - while let Some(event_or_error) = stream.next().await { - let event = event_or_error?; - if event.mask.contains(EventMask::CREATE) { - let new_file = std::fs::File::open(&self.log_file).with_context(|| { - format!("unable to read file `{}`", self.log_file.display()) - })?; - let _old = file.replace(new_file); - } - if let Some(ref file) = file { - for group in Deserializer::from_reader(file).into_iter::() { - sender.send(group?).await? - } + loop { + tokio::select! { + Some(event_or_error) = stream.next() => { + let event = event_or_error?; + if event.mask.contains(EventMask::CREATE) { + file = fs::File::open(&self.log_file).with_context(|| { + format!("unable to read file `{}`", self.log_file.display()) + })?; + } + for group in Deserializer::from_reader(&mut file).into_iter::() { + event_sender.send(group?).await? + } + }, + _ = shutdown_receiver.recv() => break, } } @@ -83,6 +93,7 @@ struct Subscription { struct Publisher { socket_path: PathBuf, subscriptions: Arc>>, + activated: Arc>, } impl Publisher { @@ -91,21 +102,28 @@ impl Publisher { Self { socket_path, subscriptions: Arc::new(RwLock::new(HashMap::new())), + activated: Arc::new(RwLock::new(false)), } } #[cfg(feature = "libsystemd")] fn get_std_listener(&self) -> Result { - if let Ok(mut descriptors) = receive_descriptors(false) { - if descriptors.len() > 1 { - bail!("too many file descriptors"); - } else if descriptors.is_empty() { - bail!("no file descriptors received"); + match receive_descriptors(false) { + Ok(mut descriptors) => { + if descriptors.len() > 1 { + bail!("too many file descriptors"); + } else if descriptors.is_empty() { + bail!("no file descriptors received"); + } + let fd = descriptors.pop().unwrap().into_raw_fd(); + let mut activated = self.activated.write().unwrap(); + *activated = true; + Ok(unsafe { StdUnixListener::from_raw_fd(fd) }) + } + Err(e) => { + info!(error = %e, "unable to receive file descriptors"); + Ok(StdUnixListener::bind(&self.socket_path)?) } - let fd = descriptors.pop().unwrap().into_raw_fd(); - Ok(unsafe { StdUnixListener::from_raw_fd(fd) }) - } else { - Ok(StdUnixListener::bind(&self.socket_path)?) } } @@ -114,65 +132,113 @@ impl Publisher { Ok(StdUnixListener::bind(&self.socket_path)?) } - async fn listen(&self) -> Result<()> { - let std_listener = self.get_std_listener()?; - std_listener.set_nonblocking(true)?; - let listener = UnixListener::from_std(std_listener)?; + async fn accept_subscriber(&self, stream: UnixStream) -> Result<()> { + let subscriber_fd = stream.as_raw_fd(); - while let Ok((stream, _sock_addr)) = listener.accept().await { - let subscriber_fd = stream.as_raw_fd(); + debug!(socket = subscriber_fd, "subscriber connected"); - debug!(socket = subscriber_fd, "subscriber connected"); + let (de, ser) = stream.into_split(); - let (de, ser) = stream.into_split(); + let ser = FramedWrite::new(ser, LengthDelimitedCodec::new()); + let de = FramedRead::new(de, LengthDelimitedCodec::new()); - let ser = FramedWrite::new(ser, LengthDelimitedCodec::new()); - let de = FramedRead::new(de, LengthDelimitedCodec::new()); + let ser = SymmetricallyFramed::new(ser, SymmetricalCbor::::default()); + let mut de = SymmetricallyFramed::new(de, SymmetricalCbor::>::default()); + + // Populate the scopes + if let Some(scopes) = de.try_next().await.unwrap() { + self.subscriptions.write().unwrap().insert( + subscriber_fd, + Subscription { + stream: ser, + scopes, + errored: Default::default(), + }, + ); + } + Ok(()) + } - let ser = SymmetricallyFramed::new(ser, SymmetricalCbor::::default()); - let mut de = - SymmetricallyFramed::new(de, SymmetricalCbor::>::default()); + async fn listen(&self, shutdown_receiver: &mut broadcast::Receiver<()>) -> Result<()> { + let std_listener = self.get_std_listener()?; + std_listener.set_nonblocking(true)?; + let listener = UnixListener::from_std(std_listener)?; - // Populate the scopes - if let Some(scopes) = de.try_next().await.unwrap() { - self.subscriptions.write().unwrap().insert( - subscriber_fd, - Subscription { - stream: ser, - scopes, - errored: Default::default(), - }, - ); + loop { + tokio::select! { + maybe_stream = listener.accept() => { + let stream = match maybe_stream { + Ok((stream, _sock_addr)) => stream, + Err(e) => { + info!(error = %e, "unable to accept connection"); + break; + } + }; + if let Err(e) = self.accept_subscriber(stream).await { + info!(error = %e, "unable to accept subscriber"); + break; + } + }, + _ = shutdown_receiver.recv() => { + if !*self.activated.read().unwrap() { + drop(listener); + if let Err(e) = fs::remove_file(&self.socket_path) { + info!(error = %e, "error removing socket"); + } + } + break; + }, } } Ok(()) } - async fn publish(&self, receiver: Receiver) -> Result<()> { - let mut stream = ReceiverStream::new(receiver); - while let Some(group) = stream.next().await { - let mut subscriptions = self.subscriptions.write().unwrap(); - let mut publications = Vec::new(); - - for (_, subscription) in subscriptions.iter_mut() { - let mut group = group.clone(); - group.events_filtered(&subscription.scopes); - if !group.events().is_empty() { - publications.push(async move { - if let Err(e) = subscription.stream.send(group).await { - info!(error = %e, "unable to send event"); - subscription.errored = true; - } - }); - } + async fn publish_event( + &self, + group: &EventGroup, + shutdown_sender: &broadcast::Sender<()>, + ) -> Result<()> { + let mut subscriptions = self.subscriptions.write().unwrap(); + let mut publications = Vec::new(); + + for (_, subscription) in subscriptions.iter_mut() { + let mut group = group.clone(); + group.events_filtered(&subscription.scopes); + if !group.events().is_empty() { + publications.push(async move { + if let Err(e) = subscription.stream.send(group).await { + info!(error = %e, "unable to send event"); + subscription.errored = true; + } + }); } + } - future::join_all(publications).await; + future::join_all(publications).await; - // Remove errored subscriptions - subscriptions.retain(|_, v| !v.errored); - if subscriptions.is_empty() { - break; + // Remove errored subscriptions + subscriptions.retain(|_, v| !v.errored); + + if *self.activated.read().unwrap() && subscriptions.is_empty() { + info!("shutting down event broker"); + shutdown_sender.send(())?; + } + + Ok(()) + } + + async fn publish( + &self, + mut event_stream: impl Stream + marker::Unpin, + shutdown_receiver: &mut broadcast::Receiver<()>, + shutdown_sender: &broadcast::Sender<()>, + ) -> Result<()> { + loop { + tokio::select! { + Some(ref group) = event_stream.next() => { + self.publish_event(group, shutdown_sender).await? + }, + _ = shutdown_receiver.recv() => break, } } @@ -180,6 +246,28 @@ impl Publisher { } } +async fn shutdown( + shutdown_receiver: &mut broadcast::Receiver<()>, + shutdown_sender: &broadcast::Sender<()>, +) -> Result<()> { + loop { + tokio::select! { + maybe_value = signal::ctrl_c() => { + if let Err(e) = maybe_value { + info!(error = %e, "error receiving ctrl-c") + } + info!("shutting down event broker"); + if let Err(e) = shutdown_sender.send(()) { + info!(error = %e, "unable to send shutdown"); + } + break; + }, + _ = shutdown_receiver.recv() => break, + } + } + Ok(()) +} + #[tokio::main] async fn main() -> anyhow::Result<()> { let config = config::Config::new()?; @@ -192,10 +280,19 @@ async fn main() -> anyhow::Result<()> { let reader = Reader::new(&config.log_file); let publisher = Publisher::new(&config.socket_path); - let (tx, rx) = mpsc::channel::(10); + let (event_tx, event_rx) = mpsc::channel::(10); + let mut event_rx = ReceiverStream::new(event_rx); + + let (shutdown_tx, mut shutdown_rx1) = broadcast::channel::<()>(2); + let mut shutdown_rx2 = shutdown_tx.subscribe(); + let mut shutdown_rx3 = shutdown_tx.subscribe(); + let mut shutdown_rx4 = shutdown_tx.subscribe(); + try_join!( - reader.read(tx), - publisher.listen(), - publisher.publish(rx), - ).map(|_| ()) + shutdown(&mut shutdown_rx1, &shutdown_tx), + reader.read(&event_tx, &mut shutdown_rx2), + publisher.listen(&mut shutdown_rx3), + publisher.publish(&mut event_rx, &mut shutdown_rx4, &shutdown_tx), + ) + .map(|_| ()) }