Skip to content

Commit

Permalink
trans_builder accepts a mutable system
Browse files Browse the repository at this point in the history
  • Loading branch information
Seldom-SE committed Dec 26, 2024
1 parent 14ff732 commit 5957ce8
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 109 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "seldom_state"
version = "0.12.0"
version = "0.13.0-dev"
edition = "2021"
categories = ["game-development"]
description = "Component-based state machine plugin for Bevy. Useful for AI, player state, and other entities that occupy various states."
Expand Down
8 changes: 3 additions & 5 deletions examples/done.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,9 @@ fn init(mut commands: Commands, asset_server: Res<AssetServer>) {
Idle,
StateMachine::default()
// When the player clicks, go there
.trans_builder(click, |_: &AnyState, pos| {
Some(GoToSelection {
speed: 200.,
target: pos,
})
.trans_builder(click, |trans: Trans<AnyState, _>| GoToSelection {
speed: 200.,
target: trans.out,
})
// `done` triggers when the `Done` component is added to the entity. When they're done
// going to the selection, idle.
Expand Down
19 changes: 12 additions & 7 deletions examples/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,18 @@ fn init(mut commands: Commands, asset_server: Res<AssetServer>) {
// When the player hits the ground, idle
.trans::<Falling, _>(grounded, Grounded::Idle)
// When the player is grounded, set their movement direction
.trans_builder(value_unbounded(Action::Move), |_: &Grounded, value| {
Some(match value {
value if value > 0.5 => Grounded::Right,
value if value < -0.5 => Grounded::Left,
_ => Grounded::Idle,
})
}),
.trans_builder(
value_unbounded(Action::Move),
|trans: Trans<Grounded, _>| {
let value = trans.out;

match value {
value if value > 0.5 => Grounded::Right,
value if value < -0.5 => Grounded::Left,
_ => Grounded::Idle,
}
},
),
Sprite::from_image(asset_server.load("player.png")),
Transform::from_xyz(500., 0., 0.),
));
Expand Down
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#![warn(missing_docs)]
#![allow(clippy::type_complexity)]

mod machine;
pub mod machine;
pub mod set;
mod state;
pub mod trigger;
Expand Down Expand Up @@ -38,7 +38,7 @@ pub mod prelude {
value_unbounded,
};
pub use crate::{
machine::StateMachine,
machine::{StateMachine, Trans},
state::{AnyState, EntityState},
trigger::{always, done, on_event, Done, EntityTrigger, IntoTrigger, Never},
StateMachinePlugin,
Expand Down
161 changes: 96 additions & 65 deletions src/machine.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
//! Module for the [`StateMachine`] component
use std::{
any::{type_name, Any, TypeId},
fmt::Debug,
marker::PhantomData,
};

use bevy::{
ecs::{
system::{EntityCommands, SystemState},
world::Command,
},
tasks::{ComputeTaskPool, ParallelSliceMut},
utils::HashMap,
ecs::{system::EntityCommands, world::Command},
utils::TypeIdMap,
};

use crate::{
prelude::*,
set::StateSet,
state::{Insert, OnEvent},
state::OnEvent,
trigger::{IntoTrigger, TriggerOut},
};

Expand All @@ -30,7 +28,11 @@ trait Transition: Debug + Send + Sync + 'static {
fn init(&mut self, world: &mut World);
/// Checks whether the transition should be taken. `entity` is the entity that contains the
/// state machine.
fn check(&mut self, world: &World, entity: Entity) -> Option<(Box<dyn Insert>, TypeId)>;
fn check<'a>(
&'a mut self,
world: &World,
entity: Entity,
) -> Option<(Box<dyn 'a + FnOnce(&mut World, TypeId)>, TypeId)>;
}

/// An edge in the state machine. The type parameters are the [`EntityTrigger`] that causes this
Expand All @@ -40,10 +42,7 @@ struct TransitionImpl<Trig, Prev, Build, Next>
where
Trig: EntityTrigger,
Prev: EntityState,
Build: 'static
+ Fn(&Prev, <<Trig as EntityTrigger>::Out as TriggerOut>::Ok) -> Option<Next>
+ Send
+ Sync,
Build: System<In = Trans<Prev, <Trig::Out as TriggerOut>::Ok>, Out = Next>,
Next: Component + EntityState,
{
pub trigger: Trig,
Expand All @@ -55,8 +54,7 @@ impl<Trig, Prev, Build, Next> Debug for TransitionImpl<Trig, Prev, Build, Next>
where
Trig: EntityTrigger,
Prev: EntityState,
Build:
Fn(&Prev, <<Trig as EntityTrigger>::Out as TriggerOut>::Ok) -> Option<Next> + Send + Sync,
Build: System<In = Trans<Prev, <Trig::Out as TriggerOut>::Ok>, Out = Next>,
Next: Component + EntityState,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Expand All @@ -72,30 +70,41 @@ impl<Trig, Prev, Build, Next> Transition for TransitionImpl<Trig, Prev, Build, N
where
Trig: EntityTrigger,
Prev: EntityState,
Build:
Fn(&Prev, <<Trig as EntityTrigger>::Out as TriggerOut>::Ok) -> Option<Next> + Send + Sync,
Build: System<In = Trans<Prev, <Trig::Out as TriggerOut>::Ok>, Out = Next>,
Next: Component + EntityState,
{
fn init(&mut self, world: &mut World) {
self.trigger.init(world);
self.builder.initialize(world);
}

fn check(&mut self, world: &World, entity: Entity) -> Option<(Box<dyn Insert>, TypeId)> {
let Ok(res) = self.trigger.check(entity, world).into_result() else {
return None;
};

(self.builder)(Prev::from_entity(entity, world), res)
.map(|state| (Box::new(state) as Box<dyn Insert>, TypeId::of::<Next>()))
fn check<'a>(
&'a mut self,
world: &World,
entity: Entity,
) -> Option<(Box<dyn 'a + FnOnce(&mut World, TypeId)>, TypeId)> {
self.trigger
.check(entity, world)
.into_result()
.map(|out| {
(
Box::new(move |world: &mut World, curr: TypeId| {
let prev = Prev::remove(entity, world, curr);
let next = self.builder.run(TransCtx { prev, out, entity }, world);
world.entity_mut(entity).insert(next);
}) as Box<dyn 'a + FnOnce(&mut World, TypeId)>,
TypeId::of::<Next>(),
)
})
.ok()
}
}

impl<Trig, Prev, Build, Next> TransitionImpl<Trig, Prev, Build, Next>
where
Trig: EntityTrigger,
Prev: EntityState,
Build:
Fn(&Prev, <<Trig as EntityTrigger>::Out as TriggerOut>::Ok) -> Option<Next> + Send + Sync,
Build: System<In = Trans<Prev, <Trig::Out as TriggerOut>::Ok>, Out = Next>,
Next: Component + EntityState,
{
pub fn new(trigger: Trig, builder: Build) -> Self {
Expand All @@ -107,6 +116,19 @@ where
}
}

/// Context for a transition
pub struct TransCtx<Prev, Out> {
/// Previous state
pub prev: Prev,
/// Output from the trigger
pub out: Out,
/// The entity with this state machine
pub entity: Entity,
}

/// Context for a transition, usable as a `SystemInput`
pub type Trans<Prev, Out> = In<TransCtx<Prev, Out>>;

/// Information about a state
#[derive(Debug)]
struct StateMetadata {
Expand All @@ -119,11 +141,9 @@ struct StateMetadata {
impl StateMetadata {
fn new<S: EntityState>() -> Self {
Self {
name: type_name::<S>().to_owned(),
on_enter: default(),
on_exit: vec![OnEvent::Entity(Box::new(|entity: &mut EntityCommands| {
S::remove(entity);
}))],
name: type_name::<S>().to_string(),
on_enter: Vec::new(),
on_exit: Vec::new(),
}
}
}
Expand All @@ -133,7 +153,7 @@ impl StateMetadata {
/// `StateMachine::trans`, and other methods.
#[derive(Component)]
pub struct StateMachine {
states: HashMap<TypeId, StateMetadata>,
states: TypeIdMap<StateMetadata>,
/// Each transition and the state it should apply in (or [`AnyState`]). We store the transitions
/// in a flat list so that we ensure we always check them in the right order; storing them in
/// each StateMetadata would mean that e.g. we'd have to check every AnyState trigger before any
Expand All @@ -147,16 +167,19 @@ pub struct StateMachine {

impl Default for StateMachine {
fn default() -> Self {
let mut states = TypeIdMap::default();
states.insert(
TypeId::of::<AnyState>(),
StateMetadata {
name: "AnyState".to_owned(),
on_enter: vec![],
on_exit: vec![],
},
);

Self {
states: HashMap::from([(
TypeId::of::<AnyState>(),
StateMetadata {
name: "AnyState".to_owned(),
on_enter: vec![],
on_exit: vec![],
},
)]),
transitions: vec![],
states,
transitions: Vec::new(),
init_transitions: true,
log_transitions: false,
}
Expand All @@ -179,7 +202,7 @@ impl StateMachine {
trigger: impl IntoTrigger<Marker>,
state: impl Clone + Component,
) -> Self {
self.trans_builder(trigger, move |_: &S, _| Some(state.clone()))
self.trans_builder(trigger, move |_: Trans<S, _>| state.clone())
}

/// Get the metadata for the given state, creating it if necessary.
Expand All @@ -194,21 +217,25 @@ impl StateMachine {
/// `Some(Next)`, the machine will transition to that `Next` state.
pub fn trans_builder<
Prev: EntityState,
Trig: IntoTrigger<Marker>,
Trig: IntoTrigger<TrigMarker>,
Next: Clone + Component,
Marker,
TrigMarker,
BuildMarker,
>(
mut self,
trigger: Trig,
builder: impl 'static
+ Clone
+ Fn(&Prev, <<Trig::Trigger as EntityTrigger>::Out as TriggerOut>::Ok) -> Option<Next>
+ Send
+ Sync,
builder: impl IntoSystem<
Trans<Prev, <<Trig::Trigger as EntityTrigger>::Out as TriggerOut>::Ok>,
Next,
BuildMarker,
>,
) -> Self {
self.metadata_mut::<Prev>();
self.metadata_mut::<Next>();
let transition = TransitionImpl::<_, Prev, _, _>::new(trigger.into_trigger(), builder);
let transition = TransitionImpl::<_, Prev, _, _>::new(
trigger.into_trigger(),
IntoSystem::into_system(builder),
);
self.transitions.push((
TypeId::of::<Prev>(),
Box::new(transition) as Box<dyn Transition>,
Expand Down Expand Up @@ -288,22 +315,25 @@ impl StateMachine {

/// Runs all transitions until one is actually taken. If one is taken, logs the transition and
/// runs `on_enter/on_exit` triggers.
fn run(&mut self, world: &World, entity: Entity, commands: &mut Commands) {
// TODO Defer the actual transition so this can be parallelized, and see if that improves perf
fn run(&mut self, world: &mut World, entity: Entity) {
let mut states = self.states.keys();
let current = states.find(|&&state| world.entity(entity).contains_type_id(state));

let Some(&current) = current else {
panic!("Entity {entity:?} is in no state");
error!("Entity {entity:?} is in no state");
return;
};

let from = &self.states[&current];
if let Some(&other) = states.find(|&&state| world.entity(entity).contains_type_id(state)) {
let state = &from.name;
let other = &self.states[&other].name;
panic!("{entity:?} is in multiple states: {state} and {other}");
error!("{entity:?} is in multiple states: {state} and {other}");
return;
}

let Some((insert, next_state)) = self
let Some((trans, next_state)) = self
.transitions
.iter_mut()
.filter(|(type_id, _)| *type_id == current || *type_id == TypeId::of::<AnyState>())
Expand All @@ -314,12 +344,13 @@ impl StateMachine {
let to = &self.states[&next_state];

for event in from.on_exit.iter() {
event.trigger(entity, commands);
event.trigger(entity, &mut world.commands());
}

insert.insert(&mut commands.entity(entity));
trans(world, current);

for event in to.on_enter.iter() {
event.trigger(entity, commands);
event.trigger(entity, &mut world.commands());
}

if self.log_transitions {
Expand All @@ -342,9 +373,10 @@ impl StateMachine {
}

/// Runs all transitions on all entities.
// There are comments here about parallelization, but this is not parallelized anymore. Leaving them
// here in case it gets parallelized again.
pub(crate) fn transition(
world: &mut World,
system_state: &mut SystemState<ParallelCommands>,
machine_query: &mut QueryState<(Entity, &mut StateMachine)>,
) {
// Pull the machines out of the world so we can invoke mutable methods on them. The alternative
Expand All @@ -366,22 +398,21 @@ pub(crate) fn transition(

// `world` is not mutated here; the state machines are not in the world, and the Commands don't
// mutate until application
let par_commands = system_state.get(world);
let task_pool = ComputeTaskPool::get();
// let par_commands = system_state.get(world);
// let task_pool = ComputeTaskPool::get();

// chunk size of None means to automatically pick
borrowed_machines.par_splat_map_mut(task_pool, None, |_, chunk| {
for (entity, machine) in chunk {
par_commands.command_scope(|mut commands| machine.run(world, *entity, &mut commands));
}
});
for &mut (entity, ref mut machine) in &mut borrowed_machines {
machine.run(world, entity);
}

// put the borrowed machines back
for (entity, machine) in borrowed_machines {
*machine_query.get_mut(world, entity).unwrap().1 = machine;
}

// necessary to actually *apply* the commands we've enqueued
system_state.apply(world);
// system_state.apply(world);
}

#[cfg(test)]
Expand Down
Loading

0 comments on commit 5957ce8

Please sign in to comment.