-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a couple of the required handlers.
- Loading branch information
Showing
4 changed files
with
420 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
package auth | ||
|
||
import ( | ||
"errors" | ||
"fmt" | ||
"strings" | ||
|
||
"github.com/xmidt-org/wrp-go/v3" | ||
"github.com/xmidt-org/xmidt-agent/internal/wrpkit" | ||
) | ||
|
||
var ( | ||
ErrInvalidInput = fmt.Errorf("invalid input") | ||
ErrUnauthorized = fmt.Errorf("unauthorized") | ||
) | ||
|
||
const ( | ||
// statusCode is the status code to return when a message is not authorized. | ||
statusCode = 403 | ||
|
||
// wildcard is the wildcard partner id that matches all partner ids. | ||
wildcard = "*" | ||
) | ||
|
||
// Handler sends a response when a message is required to have a response, but | ||
// was not handled by the next handler in the chain. | ||
type Handler struct { | ||
next wrpkit.Handler | ||
egress wrpkit.Handler | ||
source string | ||
partners []string | ||
} | ||
|
||
// New creates a new instance of the Handler struct. The parameter next is the | ||
// handler that will be called and monitored for errors. The parameter egress is | ||
// the handler that will be called to send the response if/when the next handler | ||
// fails to handle the message. The parameter source is the source to use in | ||
// the response message. The list of partners is the list of allowed partners. | ||
func New(next, egress wrpkit.Handler, source string, partners ...string) (*Handler, error) { | ||
h := Handler{ | ||
next: next, | ||
egress: egress, | ||
source: source, | ||
partners: make([]string, 0, len(partners)), | ||
} | ||
|
||
for _, partner := range partners { | ||
partner = strings.TrimSpace(partner) | ||
if partner != "" { | ||
h.partners = append(h.partners, partner) | ||
} | ||
} | ||
|
||
if h.next == nil || h.egress == nil || h.source == "" || len(h.partners) == 0 { | ||
return nil, ErrInvalidInput | ||
} | ||
|
||
return &h, nil | ||
} | ||
|
||
// HandleWrp is called to process a message. If the message is not from an allowed | ||
// partner, a response is sent to the source of the message if applicable. | ||
func (h Handler) HandleWrp(msg wrp.Message) error { | ||
for _, allowed := range h.partners { | ||
for _, got := range msg.PartnerIDs { | ||
got = strings.TrimSpace(got) | ||
if allowed == got || allowed == wildcard { | ||
// We found a match, so continue processing the message. | ||
return h.next.HandleWrp(msg) | ||
} | ||
} | ||
} | ||
|
||
// At this point, the message is not from an allowed partner, so send a | ||
// response if needed. Otherwise, return an error. | ||
|
||
if !msg.Type.RequiresTransaction() { | ||
return ErrUnauthorized | ||
} | ||
|
||
got := strings.Join(msg.PartnerIDs, "','") | ||
want := strings.Join(h.partners, "','") | ||
|
||
response := msg | ||
response.Destination = msg.Source | ||
response.Source = h.source | ||
response.ContentType = "text/plain" | ||
response.Payload = []byte(fmt.Sprintf("Partner(s) '%s' not allowed. Allowed: '%s'", got, want)) | ||
|
||
code := int64(statusCode) | ||
response.Status = &code | ||
|
||
sendErr := h.egress.HandleWrp(response) | ||
|
||
return errors.Join(ErrUnauthorized, sendErr) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
package auth_test | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
"github.com/xmidt-org/wrp-go/v3" | ||
"github.com/xmidt-org/xmidt-agent/internal/wrphandlers/auth" | ||
"github.com/xmidt-org/xmidt-agent/internal/wrpkit" | ||
) | ||
|
||
func TestHandler_HandleWrp(t *testing.T) { | ||
//randomErr := errors.New("random error") | ||
|
||
tests := []struct { | ||
description string | ||
nextResult error | ||
nextCallCount int | ||
egressResult error | ||
egressCallCount int | ||
partner string | ||
partners []string | ||
msg wrp.Message | ||
expectedErr error | ||
validate func(wrp.Message) error | ||
}{ | ||
{ | ||
description: "normal message, good auth", | ||
nextCallCount: 1, | ||
msg: wrp.Message{ | ||
Type: wrp.SimpleEventMessageType, | ||
Source: "dns:tr1d1um.example.com/service/ignored", | ||
Destination: "event:event_1/ignored", | ||
PartnerIDs: []string{"example-partner"}, | ||
}, | ||
partner: "example-partner", | ||
}, { | ||
description: "normal message, wildcard auth", | ||
nextCallCount: 1, | ||
msg: wrp.Message{ | ||
Type: wrp.SimpleEventMessageType, | ||
Source: "dns:tr1d1um.example.com/service/ignored", | ||
Destination: "event:event_1/ignored", | ||
PartnerIDs: []string{"example-partner"}, | ||
}, | ||
partner: "*", | ||
}, { | ||
description: "partner not allowed, no response needed", | ||
msg: wrp.Message{ | ||
Type: wrp.SimpleEventMessageType, | ||
Source: "dns:tr1d1um.example.com/service/ignored", | ||
Destination: "event:event_1/ignored", | ||
PartnerIDs: []string{"example-partner"}, | ||
}, | ||
partner: "some-other-partner", | ||
expectedErr: auth.ErrUnauthorized, | ||
}, { | ||
egressCallCount: 1, | ||
description: "partner not allowed, response needed", | ||
msg: wrp.Message{ | ||
Type: wrp.SimpleRequestResponseMessageType, | ||
Source: "dns:tr1d1um.example.com/service/ignored", | ||
Destination: "mac:112233445566/service", | ||
TransactionUUID: "1234", | ||
PartnerIDs: []string{"example-partner"}, | ||
}, | ||
partner: "some-other-partner", | ||
expectedErr: auth.ErrUnauthorized, | ||
}, { | ||
egressCallCount: 1, | ||
description: "no partner provided, response needed", | ||
msg: wrp.Message{ | ||
Type: wrp.SimpleRequestResponseMessageType, | ||
Source: "dns:tr1d1um.example.com/service/ignored", | ||
Destination: "mac:112233445566/service", | ||
TransactionUUID: "1234", | ||
}, | ||
partner: "some-partner", | ||
expectedErr: auth.ErrUnauthorized, | ||
}, | ||
} | ||
for _, tc := range tests { | ||
t.Run(tc.description, func(t *testing.T) { | ||
assert := assert.New(t) | ||
require := require.New(t) | ||
|
||
nextCallCount := 0 | ||
next := wrpkit.HandlerFunc(func(wrp.Message) error { | ||
nextCallCount++ | ||
return tc.nextResult | ||
}) | ||
|
||
egressCallCount := 0 | ||
egress := wrpkit.HandlerFunc(func(wrp.Message) error { | ||
egressCallCount++ | ||
if tc.validate != nil { | ||
assert.NoError(tc.validate(tc.msg)) | ||
} | ||
return tc.egressResult | ||
}) | ||
|
||
partners := append(tc.partners, tc.partner) | ||
|
||
h, err := auth.New(next, egress, "self:/xmidt-agent/missing", partners...) | ||
require.NoError(err) | ||
|
||
err = h.HandleWrp(tc.msg) | ||
assert.ErrorIs(err, tc.expectedErr) | ||
|
||
assert.Equal(tc.nextCallCount, nextCallCount) | ||
assert.Equal(tc.egressCallCount, egressCallCount) | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
package missing | ||
|
||
import ( | ||
"errors" | ||
"fmt" | ||
|
||
"github.com/xmidt-org/wrp-go/v3" | ||
"github.com/xmidt-org/xmidt-agent/internal/wrpkit" | ||
) | ||
|
||
var ( | ||
ErrInvalidInput = fmt.Errorf("invalid input") | ||
) | ||
|
||
const ( | ||
// statusCode is the status code to return when a message is missing a handler. | ||
statusCode = 531 | ||
) | ||
|
||
// Handler sends a response when a message is required to have a response, but | ||
// was not handled by the next handler in the chain. | ||
type Handler struct { | ||
next wrpkit.Handler | ||
egress wrpkit.Handler | ||
source string | ||
} | ||
|
||
// New creates a new instance of the Handler struct. The parameter next is the | ||
// handler that will be called and monitored for errors. The parameter egress is | ||
// the handler that will be called to send the response if/when the next handler | ||
// fails to handle the message. The parameter source is the source to use in | ||
// the response message. | ||
func New(next, egress wrpkit.Handler, source string) (*Handler, error) { | ||
if next == nil || egress == nil || source == "" { | ||
return nil, ErrInvalidInput | ||
} | ||
|
||
return &Handler{ | ||
next: next, | ||
egress: egress, | ||
source: source, | ||
}, nil | ||
} | ||
|
||
// HandleWrp is called to process a message. If the next handler fails to | ||
// process the message, a response is sent to the source of the message. | ||
func (h Handler) HandleWrp(msg wrp.Message) error { | ||
err := h.next.HandleWrp(msg) | ||
if err == nil { | ||
return nil | ||
} | ||
|
||
if !msg.Type.RequiresTransaction() { | ||
return err | ||
} | ||
|
||
// If the error is not ErrNotHandled, return the error. | ||
if !errors.Is(err, wrpkit.ErrNotHandled) { | ||
return err | ||
} | ||
|
||
// Consume the error since we are handling it here. | ||
err = nil | ||
|
||
// At this point, we know that a response is required, but the next handler | ||
// failed to process the message, or didn't have a handler for it. | ||
response := msg | ||
response.Destination = msg.Source | ||
response.Source = h.source | ||
response.Payload = nil | ||
|
||
code := int64(statusCode) | ||
response.Status = &code | ||
|
||
sendErr := h.egress.HandleWrp(response) | ||
|
||
return errors.Join(err, sendErr) | ||
} |
Oops, something went wrong.