From f6cc31b1bad2d4e41e323d9053a03160ded6d20d Mon Sep 17 00:00:00 2001 From: Thomas Braun Date: Sat, 3 Aug 2024 12:10:09 -0400 Subject: [PATCH] Add more implementation details (WIP) --- citadel-internal-service-macros/src/lib.rs | 98 ++++++++++++++++++++++ citadel-internal-service-types/src/lib.rs | 6 +- citadel-messaging/src/lib.rs | 64 +++++++++----- 3 files changed, 142 insertions(+), 26 deletions(-) diff --git a/citadel-internal-service-macros/src/lib.rs b/citadel-internal-service-macros/src/lib.rs index 1bb4af3..f9b2f5b 100644 --- a/citadel-internal-service-macros/src/lib.rs +++ b/citadel-internal-service-macros/src/lib.rs @@ -1,5 +1,6 @@ use proc_macro::TokenStream; use quote::quote; +use syn::spanned::Spanned; use syn::{parse_macro_input, Data, DataEnum, DeriveInput, Ident}; #[proc_macro_derive(IsError)] @@ -12,6 +13,103 @@ pub fn is_notification_derive(input: TokenStream) -> TokenStream { generate_function(input, "Notification", "is_notification") } +// Create a proc macro that generates a function that goes through each enum variant, looks at the first and only item in the variant, and creates a function called request_id(&self) -> Option<&Uuid>, that looks at the field "request_id" in the variant and returns a reference to it if it exists. +#[proc_macro_derive(RequestId)] +pub fn request_id_derive(input: TokenStream) -> TokenStream { + generate_field_function(input, "request_id", "request_id") +} + +fn generate_field_function( + input: TokenStream, + field_name: &str, + function_name: &str, +) -> TokenStream { + // Parse the input tokens into a syntax tree + let input = parse_macro_input!(input as DeriveInput); + + // Extract the identifier and data from the input + let name = &input.ident; + let data = if let Data::Enum(data) = input.data { + data + } else { + // This macro only supports enums + panic!("{function_name} can only be derived for enums"); + }; + + // Convert the function name to a tokenstream + let function_name = Ident::new(function_name, name.span()); + + // Generate match arms for each enum variant + let match_arms = generate_field_match_arms(name, &data, field_name); + + // Generate the implementation of the `is_error` method + let expanded = quote! { + impl #name { + pub fn #function_name(&self) -> Option<&Uuid> { + match self { + #(#match_arms)* + } + } + } + }; + + // Convert into a TokenStream and return it + TokenStream::from(expanded) +} + +fn generate_field_match_arms( + name: &Ident, + data_enum: &DataEnum, + field_name: &str, +) -> Vec { + data_enum + .variants + .iter() + .map(|variant| { + let variant_ident = &variant.ident; + let field = variant.fields.iter().next().unwrap(); + let field_name = field_name.to_string(); + let field_name = Ident::new(&field_name, field.span()); + + // Determine if the enum is of the form Enum::Variant(inner) or Enum::Variant { inner, .. } + let is_tuple_variant = variant + .fields + .iter() + .next() + .map_or(false, |field| field.ident.is_none()); + if is_tuple_variant { + // See if the type is a Uuid or an Option + if let syn::Type::Path(type_path) = &field.ty { + if type_path.path.segments.len() == 1 { + if type_path.path.segments[0].ident == "Uuid" { + return quote! { + #name::#variant_ident(inner) => Some(&inner.#field_name), + }; + } + } + } + + // See if "inner" has a field called "request_id" + if field.ident.is_none() { + return quote! { + #name::#variant_ident(_) => None, + }; + } + + // Match against each variant, ignoring any inner data + quote! { + #name::#variant_ident(inner, ..) => inner.#field_name.as_ref(), + } + } else { + // Match against each variant, ignoring any inner data + quote! { + #name::#variant_ident { #field_name, .. } => Some(#field_name.as_ref()), + } + } + }) + .collect() +} + fn generate_function(input: TokenStream, contains: &str, function_name: &str) -> TokenStream { // Parse the input tokens into a syntax tree let input = parse_macro_input!(input as DeriveInput); diff --git a/citadel-internal-service-types/src/lib.rs b/citadel-internal-service-types/src/lib.rs index 6d92a0d..082dca5 100644 --- a/citadel-internal-service-types/src/lib.rs +++ b/citadel-internal-service-types/src/lib.rs @@ -1,5 +1,5 @@ use bytes::BytesMut; -use citadel_internal_service_macros::{IsError, IsNotification}; +use citadel_internal_service_macros::{IsError, IsNotification, RequestId}; pub use citadel_types::prelude::{ ConnectMode, MemberState, MessageGroupKey, ObjectTransferStatus, SecBuffer, SecurityLevel, SessionSecuritySettings, TransferType, UdpMode, UserIdentifier, VirtualObjectMetadata, @@ -566,7 +566,7 @@ pub struct FileTransferTickNotification { pub status: ObjectTransferStatus, } -#[derive(Serialize, Deserialize, Debug, Clone, IsError, IsNotification)] +#[derive(Serialize, Deserialize, Debug, Clone, IsError, IsNotification, RequestId)] pub enum InternalServiceResponse { ConnectSuccess(ConnectSuccess), ConnectFailure(ConnectFailure), @@ -647,7 +647,7 @@ pub enum InternalServiceResponse { ListRegisteredPeersFailure(ListRegisteredPeersFailure), } -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Serialize, Deserialize, Debug, Clone, RequestId)] pub enum InternalServiceRequest { Connect { // A user-provided unique ID that will be returned in the response diff --git a/citadel-messaging/src/lib.rs b/citadel-messaging/src/lib.rs index acbf4d1..3938f09 100644 --- a/citadel-messaging/src/lib.rs +++ b/citadel-messaging/src/lib.rs @@ -3,9 +3,7 @@ use citadel_internal_service_connector::io_interface::IOInterface; use citadel_internal_service_types::messaging_layer::{ CWMessage, MessengerUpdate, OutgoingCWMessage, }; -use citadel_internal_service_types::{ - InternalServicePayload, InternalServiceRequest, InternalServiceResponse, -}; +use citadel_internal_service_types::{InternalServiceRequest, InternalServiceResponse}; use futures::sink::SinkExt; use futures::StreamExt; use serde::{Deserialize, Serialize}; @@ -21,9 +19,9 @@ pub struct Messenger { tx_to_subscriber: UnboundedSender, rx_from_messenger: Option>, internal_service_senders: - Arc>>>, + Arc>>>, internal_service_listeners: - Arc>>>, + Arc>>>, is_running: Arc, } @@ -68,16 +66,21 @@ impl Messenger { while let Some(response) = stream.next().await { let mut lock = interal_service_senders_clone.lock().await; - if let Some(tx) = lock.remove(response.request_id()) { - if let Err(response) = tx.send(response) { - // Send through subscriber as backup - if let Err(_err) = tx_to_subscriber_clone.send(response) {} + + if let Some(uuid) = response.request_id() { + if let Some(tx) = lock.remove(uuid) { + if let Err(response) = tx.send(response) { + // Send through subscriber as backup + if let Err(_err) = tx_to_subscriber_clone.send(response) {} + } + + continue; } - } else { - // Send through normal channel - let signal = MessengerUpdate::Other { response }; - if let Err(_err) = tx_to_subscriber_clone.send(signal) {} } + + // Send through normal channel + let signal = MessengerUpdate::Other { response }; + if let Err(_err) = tx_to_subscriber_clone.send(signal) {} } }; @@ -213,7 +216,7 @@ impl Messenger { largest_peer_cid = message.peer_cid; } - let response = self.send_and_wait_for_response(request_id, command).await?; + let response = self.send_and_wait_for_response(command).await?; if let InternalServiceResponse::LocalDBSetKVSuccess(_) = response { // Continue } else { @@ -234,7 +237,7 @@ impl Messenger { }; if let InternalServiceResponse::LocalDBSetKVSuccess(_) = self - .send_and_wait_for_response(request_id, request_for_largest_received) + .send_and_wait_for_response(request_for_largest_received) .await? { // Continue @@ -248,11 +251,18 @@ impl Messenger { pub async fn send_and_wait_for_response( &mut self, - request_id: Uuid, internal_service_request: InternalServiceRequest, ) -> std::io::Result { + let request_id = internal_service_request + .request_id() + .copied() + .ok_or_else(|| { + generic_std_error("Failed to get request id for internal service request") + })?; + self.register_listener_internal(request_id).await; self.connection.send(internal_service_request).await?; + let rx = self .internal_service_listeners .lock() @@ -260,11 +270,21 @@ impl Messenger { .remove(&request_id) .ok_or_else(|| generic_std_error("Failed to get listener for request"))?; - rx.await.map_err(|_| { + let recv = rx.await.map_err(|_| { generic_std_error(format!( "Failed to get response for request: {request_id:?}" )) - }) + }); + + match recv.map_err(|err| generic_std_error(format!("Failed to get response: {err}")))? { + update @ MessengerUpdate::Message { .. } => { + self.tx_to_subscriber + .send(update) + .map_err(|err| generic_std_error(format!("Failed to send message: {err}")))?; + } + + MessengerUpdate::Other { response } => Ok(response), + } } async fn register_listener_internal(&self, uuid: Uuid) { @@ -289,9 +309,7 @@ impl Messenger { key: generate_highest_message_id_key_for_cid_received(cid, peer_cid), }; - let response = self - .send_and_wait_for_response(request_id, request_for_largest) - .await?; + let response = self.send_and_wait_for_response(request_for_largest).await?; if let InternalServiceResponse::LocalDBGetKVSuccess(value) = response { let highest_value = be_vec_to_u64(&value.value) .ok_or_else(|| generic_std_error("Invalid highest CID encoding"))?; @@ -318,7 +336,7 @@ impl Messenger { key: key.clone(), }; - let response = self.send_and_wait_for_response(request_id, request).await?; + let response = self.send_and_wait_for_response(request).await?; let mut db = if let InternalServiceResponse::LocalDBGetKVSuccess(value) = response { let db = bincode2::deserialize(&value.value).map_err(|err| { @@ -349,7 +367,7 @@ impl Messenger { value: serialized, }; - let response = self.send_and_wait_for_response(request_id, request).await?; + let response = self.send_and_wait_for_response(request).await?; if let InternalServiceResponse::LocalDBSetKVSuccess(_) = response { Ok(())