diff --git a/CHANGELOG.md b/CHANGELOG.md index e003991..99fc0f1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,12 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - Add transition callback. A function which is called for every transition. It has default empty implementation. +- Add support for implicit and wildcard internal transitions + +### Changed + +- [breaking] Renamed custom_guard_error flag to custom_error as it is not guard specific anymore +- [breaking] Re-ordered on_exit/on_entry hooks calls ## [0.7.0] - 2024-07-03 diff --git a/README.md b/README.md index 464f538..d190c24 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,57 @@ statemachine!{ ``` See example `examples/input_state_pattern_match.rs` for a usage example. +#### Internal transitions + +The DSL supports internal transitions. +Internal transition allow to accept an event and process an action, +and then stay in the current state. +Internal transitions can be specified explicitly, e.g. +```plantuml +State2 + Event2 / event2_action = State2, +``` +or +```plantuml +State2 + Event2 / event2_action = _, +``` +or implicitly, by omitting the target state including '='. +```plantuml +State2 + Event2 / event2_action, +``` +It is also possible to define wildcard implicit (or explicit using '_') internal transitions. + +```rust +statemachine! { + transitions: { + *State1 + Event2 = State2, + State1 + Event3 = State3, + State1 + Event4 = State4, + + _ + Event2 / event2_action, + }, +} +``` +The example above demonstrates how you could make Event2 acceptable for any state, +not covered by any of the previous transitions, and to do an action to process it. + +It is equivalent to: + +```rust +statemachine! { + transitions: { + *State1 + Event2 = State2, + State1 + Event3 = State3, + State1 + Event4 = State4, + + State2 + Event2 / event2_action = State2, + State3 + Event2 / event2_action = State3, + State4 + Event2 / event2_action = State4, + }, +} +``` + +See also tests: `test_internal_transition_with_data()` or `test_wildcard_states_and_internal_transitions()` for a usage example. + #### Guard expressions Guard expression in square brackets [] allows to define a boolean expressions of multiple guard functions. diff --git a/docs/dsl.md b/docs/dsl.md index 83aac51..09c6ef0 100644 --- a/docs/dsl.md +++ b/docs/dsl.md @@ -21,7 +21,7 @@ statemachine!{ // [Optional] Can be optionally specified to add a new `type Error` to the // generated `StateMachineContext` trait to allow guards to return a custom // error type instead of `()`. - custom_guard_error: false, + custom_error: false, // [Optional] A list of derive names for the generated `States` and `Events` // enumerations respectively. For example, to `#[derive(Debug)]`, these diff --git a/examples/guard_custom_error.rs b/examples/guard_custom_error.rs index b8e85a4..aac7a2a 100644 --- a/examples/guard_custom_error.rs +++ b/examples/guard_custom_error.rs @@ -27,7 +27,7 @@ statemachine! { State2(MyStateData) + Event2 [guard2] / action2 = State3, // ... }, - custom_guard_error: true, + custom_error: true, } /// Context diff --git a/macros/src/codegen.rs b/macros/src/codegen.rs index c30984f..85b0ec1 100644 --- a/macros/src/codegen.rs +++ b/macros/src/codegen.rs @@ -203,7 +203,7 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream { }) .collect(); - let guard_error = if sm.custom_guard_error { + let custom_error = if sm.custom_error { quote! { Self::Error } } else { quote! { () } @@ -269,11 +269,13 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream { let state_name = format!("[{}::{}]", states_type_name, state); entries_exits.extend(quote! { #[doc = concat!("Called on entry to ", #state_name)] + #[inline(always)] fn #entry_ident(&mut self) {} }); let exit_ident = format_ident!("on_exit_{}", string_morph::to_snake_case(state)); entries_exits.extend(quote! { #[doc = concat!("Called on exit from ", #state_name)] + #[inline(always)] fn #exit_ident(&mut self) {} }); @@ -327,7 +329,7 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream { guard_list.extend(quote! { #[allow(missing_docs)] #[allow(clippy::result_unit_err)] - #is_async fn #guard <#all_lifetimes> (&self, #temporary_context #state_data #event_data) -> Result; + #is_async fn #guard <#all_lifetimes> (&self, #temporary_context #state_data #event_data) -> Result; }); }; Ok(()) @@ -350,10 +352,10 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream { .data_types .get(&transition.out_state.to_string()) { - quote! { Result<#output_data,#guard_error> } + quote! { Result<#output_data,#custom_error> } } else { // Empty return type - quote! { Result<(),#guard_error> } + quote! { Result<(),#custom_error> } }; let event_data = match sm.event_data.data_types.get(event) { @@ -412,21 +414,35 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream { guard.iter() .zip(action.iter().zip(out_state)).map(|(guard, (action,out_state))| { let binding = out_state.to_string(); - let out_state_string = &binding.split('(').next().unwrap(); + let out_state_string = binding.split('(').next().unwrap().trim(); let binding = in_state.to_string(); - let in_state_string = &binding.split('(').next().unwrap(); + let in_state_string = binding.split('(').next().unwrap().trim(); let entry_ident = format_ident!("on_entry_{}", string_morph::to_snake_case(out_state_string)); let exit_ident = format_ident!("on_exit_{}", string_morph::to_snake_case(in_state_string)); - let entry_exit_states = - quote! { - self.context.#exit_ident(); - self.context.#entry_ident(); - }; let (is_async_action, action_code) = generate_action(action, &temporary_context_call, action_params, &error_type_name); is_async_state_machine |= is_async_action; + let transition = if in_state_string == out_state_string { + // Stay in the same state => no need to call on_entry/on_exit and log_state_change + quote!{ + #action_code + self.state = #states_type_name::#out_state; + return Ok(&self.state); + } + } else { + quote!{ + self.context.#exit_ident(); + #action_code + let out_state = #states_type_name::#out_state; + self.context.log_state_change(&out_state); + self.context().transition_callback(&self.state, &out_state); + self.state = out_state; + self.context.#entry_ident(); + return Ok(&self.state); + } + }; if let Some(expr) = guard { // Guarded transition let guard_expression= expr.to_token_stream(&mut |async_ident: &AsyncIdent| { let guard_ident = &async_ident.ident; @@ -452,24 +468,12 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream { // Otherwise, there may be a later transition that passes, // so we'll defer to that. if guard_passed { - #action_code - let out_state = #states_type_name::#out_state; - self.context.log_state_change(&out_state); - #entry_exit_states - self.context().transition_callback(&self.state, &out_state); - self.state = out_state; - return Ok(&self.state); + #transition } } } else { // Unguarded transition quote!{ - #action_code - let out_state = #states_type_name::#out_state; - self.context.log_state_change(&out_state); - #entry_exit_states - self.context().transition_callback(&self.state, &out_state); - self.state = out_state; - return Ok(&self.state); + #transition } } } @@ -513,7 +517,7 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream { // lifetimes that exists in #events_type_name but not in #states_type_name let event_unique_lifetimes = event_lifetimes - state_lifetimes; - let guard_error = if sm.custom_guard_error { + let custom_error = if sm.custom_error { quote! { /// The error type returned by guard or action functions. type Error: core::fmt::Debug; @@ -528,7 +532,7 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream { quote! {} }; - let error_type = if sm.custom_guard_error { + let error_type = if sm.custom_error { quote! { #error_type_name<::Error> } @@ -543,7 +547,7 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream { /// This trait outlines the guards and actions that need to be implemented for the state /// machine. pub trait #state_machine_context_type_name { - #guard_error + #custom_error #guard_list #action_list #entries_exits diff --git a/macros/src/parser/data.rs b/macros/src/parser/data.rs index 894139a..170ba2c 100644 --- a/macros/src/parser/data.rs +++ b/macros/src/parser/data.rs @@ -47,13 +47,19 @@ impl DataDefinitions { if prev != &data_type.clone() { return Err(parse::Error::new( data_type.span(), - "This event's type does not match its previous definition.", + format!( + "This event's type {} does not match its previous definition", + key + ), )); } } else { return Err(parse::Error::new( data_type.span(), - "This event's type does not match its previous definition.", + format!( + "This event's type {} does not match its previous definition", + key + ), )); } } diff --git a/macros/src/parser/mod.rs b/macros/src/parser/mod.rs index ea65567..faed8ee 100644 --- a/macros/src/parser/mod.rs +++ b/macros/src/parser/mod.rs @@ -49,7 +49,7 @@ pub struct ParsedStateMachine { pub derive_states: Vec, pub derive_events: Vec, pub temporary_context_type: Option, - pub custom_guard_error: bool, + pub custom_error: bool, pub states: HashMap, pub starting_state: Ident, pub state_data: DataDefinitions, @@ -110,7 +110,19 @@ fn add_transition( } impl ParsedStateMachine { - pub fn new(sm: StateMachine) -> parse::Result { + pub fn new(mut sm: StateMachine) -> parse::Result { + // Derive out_state for internal non-wildcard transitions + for transition in sm.transitions.iter_mut() { + if transition.out_state.internal_transition && !transition.in_state.wildcard { + transition.out_state.ident = transition.in_state.ident.clone(); + transition + .out_state + .data_type + .clone_from(&transition.in_state.data_type); + transition.out_state.internal_transition = false; + } + } + // Check the initial state definition let mut starting_transitions_iter = sm.transitions.iter().filter(|sm| sm.in_state.start); @@ -138,16 +150,18 @@ impl ParsedStateMachine { for transition in sm.transitions.iter() { // Collect states let in_state_name = transition.in_state.ident.to_string(); - let out_state_name = transition.out_state.ident.to_string(); if !transition.in_state.wildcard { states.insert(in_state_name.clone(), transition.in_state.ident.clone()); state_data.collect(in_state_name.clone(), transition.in_state.data_type.clone())?; } - states.insert(out_state_name.clone(), transition.out_state.ident.clone()); - state_data.collect( - out_state_name.clone(), - transition.out_state.data_type.clone(), - )?; + if !transition.out_state.internal_transition { + let out_state_name = transition.out_state.ident.to_string(); + states.insert(out_state_name.clone(), transition.out_state.ident.clone()); + state_data.collect( + out_state_name.clone(), + transition.out_state.data_type.clone(), + )?; + } // Collect events let event_name = transition.event.ident.to_string(); @@ -158,7 +172,10 @@ impl ParsedStateMachine { if !transition.in_state.wildcard { states_events_mapping.insert(transition.in_state.ident.to_string(), HashMap::new()); } - states_events_mapping.insert(transition.out_state.ident.to_string(), HashMap::new()); + if !transition.out_state.internal_transition { + states_events_mapping + .insert(transition.out_state.ident.to_string(), HashMap::new()); + } } for transition in sm.transitions.iter() { @@ -185,12 +202,17 @@ impl ParsedStateMachine { }; // create the transition + let mut out_state = transition.out_state.clone(); + if out_state.internal_transition { + out_state.ident = in_state.ident.clone(); + out_state.data_type.clone_from(&in_state.data_type); + } let wildcard_transition = StateTransition { in_state, event: transition.event.clone(), guard: transition.guard.clone(), action: transition.action.clone(), - out_state: transition.out_state.clone(), + out_state, }; // add the wildcard transition to the transition map @@ -223,7 +245,7 @@ impl ParsedStateMachine { derive_states: sm.derive_states, derive_events: sm.derive_events, temporary_context_type: sm.temporary_context_type, - custom_guard_error: sm.custom_guard_error, + custom_error: sm.custom_error, states, starting_state, state_data, diff --git a/macros/src/parser/output_state.rs b/macros/src/parser/output_state.rs index 29642aa..f883d48 100644 --- a/macros/src/parser/output_state.rs +++ b/macros/src/parser/output_state.rs @@ -1,43 +1,64 @@ +use proc_macro2::Span; use syn::{parenthesized, parse, spanned::Spanned, token, Ident, Token, Type}; #[derive(Debug, Clone)] pub struct OutputState { pub ident: Ident, + pub internal_transition: bool, pub data_type: Option, } impl parse::Parse for OutputState { fn parse(input: parse::ParseStream) -> syn::Result { - input.parse::()?; - let ident: Ident = input.parse()?; + if input.peek(Token![=]) { + input.parse::()?; + let (internal_transition, ident) = if input.peek(Token![_]) { + // Underscore ident here is used to represent an internal transition + let underscore = input.parse::()?; + (true, underscore.into()) + } else { + (false, input.parse()?) + }; - // Possible type on the output state - let data_type = if input.peek(token::Paren) { - let content; - parenthesized!(content in input); - let input: Type = content.parse()?; + // Possible type on the output state + let data_type = if !internal_transition && input.peek(token::Paren) { + let content; + parenthesized!(content in input); + let input: Type = content.parse()?; - // Check so the type is supported - match &input { - Type::Array(_) - | Type::Path(_) - | Type::Ptr(_) - | Type::Reference(_) - | Type::Slice(_) - | Type::Tuple(_) => (), - _ => { - return Err(parse::Error::new( - input.span(), - "This is an unsupported type for states.", - )) + // Check so the type is supported + match &input { + Type::Array(_) + | Type::Path(_) + | Type::Ptr(_) + | Type::Reference(_) + | Type::Slice(_) + | Type::Tuple(_) => (), + _ => { + return Err(parse::Error::new( + input.span(), + "This is an unsupported type for states.", + )) + } } - } - Some(input) - } else { - None - }; + Some(input) + } else { + None + }; - Ok(Self { ident, data_type }) + Ok(Self { + ident, + internal_transition, + data_type, + }) + } else { + // Internal transition + Ok(Self { + ident: Ident::new("_", Span::call_site()), + internal_transition: true, + data_type: None, + }) + } } } diff --git a/macros/src/parser/state_machine.rs b/macros/src/parser/state_machine.rs index 85f4402..2b011d0 100644 --- a/macros/src/parser/state_machine.rs +++ b/macros/src/parser/state_machine.rs @@ -4,7 +4,7 @@ use syn::{braced, bracketed, parse, spanned::Spanned, token, Ident, Token, Type} #[derive(Debug)] pub struct StateMachine { pub temporary_context_type: Option, - pub custom_guard_error: bool, + pub custom_error: bool, pub transitions: Vec, pub name: Option, pub derive_states: Vec, @@ -15,7 +15,7 @@ impl StateMachine { pub fn new() -> Self { StateMachine { temporary_context_type: None, - custom_guard_error: false, + custom_error: false, transitions: Vec::new(), name: None, derive_states: Vec::new(), @@ -72,11 +72,11 @@ impl parse::Parse for StateMachine { } } } - "custom_guard_error" => { + "custom_error" => { input.parse::()?; - let custom_guard_error: syn::LitBool = input.parse()?; - if custom_guard_error.value { - statemachine.custom_guard_error = true + let custom_error: syn::LitBool = input.parse()?; + if custom_error.value { + statemachine.custom_error = true } } "temporary_context" => { @@ -145,7 +145,7 @@ impl parse::Parse for StateMachine { "Unknown keyword {}. Support keywords: [\"name\", \ \"transitions\", \ \"temporary_context\", \ - \"custom_guard_error\", \ + \"custom_error\", \ \"derive_states\", \ \"derive_events\" ]", diff --git a/macros/src/parser/transition.rs b/macros/src/parser/transition.rs index e7031ef..e24cd32 100644 --- a/macros/src/parser/transition.rs +++ b/macros/src/parser/transition.rs @@ -7,7 +7,7 @@ use quote::quote; use std::fmt; use syn::{bracketed, parse, token, Ident, Token}; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct StateTransition { pub in_state: InputState, pub event: Event, diff --git a/tests/test.rs b/tests/test.rs index 9f38926..4025e3e 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -4,6 +4,40 @@ use derive_more::Display; use smlang::statemachine; +mod internal_macros { + #[macro_export] + macro_rules! assert_transition { + ($sm:expr, $event:expr, $expected_state:expr, $expected_count:expr) => {{ + let prev_state = $sm.state; + $sm.process_event($event).unwrap(); + println!("{:?} -> {:?} : {:?}", prev_state, $sm.state, $sm.context()); + assert_eq!($expected_state, $sm.state); + assert_eq!($expected_count, $sm.context().count); + }}; + } + #[macro_export] + macro_rules! assert_transition_ok { + ($sm:expr, $event:expr, $expected_action:expr, $expected_result:pat) => {{ + let prev_state = $sm.state; + if let Ok(result) = $sm.process_event($event) { + let result = result.clone(); + println!("{:?} -> {:?} : {:?}", prev_state, result, $sm.context()); + match result { + $expected_result => {} + _ => panic!( + "Assertion failed:\n expected: {:?},\n actual: {:?}", + stringify!($expected_result), + result + ), + } + assert_eq!($expected_action, $sm.context().action); + } else { + panic!("assert_transition_ok failed") + } + }}; + } +} + #[test] fn compile_fail_tests() { let t = trybuild::TestCases::new(); @@ -325,3 +359,131 @@ fn guard_errors() { sm.process_event(Events::Event1).unwrap(); assert!(matches!(sm.state(), &States::Done)); } +#[test] +fn test_internal_transition_with_data() { + #[derive(Clone, Copy, Debug, PartialEq, Eq)] + pub struct State1Data(pub ActionId); + #[derive(Clone, Copy, Debug, PartialEq, Eq)] + pub struct State3Data(pub ActionId); + + statemachine! { + transitions: { + *State1(State1Data) + Event2 / action12 = State2, + State1(State1Data) + Event3 / action13 = State3(State3Data), + State1(State1Data) + Event4 / action14 = State4(State3Data), + + State2 + Event3 / action23 = State3(State3Data), + State4(State3Data) + Event1 / action44 = _, // Same as State4(State3Data) + Event1 / action44 + + // TRANSITION : _ + Event3 / increment_count = _, IS EQUIVALENT TO THE FOLLOWING TWO: + // State3(State3Data) + Event3 / action_3 = State3(State3Data), + // State4(State3Data) + Event3 / action_3 = State4(State3Data), + _ + Event3 / action_3 = _, + }, + derive_states: [Debug, Clone, Copy, Eq ] + } + /// Context + #[derive(Default, Debug, PartialEq, Eq)] + pub struct Context { + action: ActionId, + } + #[derive(Debug, PartialEq, Eq, Clone, Copy, Default)] + enum ActionId { + #[default] + Init, + Action3, + Action12, + Action13, + Action14, + Action23, + Action44, + } + impl StateMachineContext for Context { + fn action_3(&mut self, d: &State3Data) -> Result { + self.action = ActionId::Action3; + Ok(*d) + } + + fn action44(&mut self, _d: &State3Data) -> Result { + self.action = ActionId::Action44; + Ok(State3Data(ActionId::Action44)) + } + fn action14(&mut self, _d: &State1Data) -> Result { + self.action = ActionId::Action14; + Ok(State3Data(ActionId::Action14)) + } + fn action12(&mut self, _d: &State1Data) -> Result<(), ()> { + self.action = ActionId::Action12; + Ok(()) + } + fn action13(&mut self, _d: &State1Data) -> Result { + self.action = ActionId::Action13; + Ok(State3Data(ActionId::Action13)) + } + fn action23(&mut self) -> Result { + self.action = ActionId::Action23; + Ok(State3Data(ActionId::Action23)) + } + } + + { + let mut sm = StateMachine::new(Context::default(), State1Data(ActionId::Init)); + matches!(States::State2, States::State2); + assert_transition_ok!(sm, Events::Event2, ActionId::Action12, States::State2); + assert!(sm.process_event(Events::Event1).is_err()); + assert!(sm.process_event(Events::Event2).is_err()); + assert!(sm.process_event(Events::Event4).is_err()); + assert_transition_ok!(sm, Events::Event3, ActionId::Action23, States::State3(_)); + assert_transition_ok!(sm, Events::Event3, ActionId::Action3, States::State3(_)); + assert_transition_ok!(sm, Events::Event3, ActionId::Action3, States::State3(_)); + assert!(sm.process_event(Events::Event1).is_err()); + assert!(sm.process_event(Events::Event2).is_err()); + assert!(sm.process_event(Events::Event4).is_err()); + } + { + let mut sm = StateMachine::new(Context::default(), State1Data(ActionId::Init)); + assert_transition_ok!(sm, Events::Event3, ActionId::Action13, States::State3(_)); + assert!(sm.process_event(Events::Event1).is_err()); + assert!(sm.process_event(Events::Event2).is_err()); + assert!(sm.process_event(Events::Event4).is_err()); + } + { + let mut sm = StateMachine::new(Context::default(), State1Data(ActionId::Init)); + assert_transition_ok!(sm, Events::Event4, ActionId::Action14, States::State4(_)); + assert_transition_ok!(sm, Events::Event1, ActionId::Action44, States::State4(_)); + assert_transition_ok!(sm, Events::Event3, ActionId::Action3, States::State4(_)); + } +} +#[test] +fn test_wildcard_states_and_internal_transitions() { + statemachine! { + transitions: { + *State1 + Event2 = State2, + State2 + Event3 = State3, + _ + Event1 / increment_count, // Internal transition (implicit: omitting target state) + _ + Event3 / increment_count = _ , // Internal transition (explicit: using _ as target state) + }, + derive_states: [Debug, Clone, Copy] + } + #[derive(Debug)] + pub struct Context { + count: u32, + } + impl StateMachineContext for Context { + fn increment_count(&mut self) -> Result<(), ()> { + self.count += 1; + Ok(()) + } + } + + let mut sm = StateMachine::new(Context { count: 0 }); + + assert_transition!(sm, Events::Event1, States::State1, 1); + assert_transition!(sm, Events::Event2, States::State2, 1); + assert_transition!(sm, Events::Event3, States::State3, 1); + assert_transition!(sm, Events::Event1, States::State3, 2); + assert_transition!(sm, Events::Event3, States::State3, 3); + + assert!(sm.process_event(Events::Event2).is_err()); // InvalidEvent + assert_eq!(States::State3, sm.state); +}