Skip to content

Commit

Permalink
Merge pull request #80 from dkumsh/feat/wildcard-internal-transitions
Browse files Browse the repository at this point in the history
Support for wildcard internal transitions
  • Loading branch information
ryan-summers authored Jul 22, 2024
2 parents 5980311 + ae8574c commit 9da1c1c
Show file tree
Hide file tree
Showing 11 changed files with 349 additions and 77 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ and this project adheres to [Semantic Versioning](http://semver.org/).

- Add transition callback. A function which is called for every transition. It has default empty
implementation.
- Add support for implicit and wildcard internal transitions

### Changed

- [breaking] Renamed custom_guard_error flag to custom_error as it is not guard specific anymore
- [breaking] Re-ordered on_exit/on_entry hooks calls

## [0.7.0] - 2024-07-03

Expand Down
51 changes: 51 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,57 @@ statemachine!{
```
See example `examples/input_state_pattern_match.rs` for a usage example.

#### Internal transitions

The DSL supports internal transitions.
Internal transition allow to accept an event and process an action,
and then stay in the current state.
Internal transitions can be specified explicitly, e.g.
```plantuml
State2 + Event2 / event2_action = State2,
```
or
```plantuml
State2 + Event2 / event2_action = _,
```
or implicitly, by omitting the target state including '='.
```plantuml
State2 + Event2 / event2_action,
```
It is also possible to define wildcard implicit (or explicit using '_') internal transitions.

```rust
statemachine! {
transitions: {
*State1 + Event2 = State2,
State1 + Event3 = State3,
State1 + Event4 = State4,

_ + Event2 / event2_action,
},
}
```
The example above demonstrates how you could make Event2 acceptable for any state,
not covered by any of the previous transitions, and to do an action to process it.

It is equivalent to:

```rust
statemachine! {
transitions: {
*State1 + Event2 = State2,
State1 + Event3 = State3,
State1 + Event4 = State4,

State2 + Event2 / event2_action = State2,
State3 + Event2 / event2_action = State3,
State4 + Event2 / event2_action = State4,
},
}
```

See also tests: `test_internal_transition_with_data()` or `test_wildcard_states_and_internal_transitions()` for a usage example.

#### Guard expressions

Guard expression in square brackets [] allows to define a boolean expressions of multiple guard functions.
Expand Down
2 changes: 1 addition & 1 deletion docs/dsl.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ statemachine!{
// [Optional] Can be optionally specified to add a new `type Error` to the
// generated `StateMachineContext` trait to allow guards to return a custom
// error type instead of `()`.
custom_guard_error: false,
custom_error: false,

// [Optional] A list of derive names for the generated `States` and `Events`
// enumerations respectively. For example, to `#[derive(Debug)]`, these
Expand Down
2 changes: 1 addition & 1 deletion examples/guard_custom_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ statemachine! {
State2(MyStateData) + Event2 [guard2] / action2 = State3,
// ...
},
custom_guard_error: true,
custom_error: true,
}

/// Context
Expand Down
60 changes: 32 additions & 28 deletions macros/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream {
})
.collect();

let guard_error = if sm.custom_guard_error {
let custom_error = if sm.custom_error {
quote! { Self::Error }
} else {
quote! { () }
Expand Down Expand Up @@ -269,11 +269,13 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream {
let state_name = format!("[{}::{}]", states_type_name, state);
entries_exits.extend(quote! {
#[doc = concat!("Called on entry to ", #state_name)]
#[inline(always)]
fn #entry_ident(&mut self) {}
});
let exit_ident = format_ident!("on_exit_{}", string_morph::to_snake_case(state));
entries_exits.extend(quote! {
#[doc = concat!("Called on exit from ", #state_name)]
#[inline(always)]
fn #exit_ident(&mut self) {}
});

Expand Down Expand Up @@ -327,7 +329,7 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream {
guard_list.extend(quote! {
#[allow(missing_docs)]
#[allow(clippy::result_unit_err)]
#is_async fn #guard <#all_lifetimes> (&self, #temporary_context #state_data #event_data) -> Result<bool,#guard_error>;
#is_async fn #guard <#all_lifetimes> (&self, #temporary_context #state_data #event_data) -> Result<bool,#custom_error>;
});
};
Ok(())
Expand All @@ -350,10 +352,10 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream {
.data_types
.get(&transition.out_state.to_string())
{
quote! { Result<#output_data,#guard_error> }
quote! { Result<#output_data,#custom_error> }
} else {
// Empty return type
quote! { Result<(),#guard_error> }
quote! { Result<(),#custom_error> }
};

let event_data = match sm.event_data.data_types.get(event) {
Expand Down Expand Up @@ -412,21 +414,35 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream {
guard.iter()
.zip(action.iter().zip(out_state)).map(|(guard, (action,out_state))| {
let binding = out_state.to_string();
let out_state_string = &binding.split('(').next().unwrap();
let out_state_string = binding.split('(').next().unwrap().trim();
let binding = in_state.to_string();
let in_state_string = &binding.split('(').next().unwrap();
let in_state_string = binding.split('(').next().unwrap().trim();

let entry_ident = format_ident!("on_entry_{}", string_morph::to_snake_case(out_state_string));
let exit_ident = format_ident!("on_exit_{}", string_morph::to_snake_case(in_state_string));

let entry_exit_states =
quote! {
self.context.#exit_ident();
self.context.#entry_ident();
};
let (is_async_action, action_code) = generate_action(action, &temporary_context_call, action_params, &error_type_name);
is_async_state_machine |= is_async_action;

let transition = if in_state_string == out_state_string {
// Stay in the same state => no need to call on_entry/on_exit and log_state_change
quote!{
#action_code
self.state = #states_type_name::#out_state;
return Ok(&self.state);
}
} else {
quote!{
self.context.#exit_ident();
#action_code
let out_state = #states_type_name::#out_state;
self.context.log_state_change(&out_state);
self.context().transition_callback(&self.state, &out_state);
self.state = out_state;
self.context.#entry_ident();
return Ok(&self.state);
}
};
if let Some(expr) = guard { // Guarded transition
let guard_expression= expr.to_token_stream(&mut |async_ident: &AsyncIdent| {
let guard_ident = &async_ident.ident;
Expand All @@ -452,24 +468,12 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream {
// Otherwise, there may be a later transition that passes,
// so we'll defer to that.
if guard_passed {
#action_code
let out_state = #states_type_name::#out_state;
self.context.log_state_change(&out_state);
#entry_exit_states
self.context().transition_callback(&self.state, &out_state);
self.state = out_state;
return Ok(&self.state);
#transition
}
}
} else { // Unguarded transition
quote!{
#action_code
let out_state = #states_type_name::#out_state;
self.context.log_state_change(&out_state);
#entry_exit_states
self.context().transition_callback(&self.state, &out_state);
self.state = out_state;
return Ok(&self.state);
#transition
}
}
}
Expand Down Expand Up @@ -513,7 +517,7 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream {
// lifetimes that exists in #events_type_name but not in #states_type_name
let event_unique_lifetimes = event_lifetimes - state_lifetimes;

let guard_error = if sm.custom_guard_error {
let custom_error = if sm.custom_error {
quote! {
/// The error type returned by guard or action functions.
type Error: core::fmt::Debug;
Expand All @@ -528,7 +532,7 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream {
quote! {}
};

let error_type = if sm.custom_guard_error {
let error_type = if sm.custom_error {
quote! {
#error_type_name<<T as #state_machine_context_type_name>::Error>
}
Expand All @@ -543,7 +547,7 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream {
/// This trait outlines the guards and actions that need to be implemented for the state
/// machine.
pub trait #state_machine_context_type_name {
#guard_error
#custom_error
#guard_list
#action_list
#entries_exits
Expand Down
10 changes: 8 additions & 2 deletions macros/src/parser/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,19 @@ impl DataDefinitions {
if prev != &data_type.clone() {
return Err(parse::Error::new(
data_type.span(),
"This event's type does not match its previous definition.",
format!(
"This event's type {} does not match its previous definition",
key
),
));
}
} else {
return Err(parse::Error::new(
data_type.span(),
"This event's type does not match its previous definition.",
format!(
"This event's type {} does not match its previous definition",
key
),
));
}
}
Expand Down
44 changes: 33 additions & 11 deletions macros/src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ pub struct ParsedStateMachine {
pub derive_states: Vec<Ident>,
pub derive_events: Vec<Ident>,
pub temporary_context_type: Option<Type>,
pub custom_guard_error: bool,
pub custom_error: bool,
pub states: HashMap<String, Ident>,
pub starting_state: Ident,
pub state_data: DataDefinitions,
Expand Down Expand Up @@ -110,7 +110,19 @@ fn add_transition(
}

impl ParsedStateMachine {
pub fn new(sm: StateMachine) -> parse::Result<Self> {
pub fn new(mut sm: StateMachine) -> parse::Result<Self> {
// Derive out_state for internal non-wildcard transitions
for transition in sm.transitions.iter_mut() {
if transition.out_state.internal_transition && !transition.in_state.wildcard {
transition.out_state.ident = transition.in_state.ident.clone();
transition
.out_state
.data_type
.clone_from(&transition.in_state.data_type);
transition.out_state.internal_transition = false;
}
}

// Check the initial state definition
let mut starting_transitions_iter = sm.transitions.iter().filter(|sm| sm.in_state.start);

Expand Down Expand Up @@ -138,16 +150,18 @@ impl ParsedStateMachine {
for transition in sm.transitions.iter() {
// Collect states
let in_state_name = transition.in_state.ident.to_string();
let out_state_name = transition.out_state.ident.to_string();
if !transition.in_state.wildcard {
states.insert(in_state_name.clone(), transition.in_state.ident.clone());
state_data.collect(in_state_name.clone(), transition.in_state.data_type.clone())?;
}
states.insert(out_state_name.clone(), transition.out_state.ident.clone());
state_data.collect(
out_state_name.clone(),
transition.out_state.data_type.clone(),
)?;
if !transition.out_state.internal_transition {
let out_state_name = transition.out_state.ident.to_string();
states.insert(out_state_name.clone(), transition.out_state.ident.clone());
state_data.collect(
out_state_name.clone(),
transition.out_state.data_type.clone(),
)?;
}

// Collect events
let event_name = transition.event.ident.to_string();
Expand All @@ -158,7 +172,10 @@ impl ParsedStateMachine {
if !transition.in_state.wildcard {
states_events_mapping.insert(transition.in_state.ident.to_string(), HashMap::new());
}
states_events_mapping.insert(transition.out_state.ident.to_string(), HashMap::new());
if !transition.out_state.internal_transition {
states_events_mapping
.insert(transition.out_state.ident.to_string(), HashMap::new());
}
}

for transition in sm.transitions.iter() {
Expand All @@ -185,12 +202,17 @@ impl ParsedStateMachine {
};

// create the transition
let mut out_state = transition.out_state.clone();
if out_state.internal_transition {
out_state.ident = in_state.ident.clone();
out_state.data_type.clone_from(&in_state.data_type);
}
let wildcard_transition = StateTransition {
in_state,
event: transition.event.clone(),
guard: transition.guard.clone(),
action: transition.action.clone(),
out_state: transition.out_state.clone(),
out_state,
};

// add the wildcard transition to the transition map
Expand Down Expand Up @@ -223,7 +245,7 @@ impl ParsedStateMachine {
derive_states: sm.derive_states,
derive_events: sm.derive_events,
temporary_context_type: sm.temporary_context_type,
custom_guard_error: sm.custom_guard_error,
custom_error: sm.custom_error,
states,
starting_state,
state_data,
Expand Down
Loading

0 comments on commit 9da1c1c

Please sign in to comment.