diff --git a/internal/wrphandlers/auth/handler.go b/internal/wrphandlers/auth/handler.go new file mode 100644 index 0000000..331ed0d --- /dev/null +++ b/internal/wrphandlers/auth/handler.go @@ -0,0 +1,99 @@ +// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package auth + +import ( + "errors" + "fmt" + "strings" + + "github.com/xmidt-org/wrp-go/v3" + "github.com/xmidt-org/xmidt-agent/internal/wrpkit" +) + +var ( + ErrInvalidInput = fmt.Errorf("invalid input") + ErrUnauthorized = fmt.Errorf("unauthorized") +) + +const ( + // statusCode is the status code to return when a message is not authorized. + statusCode = 403 + + // wildcard is the wildcard partner id that matches all partner ids. + wildcard = "*" +) + +// Handler sends a response when a message is required to have a response, but +// was not handled by the next handler in the chain. +type Handler struct { + next wrpkit.Handler + egress wrpkit.Handler + source string + partners []string +} + +// New creates a new instance of the Handler struct. The parameter next is the +// handler that will be called and monitored for errors. The parameter egress is +// the handler that will be called to send the response if/when the next handler +// fails to handle the message. The parameter source is the source to use in +// the response message. The list of partners is the list of allowed partners. +func New(next, egress wrpkit.Handler, source string, partners ...string) (*Handler, error) { + h := Handler{ + next: next, + egress: egress, + source: source, + partners: make([]string, 0, len(partners)), + } + + for _, partner := range partners { + partner = strings.TrimSpace(partner) + if partner != "" { + h.partners = append(h.partners, partner) + } + } + + if h.next == nil || h.egress == nil || h.source == "" || len(h.partners) == 0 { + return nil, ErrInvalidInput + } + + return &h, nil +} + +// HandleWrp is called to process a message. If the message is not from an allowed +// partner, a response is sent to the source of the message if applicable. +func (h Handler) HandleWrp(msg wrp.Message) error { + for _, allowed := range h.partners { + for _, got := range msg.PartnerIDs { + got = strings.TrimSpace(got) + if allowed == got || allowed == wildcard { + // We found a match, so continue processing the message. + return h.next.HandleWrp(msg) + } + } + } + + // At this point, the message is not from an allowed partner, so send a + // response if needed. Otherwise, return an error. + + if !msg.Type.RequiresTransaction() { + return ErrUnauthorized + } + + got := strings.Join(msg.PartnerIDs, "','") + want := strings.Join(h.partners, "','") + + response := msg + response.Destination = msg.Source + response.Source = h.source + response.ContentType = "text/plain" + response.Payload = []byte(fmt.Sprintf("Partner(s) '%s' not allowed. Allowed: '%s'", got, want)) + + code := int64(statusCode) + response.Status = &code + + sendErr := h.egress.HandleWrp(response) + + return errors.Join(ErrUnauthorized, sendErr) +} diff --git a/internal/wrphandlers/auth/handler_test.go b/internal/wrphandlers/auth/handler_test.go new file mode 100644 index 0000000..1e9830d --- /dev/null +++ b/internal/wrphandlers/auth/handler_test.go @@ -0,0 +1,118 @@ +// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package auth_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/xmidt-org/wrp-go/v3" + "github.com/xmidt-org/xmidt-agent/internal/wrphandlers/auth" + "github.com/xmidt-org/xmidt-agent/internal/wrpkit" +) + +func TestHandler_HandleWrp(t *testing.T) { + //randomErr := errors.New("random error") + + tests := []struct { + description string + nextResult error + nextCallCount int + egressResult error + egressCallCount int + partner string + partners []string + msg wrp.Message + expectedErr error + validate func(wrp.Message) error + }{ + { + description: "normal message, good auth", + nextCallCount: 1, + msg: wrp.Message{ + Type: wrp.SimpleEventMessageType, + Source: "dns:tr1d1um.example.com/service/ignored", + Destination: "event:event_1/ignored", + PartnerIDs: []string{"example-partner"}, + }, + partner: "example-partner", + }, { + description: "normal message, wildcard auth", + nextCallCount: 1, + msg: wrp.Message{ + Type: wrp.SimpleEventMessageType, + Source: "dns:tr1d1um.example.com/service/ignored", + Destination: "event:event_1/ignored", + PartnerIDs: []string{"example-partner"}, + }, + partner: "*", + }, { + description: "partner not allowed, no response needed", + msg: wrp.Message{ + Type: wrp.SimpleEventMessageType, + Source: "dns:tr1d1um.example.com/service/ignored", + Destination: "event:event_1/ignored", + PartnerIDs: []string{"example-partner"}, + }, + partner: "some-other-partner", + expectedErr: auth.ErrUnauthorized, + }, { + egressCallCount: 1, + description: "partner not allowed, response needed", + msg: wrp.Message{ + Type: wrp.SimpleRequestResponseMessageType, + Source: "dns:tr1d1um.example.com/service/ignored", + Destination: "mac:112233445566/service", + TransactionUUID: "1234", + PartnerIDs: []string{"example-partner"}, + }, + partner: "some-other-partner", + expectedErr: auth.ErrUnauthorized, + }, { + egressCallCount: 1, + description: "no partner provided, response needed", + msg: wrp.Message{ + Type: wrp.SimpleRequestResponseMessageType, + Source: "dns:tr1d1um.example.com/service/ignored", + Destination: "mac:112233445566/service", + TransactionUUID: "1234", + }, + partner: "some-partner", + expectedErr: auth.ErrUnauthorized, + }, + } + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + nextCallCount := 0 + next := wrpkit.HandlerFunc(func(wrp.Message) error { + nextCallCount++ + return tc.nextResult + }) + + egressCallCount := 0 + egress := wrpkit.HandlerFunc(func(wrp.Message) error { + egressCallCount++ + if tc.validate != nil { + assert.NoError(tc.validate(tc.msg)) + } + return tc.egressResult + }) + + partners := append(tc.partners, tc.partner) + + h, err := auth.New(next, egress, "self:/xmidt-agent/missing", partners...) + require.NoError(err) + + err = h.HandleWrp(tc.msg) + assert.ErrorIs(err, tc.expectedErr) + + assert.Equal(tc.nextCallCount, nextCallCount) + assert.Equal(tc.egressCallCount, egressCallCount) + }) + } +} diff --git a/internal/wrphandlers/missing/handler.go b/internal/wrphandlers/missing/handler.go new file mode 100644 index 0000000..48d6d9e --- /dev/null +++ b/internal/wrphandlers/missing/handler.go @@ -0,0 +1,81 @@ +// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package missing + +import ( + "errors" + "fmt" + + "github.com/xmidt-org/wrp-go/v3" + "github.com/xmidt-org/xmidt-agent/internal/wrpkit" +) + +var ( + ErrInvalidInput = fmt.Errorf("invalid input") +) + +const ( + // statusCode is the status code to return when a message is missing a handler. + statusCode = 531 +) + +// Handler sends a response when a message is required to have a response, but +// was not handled by the next handler in the chain. +type Handler struct { + next wrpkit.Handler + egress wrpkit.Handler + source string +} + +// New creates a new instance of the Handler struct. The parameter next is the +// handler that will be called and monitored for errors. The parameter egress is +// the handler that will be called to send the response if/when the next handler +// fails to handle the message. The parameter source is the source to use in +// the response message. +func New(next, egress wrpkit.Handler, source string) (*Handler, error) { + if next == nil || egress == nil || source == "" { + return nil, ErrInvalidInput + } + + return &Handler{ + next: next, + egress: egress, + source: source, + }, nil +} + +// HandleWrp is called to process a message. If the next handler fails to +// process the message, a response is sent to the source of the message. +func (h Handler) HandleWrp(msg wrp.Message) error { + err := h.next.HandleWrp(msg) + if err == nil { + return nil + } + + if !msg.Type.RequiresTransaction() { + return err + } + + // If the error is not ErrNotHandled, return the error. + if !errors.Is(err, wrpkit.ErrNotHandled) { + return err + } + + // Consume the error since we are handling it here. + err = nil + + // At this point, we know that a response is required, but the next handler + // failed to process the message, or didn't have a handler for it. + response := msg + response.Destination = msg.Source + response.Source = h.source + response.Payload = nil + + code := int64(statusCode) + response.Status = &code + + sendErr := h.egress.HandleWrp(response) + + return errors.Join(err, sendErr) +} diff --git a/internal/wrphandlers/missing/handler_test.go b/internal/wrphandlers/missing/handler_test.go new file mode 100644 index 0000000..5fb2a78 --- /dev/null +++ b/internal/wrphandlers/missing/handler_test.go @@ -0,0 +1,122 @@ +// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package missing_test + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/xmidt-org/wrp-go/v3" + "github.com/xmidt-org/xmidt-agent/internal/wrphandlers/missing" + "github.com/xmidt-org/xmidt-agent/internal/wrpkit" +) + +func TestHandler_HandleWrp(t *testing.T) { + randomErr := errors.New("random error") + + tests := []struct { + description string + nextResult error + nextCallCount int + egressResult error + egressCallCount int + msg wrp.Message + expectedErr error + validate func(wrp.Message) error + }{ + { + description: "normal message", + nextCallCount: 1, + msg: wrp.Message{ + Type: wrp.SimpleEventMessageType, + Source: "dns:tr1d1um.example.com/service/ignored", + Destination: "event:event_1/ignored", + }, + }, { + description: "error with an msg that doesn't require a response (random error)", + nextCallCount: 1, + nextResult: randomErr, + msg: wrp.Message{ + Type: wrp.SimpleEventMessageType, + Source: "dns:tr1d1um.example.com/service/ignored", + Destination: "event:event_1/ignored", + }, + expectedErr: randomErr, + }, { + description: "error with an msg that doesn't require a response (no handler)", + nextCallCount: 1, + nextResult: wrpkit.ErrNotHandled, + msg: wrp.Message{ + Type: wrp.SimpleEventMessageType, + Source: "dns:tr1d1um.example.com/service/ignored", + Destination: "event:event_1/ignored", + }, + expectedErr: wrpkit.ErrNotHandled, + }, { + description: "error with an msg that requires a response, but was handled", + nextCallCount: 1, + nextResult: randomErr, + msg: wrp.Message{ + Type: wrp.SimpleRequestResponseMessageType, + Source: "dns:tr1d1um.example.com/service/ignored", + Destination: "mac:112233445566/some-service", + }, + expectedErr: randomErr, + }, { + description: "unhandled message, but requires a response", + nextCallCount: 1, + nextResult: wrpkit.ErrNotHandled, + egressCallCount: 1, + msg: wrp.Message{ + Type: wrp.SimpleRequestResponseMessageType, + Source: "dns:tr1d1um.example.com/service/ignored", + Destination: "mac:112233445566/some-service", + }, + }, { + description: "unhandled message, requires a response, error sending response", + nextCallCount: 1, + nextResult: wrpkit.ErrNotHandled, + egressCallCount: 1, + egressResult: randomErr, + msg: wrp.Message{ + Type: wrp.SimpleRequestResponseMessageType, + Source: "dns:tr1d1um.example.com/service/ignored", + Destination: "mac:112233445566/some-service", + }, + expectedErr: randomErr, + }, + } + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + nextCallCount := 0 + next := wrpkit.HandlerFunc(func(wrp.Message) error { + nextCallCount++ + return tc.nextResult + }) + + egressCallCount := 0 + egress := wrpkit.HandlerFunc(func(wrp.Message) error { + egressCallCount++ + if tc.validate != nil { + assert.NoError(tc.validate(tc.msg)) + } + return tc.egressResult + }) + + h, err := missing.New(next, egress, "self:/xmidt-agent/missing") + require.NoError(err) + + err = h.HandleWrp(tc.msg) + assert.ErrorIs(err, tc.expectedErr) + + assert.Equal(tc.nextCallCount, nextCallCount) + assert.Equal(tc.egressCallCount, egressCallCount) + }) + } +}