From 275653297b6854a3a1a920348c24b776009c2fab Mon Sep 17 00:00:00 2001 From: dragonn Date: Thu, 18 Apr 2024 15:53:54 +0200 Subject: [PATCH] add after_dispatch hook --- examples/macro/async_blinky/src/main.rs | 8 +++++++- macro/src/analyze.rs | 13 +++++++++++++ macro/src/codegen.rs | 7 +++++++ macro/src/lower.rs | 6 ++++++ statig/src/awaitable/state.rs | 12 +++++++++++- statig/src/awaitable/superstate.rs | 10 +++++++++- statig/src/blocking/state.rs | 12 +++++++++++- statig/src/blocking/superstate.rs | 10 +++++++++- statig/src/into_state_machine.rs | 5 +++++ statig/src/lib.rs | 5 +++++ 10 files changed, 83 insertions(+), 5 deletions(-) diff --git a/examples/macro/async_blinky/src/main.rs b/examples/macro/async_blinky/src/main.rs index 4d0eb9a..66290d3 100644 --- a/examples/macro/async_blinky/src/main.rs +++ b/examples/macro/async_blinky/src/main.rs @@ -30,7 +30,9 @@ pub enum Event { // Set the `on_transition` callback. on_transition = "Self::on_transition", // Set the `on_dispatch` callback. - on_dispatch = "Self::on_dispatch" + on_dispatch = "Self::on_dispatch", + // Set the `on_dispatch` callback. + after_dispatch = "Self::after_dispatch" )] impl Blinky { #[action] @@ -86,6 +88,10 @@ impl Blinky { fn on_dispatch(&mut self, state: StateOrSuperstate, event: &Event) { println!("dispatching `{event:?}` to `{state:?}`"); } + + fn after_dispatch(&mut self, state: StateOrSuperstate, event: &Event) { + println!("dispatched `{event:?}` to `{state:?}`"); + } } #[tokio::main] diff --git a/macro/src/analyze.rs b/macro/src/analyze.rs index 1e1f90d..9ba798a 100644 --- a/macro/src/analyze.rs +++ b/macro/src/analyze.rs @@ -52,6 +52,8 @@ pub struct StateMachine { pub on_transition: Option, /// Optional `on_dispatch` callback. pub on_dispatch: Option, + /// Optional `after_dispatch` callback. + pub after_dispatch: Option, } /// Information regarding a state. @@ -181,6 +183,7 @@ pub fn analyze_state_machine(attribute_args: &AttributeArgs, item_impl: &ItemImp let mut on_transition = None; let mut on_dispatch = None; + let mut after_dispatch = None; let mut visibility = parse_quote!(pub); let mut event_ident = parse_quote!(event); @@ -232,6 +235,14 @@ pub fn analyze_state_machine(attribute_args: &AttributeArgs, item_impl: &ItemImp _ => abort!(name_value, "must be a string literal"), } } + NestedMeta::Meta(Meta::NameValue(name_value)) + if name_value.path.is_ident("after_dispatch") => + { + after_dispatch = match &name_value.lit { + Lit::Str(input_pat) => Some(input_pat.parse().unwrap()), + _ => abort!(name_value, "must be a string literal"), + } + } NestedMeta::Meta(Meta::NameValue(name_value)) if name_value.path.is_ident("visibility") => { @@ -340,6 +351,7 @@ pub fn analyze_state_machine(attribute_args: &AttributeArgs, item_impl: &ItemImp superstate_ident, superstate_derives, on_dispatch, + after_dispatch, on_transition, event_ident, context_ident, @@ -661,6 +673,7 @@ fn valid_state_analyze() { superstate_derives, on_transition, on_dispatch, + after_dispatch, event_ident, context_ident, visibility, diff --git a/macro/src/codegen.rs b/macro/src/codegen.rs index 42d1a1b..7be562f 100644 --- a/macro/src/codegen.rs +++ b/macro/src/codegen.rs @@ -72,6 +72,12 @@ fn codegen_state_machine_impl(ir: &Ir) -> ItemImpl { const ON_DISPATCH: fn(&mut Self, StateOrSuperstate<'_, '_, Self>, &Self::Event<'_>) = #on_dispatch; ), }; + let after_dispatch = match &ir.state_machine.after_dispatch { + None => quote!(), + Some(after_dispatch) => quote!( + const AFTER_DISPATCH: fn(&mut Self, StateOrSuperstate<'_, '_, Self>, &Self::Event<'_>) = #after_dispatch; + ), + }; parse_quote!( impl #impl_generics statig::#mode::IntoStateMachine for #shared_storage_type #where_clause @@ -85,6 +91,7 @@ fn codegen_state_machine_impl(ir: &Ir) -> ItemImpl { #on_transition #on_dispatch + #after_dispatch } ) } diff --git a/macro/src/lower.rs b/macro/src/lower.rs index d32a1da..15006a3 100644 --- a/macro/src/lower.rs +++ b/macro/src/lower.rs @@ -60,6 +60,8 @@ pub struct StateMachine { pub on_transition: Option, /// The path of the `on_dispatch` callback. pub on_dispatch: Option, + /// The path of the `after_dispatch` callback. + pub after_dispatch: Option, /// The visibility for the derived types, pub visibility: Visibility, /// The external input pattern. @@ -138,6 +140,7 @@ pub fn lower(model: &Model) -> Ir { let superstate_ident = model.state_machine.superstate_ident.clone(); let on_transition = model.state_machine.on_transition.clone(); let on_dispatch = model.state_machine.on_dispatch.clone(); + let after_dispatch = model.state_machine.after_dispatch.clone(); let event_ident = model.state_machine.event_ident.clone(); let context_ident = model.state_machine.context_ident.clone(); let shared_storage_type = model.state_machine.shared_storage_type.clone(); @@ -422,6 +425,7 @@ pub fn lower(model: &Model) -> Ir { superstate_generics, on_transition, on_dispatch, + after_dispatch, visibility, event_ident, context_ident, @@ -709,6 +713,7 @@ fn create_analyze_state_machine() -> analyze::StateMachine { superstate_derives: vec![parse_quote!(Copy), parse_quote!(Clone)], on_transition: None, on_dispatch: None, + after_dispatch: None, visibility: parse_quote!(pub), event_ident: parse_quote!(input), context_ident: parse_quote!(context), @@ -734,6 +739,7 @@ fn create_lower_state_machine() -> StateMachine { superstate_generics, on_transition: None, on_dispatch: None, + after_dispatch: None, visibility: parse_quote!(pub), event_ident: parse_quote!(input), context_ident: parse_quote!(context), diff --git a/statig/src/awaitable/state.rs b/statig/src/awaitable/state.rs index f3efe41..f178f7c 100644 --- a/statig/src/awaitable/state.rs +++ b/statig/src/awaitable/state.rs @@ -115,6 +115,8 @@ where let response = self.call_handler(shared_storage, event, context).await; + M::AFTER_DISPATCH(shared_storage, StateOrSuperstate::State(self), event); + match response { Response::Handled => Response::Handled, Response::Super => match self.superstate() { @@ -125,7 +127,15 @@ where event, ); - superstate.handle(shared_storage, event, context).await + let response = superstate.handle(shared_storage, event, context).await; + + M::AFTER_DISPATCH( + shared_storage, + StateOrSuperstate::Superstate(&superstate), + event, + ); + + response } None => Response::Super, }, diff --git a/statig/src/awaitable/superstate.rs b/statig/src/awaitable/superstate.rs index 509fb79..d29b9ed 100644 --- a/statig/src/awaitable/superstate.rs +++ b/statig/src/awaitable/superstate.rs @@ -123,7 +123,15 @@ where event, ); - superstate.handle(shared_storage, event, context).await + let response = superstate.handle(shared_storage, event, context).await; + + M::AFTER_DISPATCH( + shared_storage, + StateOrSuperstate::Superstate(&superstate), + event, + ); + + response } None => Response::Super, }, diff --git a/statig/src/blocking/state.rs b/statig/src/blocking/state.rs index f3bf7ce..7139143 100644 --- a/statig/src/blocking/state.rs +++ b/statig/src/blocking/state.rs @@ -100,6 +100,8 @@ where let response = self.call_handler(shared_storage, event, context); + M::AFTER_DISPATCH(shared_storage, StateOrSuperstate::State(self), event); + match response { Response::Handled => Response::Handled, Response::Super => match self.superstate() { @@ -110,7 +112,15 @@ where event, ); - superstate.handle(shared_storage, event, context) + let response = superstate.handle(shared_storage, event, context); + + M::AFTER_DISPATCH( + shared_storage, + StateOrSuperstate::Superstate(&superstate), + event, + ); + + response } None => Response::Super, }, diff --git a/statig/src/blocking/superstate.rs b/statig/src/blocking/superstate.rs index 296e85e..32d8582 100644 --- a/statig/src/blocking/superstate.rs +++ b/statig/src/blocking/superstate.rs @@ -109,7 +109,15 @@ where event, ); - superstate.handle(shared_storage, event, context) + let response = superstate.handle(shared_storage, event, context); + + M::AFTER_DISPATCH( + shared_storage, + StateOrSuperstate::Superstate(&superstate), + event, + ); + + response } None => Response::Super, }, diff --git a/statig/src/into_state_machine.rs b/statig/src/into_state_machine.rs index bffea31..932aefa 100644 --- a/statig/src/into_state_machine.rs +++ b/statig/src/into_state_machine.rs @@ -27,6 +27,11 @@ where const ON_DISPATCH: fn(&mut Self, StateOrSuperstate<'_, '_, Self>, &Self::Event<'_>) = |_, _, _| {}; + /// Method that is called *after* an event is dispatched to a state or + /// superstate handler. + const AFTER_DISPATCH: fn(&mut Self, StateOrSuperstate<'_, '_, Self>, &Self::Event<'_>) = + |_, _, _| {}; + /// Method that is called *after* every transition. const ON_TRANSITION: fn(&mut Self, &Self::State, &Self::State) = |_, _, _| {}; } diff --git a/statig/src/lib.rs b/statig/src/lib.rs index 3461338..d7d70ee 100644 --- a/statig/src/lib.rs +++ b/statig/src/lib.rs @@ -381,6 +381,7 @@ //! points during state machine execution. //! //! - `on_dispatch` is called before an event is dispatched to a specific state or superstate. +//! - `after_dispatch` is called after an event is dispatched to a specific state or superstate. //! - `on_transition` is called after a transition has occurred. //! //! ``` @@ -395,6 +396,7 @@ //! #[state_machine( //! initial = "State::on()", //! on_dispatch = "Self::on_dispatch", +//! after_dispatch = "Self::after_dispatch", //! on_transition = "Self::on_transition", //! state(derive(Debug)), //! superstate(derive(Debug)) @@ -412,6 +414,9 @@ //! fn on_dispatch(&mut self, state: StateOrSuperstate, event: &Event) { //! println!("dispatched `{:?}` to `{:?}`", event, state); //! } +//! fn after_dispatch(&mut self, state: StateOrSuperstate, event: &Event) { +//! println!("dispatched `{:?}` to `{:?}`", event, state); +//! } //! } //! ``` //!