Skip to content

Commit

Permalink
Add more implementation details (WIP)
Browse files Browse the repository at this point in the history
  • Loading branch information
tbraun96 committed Aug 3, 2024
1 parent e2fd429 commit f6cc31b
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 26 deletions.
98 changes: 98 additions & 0 deletions citadel-internal-service-macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use proc_macro::TokenStream;
use quote::quote;
use syn::spanned::Spanned;
use syn::{parse_macro_input, Data, DataEnum, DeriveInput, Ident};

#[proc_macro_derive(IsError)]
Expand All @@ -12,6 +13,103 @@ pub fn is_notification_derive(input: TokenStream) -> TokenStream {
generate_function(input, "Notification", "is_notification")
}

// Create a proc macro that generates a function that goes through each enum variant, looks at the first and only item in the variant, and creates a function called request_id(&self) -> Option<&Uuid>, that looks at the field "request_id" in the variant and returns a reference to it if it exists.
#[proc_macro_derive(RequestId)]
pub fn request_id_derive(input: TokenStream) -> TokenStream {
generate_field_function(input, "request_id", "request_id")
}

fn generate_field_function(
input: TokenStream,
field_name: &str,
function_name: &str,
) -> TokenStream {
// Parse the input tokens into a syntax tree
let input = parse_macro_input!(input as DeriveInput);

// Extract the identifier and data from the input
let name = &input.ident;
let data = if let Data::Enum(data) = input.data {
data
} else {
// This macro only supports enums
panic!("{function_name} can only be derived for enums");
};

// Convert the function name to a tokenstream
let function_name = Ident::new(function_name, name.span());

// Generate match arms for each enum variant
let match_arms = generate_field_match_arms(name, &data, field_name);

// Generate the implementation of the `is_error` method
let expanded = quote! {
impl #name {
pub fn #function_name(&self) -> Option<&Uuid> {
match self {
#(#match_arms)*
}
}
}
};

// Convert into a TokenStream and return it
TokenStream::from(expanded)
}

fn generate_field_match_arms(
name: &Ident,
data_enum: &DataEnum,
field_name: &str,
) -> Vec<proc_macro2::TokenStream> {
data_enum
.variants
.iter()
.map(|variant| {
let variant_ident = &variant.ident;
let field = variant.fields.iter().next().unwrap();
let field_name = field_name.to_string();
let field_name = Ident::new(&field_name, field.span());

// Determine if the enum is of the form Enum::Variant(inner) or Enum::Variant { inner, .. }
let is_tuple_variant = variant
.fields
.iter()
.next()
.map_or(false, |field| field.ident.is_none());
if is_tuple_variant {
// See if the type is a Uuid or an Option<Uuid>
if let syn::Type::Path(type_path) = &field.ty {
if type_path.path.segments.len() == 1 {
if type_path.path.segments[0].ident == "Uuid" {
return quote! {
#name::#variant_ident(inner) => Some(&inner.#field_name),
};
}
}
}

// See if "inner" has a field called "request_id"
if field.ident.is_none() {
return quote! {
#name::#variant_ident(_) => None,
};
}

// Match against each variant, ignoring any inner data
quote! {
#name::#variant_ident(inner, ..) => inner.#field_name.as_ref(),
}
} else {
// Match against each variant, ignoring any inner data
quote! {
#name::#variant_ident { #field_name, .. } => Some(#field_name.as_ref()),
}
}
})
.collect()
}

fn generate_function(input: TokenStream, contains: &str, function_name: &str) -> TokenStream {
// Parse the input tokens into a syntax tree
let input = parse_macro_input!(input as DeriveInput);
Expand Down
6 changes: 3 additions & 3 deletions citadel-internal-service-types/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use bytes::BytesMut;
use citadel_internal_service_macros::{IsError, IsNotification};
use citadel_internal_service_macros::{IsError, IsNotification, RequestId};
pub use citadel_types::prelude::{
ConnectMode, MemberState, MessageGroupKey, ObjectTransferStatus, SecBuffer, SecurityLevel,
SessionSecuritySettings, TransferType, UdpMode, UserIdentifier, VirtualObjectMetadata,
Expand Down Expand Up @@ -566,7 +566,7 @@ pub struct FileTransferTickNotification {
pub status: ObjectTransferStatus,
}

#[derive(Serialize, Deserialize, Debug, Clone, IsError, IsNotification)]
#[derive(Serialize, Deserialize, Debug, Clone, IsError, IsNotification, RequestId)]
pub enum InternalServiceResponse {
ConnectSuccess(ConnectSuccess),
ConnectFailure(ConnectFailure),
Expand Down Expand Up @@ -647,7 +647,7 @@ pub enum InternalServiceResponse {
ListRegisteredPeersFailure(ListRegisteredPeersFailure),
}

#[derive(Serialize, Deserialize, Debug, Clone)]
#[derive(Serialize, Deserialize, Debug, Clone, RequestId)]
pub enum InternalServiceRequest {
Connect {
// A user-provided unique ID that will be returned in the response
Expand Down
64 changes: 41 additions & 23 deletions citadel-messaging/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@ use citadel_internal_service_connector::io_interface::IOInterface;
use citadel_internal_service_types::messaging_layer::{
CWMessage, MessengerUpdate, OutgoingCWMessage,
};
use citadel_internal_service_types::{
InternalServicePayload, InternalServiceRequest, InternalServiceResponse,
};
use citadel_internal_service_types::{InternalServiceRequest, InternalServiceResponse};
use futures::sink::SinkExt;
use futures::StreamExt;
use serde::{Deserialize, Serialize};
Expand All @@ -21,9 +19,9 @@ pub struct Messenger<T: IOInterface> {
tx_to_subscriber: UnboundedSender<MessengerUpdate>,
rx_from_messenger: Option<UnboundedReceiver<MessengerUpdate>>,
internal_service_senders:
Arc<Mutex<HashMap<Uuid, tokio::sync::oneshot::Sender<InternalServiceResponse>>>>,
Arc<Mutex<HashMap<Uuid, tokio::sync::oneshot::Sender<MessengerUpdate>>>>,
internal_service_listeners:
Arc<Mutex<HashMap<Uuid, tokio::sync::oneshot::Receiver<InternalServiceResponse>>>>,
Arc<Mutex<HashMap<Uuid, tokio::sync::oneshot::Receiver<MessengerUpdate>>>>,
is_running: Arc<AtomicBool>,
}

Expand Down Expand Up @@ -68,16 +66,21 @@ impl<T: IOInterface> Messenger<T> {

while let Some(response) = stream.next().await {
let mut lock = interal_service_senders_clone.lock().await;
if let Some(tx) = lock.remove(response.request_id()) {
if let Err(response) = tx.send(response) {
// Send through subscriber as backup
if let Err(_err) = tx_to_subscriber_clone.send(response) {}

if let Some(uuid) = response.request_id() {
if let Some(tx) = lock.remove(uuid) {
if let Err(response) = tx.send(response) {
// Send through subscriber as backup
if let Err(_err) = tx_to_subscriber_clone.send(response) {}
}

continue;
}
} else {
// Send through normal channel
let signal = MessengerUpdate::Other { response };
if let Err(_err) = tx_to_subscriber_clone.send(signal) {}
}

// Send through normal channel
let signal = MessengerUpdate::Other { response };
if let Err(_err) = tx_to_subscriber_clone.send(signal) {}
}
};

Expand Down Expand Up @@ -213,7 +216,7 @@ impl<T: IOInterface> Messenger<T> {
largest_peer_cid = message.peer_cid;
}

let response = self.send_and_wait_for_response(request_id, command).await?;
let response = self.send_and_wait_for_response(command).await?;
if let InternalServiceResponse::LocalDBSetKVSuccess(_) = response {
// Continue
} else {
Expand All @@ -234,7 +237,7 @@ impl<T: IOInterface> Messenger<T> {
};

if let InternalServiceResponse::LocalDBSetKVSuccess(_) = self
.send_and_wait_for_response(request_id, request_for_largest_received)
.send_and_wait_for_response(request_for_largest_received)
.await?
{
// Continue
Expand All @@ -248,23 +251,40 @@ impl<T: IOInterface> Messenger<T> {

pub async fn send_and_wait_for_response(
&mut self,
request_id: Uuid,
internal_service_request: InternalServiceRequest,
) -> std::io::Result<InternalServiceResponse> {
let request_id = internal_service_request
.request_id()
.copied()
.ok_or_else(|| {
generic_std_error("Failed to get request id for internal service request")
})?;

self.register_listener_internal(request_id).await;
self.connection.send(internal_service_request).await?;

let rx = self
.internal_service_listeners
.lock()
.await
.remove(&request_id)
.ok_or_else(|| generic_std_error("Failed to get listener for request"))?;

rx.await.map_err(|_| {
let recv = rx.await.map_err(|_| {
generic_std_error(format!(
"Failed to get response for request: {request_id:?}"
))
})
});

match recv.map_err(|err| generic_std_error(format!("Failed to get response: {err}")))? {
update @ MessengerUpdate::Message { .. } => {
self.tx_to_subscriber
.send(update)
.map_err(|err| generic_std_error(format!("Failed to send message: {err}")))?;
}

MessengerUpdate::Other { response } => Ok(response),
}
}

async fn register_listener_internal(&self, uuid: Uuid) {
Expand All @@ -289,9 +309,7 @@ impl<T: IOInterface> Messenger<T> {
key: generate_highest_message_id_key_for_cid_received(cid, peer_cid),
};

let response = self
.send_and_wait_for_response(request_id, request_for_largest)
.await?;
let response = self.send_and_wait_for_response(request_for_largest).await?;
if let InternalServiceResponse::LocalDBGetKVSuccess(value) = response {
let highest_value = be_vec_to_u64(&value.value)
.ok_or_else(|| generic_std_error("Invalid highest CID encoding"))?;
Expand All @@ -318,7 +336,7 @@ impl<T: IOInterface> Messenger<T> {
key: key.clone(),
};

let response = self.send_and_wait_for_response(request_id, request).await?;
let response = self.send_and_wait_for_response(request).await?;

let mut db = if let InternalServiceResponse::LocalDBGetKVSuccess(value) = response {
let db = bincode2::deserialize(&value.value).map_err(|err| {
Expand Down Expand Up @@ -349,7 +367,7 @@ impl<T: IOInterface> Messenger<T> {
value: serialized,
};

let response = self.send_and_wait_for_response(request_id, request).await?;
let response = self.send_and_wait_for_response(request).await?;

if let InternalServiceResponse::LocalDBSetKVSuccess(_) = response {
Ok(())
Expand Down

0 comments on commit f6cc31b

Please sign in to comment.