diff --git a/diam/sm/dwr.go b/diam/sm/dwr.go index a435b64..d0b2a41 100644 --- a/diam/sm/dwr.go +++ b/diam/sm/dwr.go @@ -39,13 +39,17 @@ func handleDWR(sm *StateMachine) diam.HandlerFunc { stateid := datatype.Unsigned32(sm.cfg.OriginStateID) m.NewAVP(avp.OriginStateID, avp.Mbit, 0, stateid) } - _, err = a.WriteTo(c) - if err != nil { - sm.Error(&diam.ErrorReport{ - Conn: c, - Message: m, - Error: err, - }) - } + + sm.cfg.AnswerHandlerProxy(func(conn diam.Conn, answer *diam.Message) error { + _, err = answer.WriteTo(conn) + if err != nil { + sm.Error(&diam.ErrorReport{ + Conn: conn, + Message: answer, + Error: err, + }) + } + return err + })(c, a) } } diff --git a/diam/sm/dwr_test.go b/diam/sm/dwr_test.go index e58659d..bfe163e 100644 --- a/diam/sm/dwr_test.go +++ b/diam/sm/dwr_test.go @@ -86,6 +86,7 @@ func TestHandleDWRWithMessageHandlers(t *testing.T) { var requestProxyCalledForCER = false var requestProxyCalledForCEA = false var requestProxyCalledForDWR = false + var requestProxyCalledForDWA = false modServerSettings.RequestHandlerProxy = func(f HandlerFuncProxy) HandlerFuncProxy { return func(c diam.Conn, m *diam.Message) error { if m.Header.CommandCode == diam.CapabilitiesExchange { @@ -102,6 +103,9 @@ func TestHandleDWRWithMessageHandlers(t *testing.T) { if m.Header.CommandCode == diam.CapabilitiesExchange { requestProxyCalledForCEA = true } + if m.Header.CommandCode == diam.DeviceWatchdog { + requestProxyCalledForDWA = true + } return f(c, m) } } @@ -167,7 +171,10 @@ func TestHandleDWRWithMessageHandlers(t *testing.T) { t.Fatalf("Unexpected result code for DWA.\n%s", resp) } if !requestProxyCalledForDWR { - t.Fatalf("answer handler proxy not called for DWR") + t.Fatalf("request handler proxy not called for DWR") + } + if !requestProxyCalledForDWA { + t.Fatalf("answer handler proxy not called for DWA") } case err := <-mux.ErrorReports(): t.Fatal(err)