diff --git a/diam/sm/dwr_test.go b/diam/sm/dwr_test.go index 1926077..e58659d 100644 --- a/diam/sm/dwr_test.go +++ b/diam/sm/dwr_test.go @@ -80,6 +80,102 @@ func TestHandleDWR(t *testing.T) { } } +func TestHandleDWRWithMessageHandlers(t *testing.T) { + // Force clone hence the dereference + modServerSettings := *serverSettings + var requestProxyCalledForCER = false + var requestProxyCalledForCEA = false + var requestProxyCalledForDWR = false + modServerSettings.RequestHandlerProxy = func(f HandlerFuncProxy) HandlerFuncProxy { + return func(c diam.Conn, m *diam.Message) error { + if m.Header.CommandCode == diam.CapabilitiesExchange { + requestProxyCalledForCER = true + } + if m.Header.CommandCode == diam.DeviceWatchdog { + requestProxyCalledForDWR = true + } + return f(c, m) + } + } + modServerSettings.AnswerHandlerProxy = func(f HandlerFuncProxy) HandlerFuncProxy { + return func(c diam.Conn, m *diam.Message) error { + if m.Header.CommandCode == diam.CapabilitiesExchange { + requestProxyCalledForCEA = true + } + return f(c, m) + } + } + sm := New(&modServerSettings) + srv := diamtest.NewServer(sm, dict.Default) + defer srv.Close() + mc := make(chan *diam.Message, 1) + mux := diam.NewServeMux() + mux.HandleFunc("CEA", func(c diam.Conn, m *diam.Message) { + mc <- m + }) + mux.HandleFunc("DWA", func(c diam.Conn, m *diam.Message) { + mc <- m + }) + cli, err := diam.Dial(srv.Addr, mux, dict.Default) + if err != nil { + t.Fatal(err) + } + defer cli.Close() + // Send CER first. + m := diam.NewRequest(diam.CapabilitiesExchange, 1001, dict.Default) + m.NewAVP(avp.OriginHost, avp.Mbit, 0, clientSettings.OriginHost) + m.NewAVP(avp.OriginRealm, avp.Mbit, 0, clientSettings.OriginRealm) + m.NewAVP(avp.HostIPAddress, avp.Mbit, 0, localhostAddress) + m.NewAVP(avp.VendorID, avp.Mbit, 0, clientSettings.VendorID) + m.NewAVP(avp.ProductName, 0, 0, clientSettings.ProductName) + m.NewAVP(avp.OriginStateID, avp.Mbit, 0, datatype.Unsigned32(1)) + m.NewAVP(avp.AcctApplicationID, avp.Mbit, 0, datatype.Unsigned32(1001)) + m.NewAVP(avp.FirmwareRevision, 0, 0, clientSettings.FirmwareRevision) + _, err = m.WriteTo(cli) + if err != nil { + t.Fatal(err) + } + select { + case resp := <-mc: + if !testResultCode(resp, diam.Success) { + t.Fatalf("Unexpected result code for CEA.\n%s", resp) + } + if !requestProxyCalledForCER { + t.Fatalf("request handler proxy not called for CER") + } + if !requestProxyCalledForCEA { + t.Fatalf("request handler proxy not called for CEA") + } + case err := <-mux.ErrorReports(): + t.Fatal(err) + case <-time.After(time.Second): + t.Fatal("No CEA received") + } + + // Send DWR. + m = diam.NewRequest(diam.DeviceWatchdog, 0, dict.Default) + m.NewAVP(avp.OriginHost, avp.Mbit, 0, clientSettings.OriginHost) + m.NewAVP(avp.OriginRealm, avp.Mbit, 0, clientSettings.OriginRealm) + m.NewAVP(avp.OriginStateID, avp.Mbit, 0, datatype.Unsigned32(1)) + _, err = m.WriteTo(cli) + if err != nil { + t.Fatal(err) + } + select { + case resp := <-mc: + if !testResultCode(resp, diam.Success) { + t.Fatalf("Unexpected result code for DWA.\n%s", resp) + } + if !requestProxyCalledForDWR { + t.Fatalf("answer handler proxy not called for DWR") + } + case err := <-mux.ErrorReports(): + t.Fatal(err) + case <-time.After(time.Second): + t.Fatal("No DWA received") + } +} + func TestHandleDWR_Fail(t *testing.T) { sm := New(serverSettings) srv := diamtest.NewServer(sm, dict.Default) diff --git a/diam/sm/sm.go b/diam/sm/sm.go index a2c8598..2c37640 100644 --- a/diam/sm/sm.go +++ b/diam/sm/sm.go @@ -122,10 +122,16 @@ func New(settings *Settings) *StateMachine { return nil })(c, m) } + var dwrHandlerProxy diam.HandlerFunc = func(c diam.Conn, m *diam.Message) { + _ = settings.RequestHandlerProxy(func(c diam.Conn, m *diam.Message) error { + handleDWR(sm)(c, m) + return nil + })(c, m) + } sm.mux.Handle("CER", cerHandlerProxy) - sm.mux.Handle("DWR", handshakeOK(handleDWR(sm))) + sm.mux.Handle("DWR", handshakeOK(dwrHandlerProxy)) sm.mux.HandleIdx(baseCERIdx, cerHandlerProxy) - sm.mux.HandleIdx(baseDWRIdx, handleDWR(sm)) + sm.mux.HandleIdx(baseDWRIdx, dwrHandlerProxy) return sm }