Skip to content

Commit

Permalink
handler proxy to return error
Browse files Browse the repository at this point in the history
  • Loading branch information
gergogera committed Feb 15, 2023
1 parent e4b37b1 commit 3d07d93
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 26 deletions.
14 changes: 7 additions & 7 deletions diam/sm/cer.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,13 @@ func errorCEA(sm *StateMachine, c diam.Conn, m *diam.Message, cer *smparser.CER,
a.NewAVP(avp.FirmwareRevision, 0, 0, sm.cfg.FirmwareRevision)
}

sm.cfg.AnswerHandlerProxy(func(conn diam.Conn, answer *diam.Message) {
return sm.cfg.AnswerHandlerProxy(func(conn diam.Conn, answer *diam.Message) error {
_, err = answer.WriteTo(conn)
if err != nil {
err = fmt.Errorf("Error CEA '%s' send failure: %v", errMessage, err)
}
return err
})(c, a)
if err != nil {
err = fmt.Errorf("Error CEA '%s' send failure: %v", errMessage, err)
}
return err
}

// successCEA sends a success answer indicating that the CER was successfully
Expand Down Expand Up @@ -168,8 +168,8 @@ func successCEA(sm *StateMachine, c diam.Conn, m *diam.Message, cer *smparser.CE
a.NewAVP(avp.FirmwareRevision, 0, 0, sm.cfg.FirmwareRevision)
}

sm.cfg.AnswerHandlerProxy(func(conn diam.Conn, answer *diam.Message) {
return sm.cfg.AnswerHandlerProxy(func(conn diam.Conn, answer *diam.Message) error {
_, err = answer.WriteTo(conn)
return err
})(c, a)
return err
}
16 changes: 6 additions & 10 deletions diam/sm/cer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,13 @@ func TestHandleCER_CerHook(t *testing.T) {
modServerSettings := *serverSettings
var requestProxyCalled = false
var writerProxyCalled = false
modServerSettings.RequestHandlerProxy = func(f diam.HandlerFunc) diam.HandlerFunc {
return func(c diam.Conn, m *diam.Message) {
f(c, m)
requestProxyCalled = true
}
modServerSettings.RequestHandlerProxy = func(f HandlerFuncProxy) HandlerFuncProxy {
requestProxyCalled = true
return f
}
modServerSettings.AnswerHandlerProxy = func(f diam.HandlerFunc) diam.HandlerFunc {
return func(c diam.Conn, m *diam.Message) {
f(c, m)
writerProxyCalled = true
}
modServerSettings.AnswerHandlerProxy = func(f HandlerFuncProxy) HandlerFuncProxy {
writerProxyCalled = true
return f
}

sm := New(&modServerSettings)
Expand Down
24 changes: 15 additions & 9 deletions diam/sm/sm.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,11 @@ func PrepareSupportedApps(d *dict.Parser) []*SupportedApp {
return locallySupportedApps
}

type MessageHandlerProxy = func(diam.HandlerFunc) diam.HandlerFunc
type HandlerFuncProxy = func(diam.Conn, *diam.Message) error
type MessageHandlerProxy = func(HandlerFuncProxy) HandlerFuncProxy

var defaultHandlerProxy = func(f diam.HandlerFunc) diam.HandlerFunc {
return func(c diam.Conn, m *diam.Message) {
f(c, m)
}
var defaultHandlerFuncProxy = func(f HandlerFuncProxy) HandlerFuncProxy {
return f
}

// Settings used to configure the state machine with AVPs to be added
Expand Down Expand Up @@ -104,10 +103,10 @@ func New(settings *Settings) *StateMachine {
settings.HostIPAddresses = []datatype.Address{settings.HostIPAddress}
}
if settings.RequestHandlerProxy == nil {
settings.RequestHandlerProxy = defaultHandlerProxy
settings.RequestHandlerProxy = defaultHandlerFuncProxy
}
if settings.AnswerHandlerProxy == nil {
settings.AnswerHandlerProxy = defaultHandlerProxy
settings.AnswerHandlerProxy = defaultHandlerFuncProxy
}

sm := &StateMachine{
Expand All @@ -116,9 +115,16 @@ func New(settings *Settings) *StateMachine {
hsNotifyc: make(chan diam.Conn),
supportedApps: PrepareSupportedApps(dict.Default),
}
sm.mux.Handle("CER", settings.RequestHandlerProxy(handleCER(sm)))

var cerHandlerProxy diam.HandlerFunc = func(c diam.Conn, m *diam.Message) {
_ = settings.RequestHandlerProxy(func(c diam.Conn, m *diam.Message) error {
handleCER(sm)(c, m)
return nil
})(c, m)
}
sm.mux.Handle("CER", cerHandlerProxy)
sm.mux.Handle("DWR", handshakeOK(handleDWR(sm)))
sm.mux.HandleIdx(baseCERIdx, settings.RequestHandlerProxy(handleCER(sm)))
sm.mux.HandleIdx(baseCERIdx, cerHandlerProxy)
sm.mux.HandleIdx(baseDWRIdx, handleDWR(sm))
return sm
}
Expand Down

0 comments on commit 3d07d93

Please sign in to comment.