diff --git a/go.mod b/go.mod index 737a878..52f88d3 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/daegalus/transition -go 1.19 +go 1.20 diff --git a/transition.go b/transition.go index 147e32b..550925c 100644 --- a/transition.go +++ b/transition.go @@ -5,28 +5,28 @@ import ( ) // Transition is a struct, embed it in your struct to enable state machine for the struct -type Transition[T any] struct { +type Transition struct { State string } // SetState set state to Stater, just set, won't save it into database -func (transition *Transition[T]) SetState(name string) { +func (transition *Transition) SetState(name string) { transition.State = name } // GetState get current state from -func (transition Transition[T]) GetState() string { +func (transition Transition) GetState() string { return transition.State } // Stater is a interface including methods `GetState`, `SetState` -type Stater[T any] interface { +type Stater interface { SetState(name string) GetState() string } // New initialize a new StateMachine that hold states, events definitions -func New[T any](_ T) *StateMachine[T] { +func New[T Stater](_ T) *StateMachine[T] { return &StateMachine[T]{ states: map[string]*State[T]{}, events: map[string]*Event[T]{}, @@ -34,7 +34,7 @@ func New[T any](_ T) *StateMachine[T] { } // StateMachine a struct that hold states, events definitions -type StateMachine[T any] struct { +type StateMachine[T Stater] struct { initialState string states map[string]*State[T] events map[string]*Event[T] @@ -67,7 +67,7 @@ func (sm *StateMachine[T]) Event(name string) *Event[T] { } // Trigger trigger an event -func (sm *StateMachine[T]) Trigger(name string, value Stater[T]) error { +func (sm *StateMachine[T]) Trigger(name string, value T) error { stateWas := value.GetState() if stateWas == "" { @@ -138,26 +138,26 @@ func (sm *StateMachine[T]) Trigger(name string, value Stater[T]) error { } // State contains State information, including enter, exit hooks -type State[T any] struct { +type State[T Stater] struct { Name string - enters []func(value Stater[T]) error - exits []func(value Stater[T]) error + enters []func(value T) error + exits []func(value T) error } // Enter register an enter hook for State -func (state *State[T]) Enter(fc func(value Stater[T]) error) *State[T] { +func (state *State[T]) Enter(fc func(value T) error) *State[T] { state.enters = append(state.enters, fc) return state } // Exit register an exit hook for State -func (state *State[T]) Exit(fc func(value Stater[T]) error) *State[T] { +func (state *State[T]) Exit(fc func(value T) error) *State[T] { state.exits = append(state.exits, fc) return state } // Event contains Event information, including transition hooks -type Event[T any] struct { +type Event[T Stater] struct { Name string transitions map[string]*EventTransition[T] } @@ -177,11 +177,11 @@ func (event *Event[T]) To(name string) *EventTransition[T] { } // EventTransition hold event's to/froms states, also including befores, afters hooks -type EventTransition[T any] struct { +type EventTransition[T Stater] struct { to string froms []string - befores []func(value Stater[T]) error - afters []func(value Stater[T]) error + befores []func(value T) error + afters []func(value T) error } // From used to define from states @@ -192,13 +192,13 @@ func (transition *EventTransition[T]) From(states ...string) *EventTransition[T] } // Before register before hooks -func (transition *EventTransition[T]) Before(fc func(value Stater[T]) error) *EventTransition[T] { +func (transition *EventTransition[T]) Before(fc func(value T) error) *EventTransition[T] { transition.befores = append(transition.befores, fc) return transition } // After register after hooks -func (transition *EventTransition[T]) After(fc func(value Stater[T]) error) *EventTransition[T] { +func (transition *EventTransition[T]) After(fc func(value T) error) *EventTransition[T] { transition.afters = append(transition.afters, fc) return transition } diff --git a/transition_test.go b/transition_test.go index c0e7273..078b62e 100644 --- a/transition_test.go +++ b/transition_test.go @@ -9,7 +9,7 @@ type Order struct { Id int Address string - Transition[Order] + Transition } func getStateMachine() *StateMachine[*Order] { @@ -90,11 +90,11 @@ func TestStateCallbacks(t *testing.T) { address1 := "I'm an address should be set when enter checkout" address2 := "I'm an address should be set when exit checkout" - orderStateMachine.State("checkout").Enter(func(order Stater[*Order]) error { - order.(*Order).Address = address1 + orderStateMachine.State("checkout").Enter(func(order *Order) error { + order.Address = address1 return nil - }).Exit(func(order Stater[*Order]) error { - order.(*Order).Address = address2 + }).Exit(func(order *Order) error { + order.Address = address2 return nil }) @@ -122,11 +122,11 @@ func TestEventCallbacks(t *testing.T) { prevState, afterState string ) - orderStateMachine.Event("checkout").To("checkout").From("draft").Before(func(order Stater[*Order]) error { - prevState = order.(*Order).State + orderStateMachine.Event("checkout").To("checkout").From("draft").Before(func(order *Order) error { + prevState = order.State return nil - }).After(func(order Stater[*Order]) error { - afterState = order.(*Order).State + }).After(func(order *Order) error { + afterState = order.State return nil }) @@ -150,7 +150,7 @@ func TestTransitionOnEnterCallbackError(t *testing.T) { orderStateMachine = getStateMachine() ) - orderStateMachine.State("checkout").Enter(func(order Stater[*Order]) (err error) { + orderStateMachine.State("checkout").Enter(func(order *Order) (err error) { return errors.New("intentional error") }) @@ -169,7 +169,7 @@ func TestTransitionOnExitCallbackError(t *testing.T) { orderStateMachine = getStateMachine() ) - orderStateMachine.State("checkout").Exit(func(order Stater[*Order]) (err error) { + orderStateMachine.State("checkout").Exit(func(order *Order) (err error) { return errors.New("intentional error") }) @@ -192,7 +192,7 @@ func TestEventOnBeforeCallbackError(t *testing.T) { orderStateMachine = getStateMachine() ) - orderStateMachine.Event("checkout").To("checkout").From("draft").Before(func(order Stater[*Order]) error { + orderStateMachine.Event("checkout").To("checkout").From("draft").Before(func(order *Order) error { return errors.New("intentional error") }) @@ -211,7 +211,7 @@ func TestEventOnAfterCallbackError(t *testing.T) { orderStateMachine = getStateMachine() ) - orderStateMachine.Event("checkout").To("checkout").From("draft").After(func(order Stater[*Order]) error { + orderStateMachine.Event("checkout").To("checkout").From("draft").After(func(order *Order) error { return errors.New("intentional error") })