Skip to content

Commit

Permalink
handler proxy
Browse files Browse the repository at this point in the history
  • Loading branch information
gergogera committed Feb 15, 2023
1 parent eb67928 commit e4b37b1
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 30 deletions.
9 changes: 7 additions & 2 deletions diam/sm/cer.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ func errorCEA(sm *StateMachine, c diam.Conn, m *diam.Message, cer *smparser.CER,
a.NewAVP(avp.FirmwareRevision, 0, 0, sm.cfg.FirmwareRevision)
}

err = sm.cfg.CeaWriterProxy(c, a)
sm.cfg.AnswerHandlerProxy(func(conn diam.Conn, answer *diam.Message) {
_, err = answer.WriteTo(conn)
})(c, a)
if err != nil {
err = fmt.Errorf("Error CEA '%s' send failure: %v", errMessage, err)
}
Expand Down Expand Up @@ -166,5 +168,8 @@ func successCEA(sm *StateMachine, c diam.Conn, m *diam.Message, cer *smparser.CE
a.NewAVP(avp.FirmwareRevision, 0, 0, sm.cfg.FirmwareRevision)
}

return sm.cfg.CeaWriterProxy(c, a)
sm.cfg.AnswerHandlerProxy(func(conn diam.Conn, answer *diam.Message) {
_, err = answer.WriteTo(conn)
})(c, a)
return err
}
25 changes: 13 additions & 12 deletions diam/sm/cer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,19 @@ func testHandleCER_HandshakeMetadata(t *testing.T, network string) {
func TestHandleCER_CerHook(t *testing.T) {
// Force clone hence the dereference
modServerSettings := *serverSettings
var handlerCalled = false
var writerCalled = false
modServerSettings.CerHandlerProxy = func(f diam.HandlerFunc) diam.HandlerFunc {
var requestProxyCalled = false
var writerProxyCalled = false
modServerSettings.RequestHandlerProxy = func(f diam.HandlerFunc) diam.HandlerFunc {
return func(c diam.Conn, m *diam.Message) {
f(c, m)
handlerCalled = true
requestProxyCalled = true
}
}
modServerSettings.CeaWriterProxy = func(c diam.Conn, m *diam.Message) error {
_, err := m.WriteTo(c)
writerCalled = true
return err
modServerSettings.AnswerHandlerProxy = func(f diam.HandlerFunc) diam.HandlerFunc {
return func(c diam.Conn, m *diam.Message) {
f(c, m)
writerProxyCalled = true
}
}

sm := New(&modServerSettings)
Expand Down Expand Up @@ -135,11 +136,11 @@ func TestHandleCER_CerHook(t *testing.T) {
t.Fatalf("Unexpected OriginRealm. Want %q, have %q",
clientSettings.OriginRealm, meta.OriginRealm)
}
if !handlerCalled {
t.Fatalf("message handler proxy not called")
if !requestProxyCalled {
t.Fatalf("request handler proxy not called")
}
if !writerCalled {
t.Fatalf("message writer proxy not called")
if !writerProxyCalled {
t.Fatalf("answer handler proxy not called")
}
}

Expand Down
26 changes: 10 additions & 16 deletions diam/sm/sm.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,13 @@ func PrepareSupportedApps(d *dict.Parser) []*SupportedApp {
}

type MessageHandlerProxy = func(diam.HandlerFunc) diam.HandlerFunc
type MessageWriterProxy = func(diam.Conn, *diam.Message) error

var defaultProxy = func(f diam.HandlerFunc) diam.HandlerFunc {
var defaultHandlerProxy = func(f diam.HandlerFunc) diam.HandlerFunc {
return func(c diam.Conn, m *diam.Message) {
f(c, m)
}
}

var defaultWriterProxy = func(c diam.Conn, m *diam.Message) error {
_, err := m.WriteTo(c)
return err
}

// Settings used to configure the state machine with AVPs to be added
// to CER on clients or CEA on servers.
type Settings struct {
Expand Down Expand Up @@ -81,9 +75,9 @@ type Settings struct {
HostIPAddresses []datatype.Address
//
// Deprecated: HostIPAddress is depreciated, use HostIPAddresses instead
HostIPAddress datatype.Address
CerHandlerProxy MessageHandlerProxy
CeaWriterProxy MessageWriterProxy
HostIPAddress datatype.Address
RequestHandlerProxy MessageHandlerProxy
AnswerHandlerProxy MessageHandlerProxy
}

var (
Expand All @@ -109,11 +103,11 @@ func New(settings *Settings) *StateMachine {
if len(settings.HostIPAddresses) == 0 && len(settings.HostIPAddress) > 0 {
settings.HostIPAddresses = []datatype.Address{settings.HostIPAddress}
}
if settings.CerHandlerProxy == nil {
settings.CerHandlerProxy = defaultProxy
if settings.RequestHandlerProxy == nil {
settings.RequestHandlerProxy = defaultHandlerProxy
}
if settings.CeaWriterProxy == nil {
settings.CeaWriterProxy = defaultWriterProxy
if settings.AnswerHandlerProxy == nil {
settings.AnswerHandlerProxy = defaultHandlerProxy
}

sm := &StateMachine{
Expand All @@ -122,9 +116,9 @@ func New(settings *Settings) *StateMachine {
hsNotifyc: make(chan diam.Conn),
supportedApps: PrepareSupportedApps(dict.Default),
}
sm.mux.Handle("CER", settings.CerHandlerProxy(handleCER(sm)))
sm.mux.Handle("CER", settings.RequestHandlerProxy(handleCER(sm)))
sm.mux.Handle("DWR", handshakeOK(handleDWR(sm)))
sm.mux.HandleIdx(baseCERIdx, settings.CerHandlerProxy(handleCER(sm)))
sm.mux.HandleIdx(baseCERIdx, settings.RequestHandlerProxy(handleCER(sm)))
sm.mux.HandleIdx(baseDWRIdx, handleDWR(sm))
return sm
}
Expand Down

0 comments on commit e4b37b1

Please sign in to comment.