Skip to content

Commit

Permalink
feat(service): add user info to opa evaluation
Browse files Browse the repository at this point in the history
Signed-off-by: Sebastian Becker <[email protected]>
  • Loading branch information
sbckr committed Dec 17, 2024
1 parent f948d37 commit 1604c39
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 53 deletions.
1 change: 1 addition & 0 deletions charts/ephemeral/templates/ephemeral.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ metadata:
data:
config.json: |-
{
"authUserIdField": "{{ .Values.ephemeral.authUserIdField }}",
"retrySleep": "50ms",
"networkEstablishTimeout": "{{ .Values.ephemeral.networkEstablishTimeout }}",
"prime": "{{ .Values.ephemeral.spdz.prime }}",
Expand Down
1 change: 1 addition & 0 deletions charts/ephemeral/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ ephemeral:
cpu:
minScale: 1
programIdentifier: "ephemeral-generic"
authUserIdField: "sub"
opa:
endpoint: "http://opa.default.svc.cluster.local:8081/"
policyPackage: "carbynestack.def"
Expand Down
6 changes: 3 additions & 3 deletions cmd/ephemeral/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@ func GetHandlerChain(conf *SPDZEngineConfig, logger *zap.SugaredLogger) (http.Ha
if err != nil {
return nil, err
}
server := NewServer(spdzClient.Compile, spdzClient.Activate, logger, typedConfig)
server := NewServer(conf.AuthUserIdField, spdzClient.Compile, spdzClient.Activate, logger, typedConfig)
activationHandler := http.HandlerFunc(server.ActivationHandler)
// Apply in Order:
// 1) MethodFilter: Check that only POST Requests can go through
// 2) BodyFilter: Check that Request Body is set properly and Sets the CtxConfig to the request
// 2) RequestFilter: Check that Request Body is set properly and Sets the CtxConfig to the request
// 3) CompilationHandler: Compiles the script if ?compile=true
// 4) ActivationHandler: Runs the script
filterChain := server.MethodFilter(server.BodyFilter(server.CompilationHandler(activationHandler)))
filterChain := server.MethodFilter(server.RequestFilter(server.CompilationHandler(activationHandler)))
return filterChain, nil
}

Expand Down
29 changes: 15 additions & 14 deletions pkg/ephemeral/io/feeder.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,17 @@ func (f *AmphoraFeeder) LoadFromSecretStoreAndFeed(act *Activation, feedPort str
data = append(data, osh.Data)
}
t := time.Now()
canExecute, err := f.conf.OpaClient.CanExecute(
map[string]interface{}{
"subject": ctx.Spdz.ProgramIdentifier,
"inputs": inputs,
"time": map[string]interface{}{
"formatted": t.String(),
"nano": t.UnixNano(),
},
"playerCount": ctx.Spdz.PlayerCount,
})
opaInput := map[string]interface{}{
"subject": ctx.Spdz.ProgramIdentifier,
"executor": ctx.AuthorizedUser,
"inputs": inputs,
"time": map[string]interface{}{
"formatted": t.String(),
"nano": t.UnixNano(),
},
"playerCount": ctx.Spdz.PlayerCount,
}
canExecute, err := f.conf.OpaClient.CanExecute(opaInput)
if err != nil {
return nil, fmt.Errorf("failed to check if program can be executed: %w", err)
}
Expand All @@ -95,7 +96,7 @@ func (f *AmphoraFeeder) LoadFromSecretStoreAndFeed(act *Activation, feedPort str
}
// Write to amphora if required and return amphora secret ids.
if act.Output.Type == AmphoraSecret {
ids, err := f.writeToAmphora(act, inputs, *resp)
ids, err := f.writeToAmphora(act, opaInput, *resp)
if err != nil {
return nil, err
}
Expand All @@ -114,7 +115,7 @@ func (f *AmphoraFeeder) LoadFromRequestAndFeed(act *Activation, feedPort string,
}
// Write to amphora if required and return amphora secret ids.
if act.Output.Type == AmphoraSecret {
ids, err := f.writeToAmphora(act, []ActivationInput{}, *resp)
ids, err := f.writeToAmphora(act, map[string]interface{}{}, *resp)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -172,9 +173,9 @@ func (f *AmphoraFeeder) feedAndRead(params []string, feedPort string, ctx *CtxCo
return f.carrier.Read(conv, isBulk)
}

func (f *AmphoraFeeder) writeToAmphora(act *Activation, inputs []ActivationInput, resp Result) ([]string, error) {
func (f *AmphoraFeeder) writeToAmphora(act *Activation, opaInput map[string]interface{}, resp Result) ([]string, error) {
client := f.conf.AmphoraClient
generatedTags, err := f.conf.OpaClient.GenerateTags(map[string]interface{}{"inputs": inputs})
generatedTags, err := f.conf.OpaClient.GenerateTags(opaInput)
if err != nil {
return nil, fmt.Errorf("failed to generate tags for program output: %w", err)
}
Expand Down
3 changes: 1 addition & 2 deletions pkg/ephemeral/player.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2021-2024 - for information on the respective copyright owner
// Copyright (c) 2021-2023 - for information on the respective copyright owner
// see the NOTICE file and/or the repository https://github.com/carbynestack/ephemeral.
//
// SPDX-License-Identifier: Apache-2.0
Expand Down Expand Up @@ -49,7 +49,6 @@ func NewPlayer(ctx context.Context, bus mb.MessageBus, stateTimeout time.Duratio
fsm.WhenIn(Init).GotEvent(Register).GoTo(Registering),
fsm.WhenIn(Registering).GotEvent(PlayersReady).GoTo(Playing).WithTimeout(computationTimeout),
fsm.WhenIn(Playing).GotEvent(PlayerFinishedWithSuccess).GoTo(PlayerFinishedWithSuccess),
fsm.WhenIn(Playing).GotEvent(PlayerFinishedWithError).GoTo(PlayerFinishedWithError),
fsm.WhenIn(Playing).GotEvent(PlayingError).GoTo(PlayerFinishedWithError),
fsm.WhenInAnyState().GotEvent(GameError).GoTo(PlayerFinishedWithError),
fsm.WhenInAnyState().GotEvent(PlayerDone).GoTo(PlayerDone),
Expand Down
94 changes: 72 additions & 22 deletions pkg/ephemeral/server.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2021-2023 - for information on the respective copyright owner
// Copyright (c) 2021-2024 - for information on the respective copyright owner
// see the NOTICE file and/or the repository https://github.com/carbynestack/ephemeral.
//
// SPDX-License-Identifier: Apache-2.0
Expand Down Expand Up @@ -42,28 +42,32 @@ var (
)

// NewServer returns a new server.
func NewServer(compile func(*CtxConfig) error, activate func(*CtxConfig) ([]byte, error), logger *zap.SugaredLogger, config *SPDZEngineTypedConfig) *Server {
func NewServer(authUserIdField string,
compile func(*CtxConfig) error,
activate func(*CtxConfig) ([]byte, error), logger *zap.SugaredLogger, config *SPDZEngineTypedConfig) *Server {
return &Server{
player: &PlayerWithIO{},
compile: compile,
activate: activate,
logger: logger,
config: config,
executor: NewCommander(),
authUserIdField: authUserIdField,
player: &PlayerWithIO{},
compile: compile,
activate: activate,
logger: logger,
config: config,
executor: NewCommander(),
}
}

// Server is a HTTP server which wraps the handling of incoming requests that trigger the MPC computation.
type Server struct {
player AbstractPlayerWithIO
compile func(*CtxConfig) error
activate func(*CtxConfig) ([]byte, error)
logger *zap.SugaredLogger
config *SPDZEngineTypedConfig
respCh chan []byte
errCh chan error
execErrCh chan error
executor Executor
authUserIdField string
player AbstractPlayerWithIO
compile func(*CtxConfig) error
activate func(*CtxConfig) ([]byte, error)
logger *zap.SugaredLogger
config *SPDZEngineTypedConfig
respCh chan []byte
errCh chan error
execErrCh chan error
executor Executor
}

// MethodFilter assures that only HTTP POST requests are able to get through.
Expand All @@ -89,11 +93,19 @@ func (s *Server) MethodFilter(next http.Handler) http.Handler {
})
}

// BodyFilter verifies all necessary parameters are set in the request body.
// RequestFilter verifies all necessary headers and parameters in the request body are set.
// Also sets the CtxConfig to the request
func (s *Server) BodyFilter(next http.Handler) http.Handler {
func (s *Server) RequestFilter(next http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, req *http.Request) {
var act Activation
authorizedUser, err := GetUserFromAuthHeader(req.Header.Get("Authorization"), s.authUserIdField)
if err != nil {
msg := "unauthorized request"
writer.WriteHeader(http.StatusUnauthorized)
writer.Write([]byte(msg))
s.logger.Errorw(msg, "Error", err)
return
}
if req.Body == nil {
msg := "request body is nil"
writer.WriteHeader(http.StatusBadRequest)
Expand All @@ -104,7 +116,7 @@ func (s *Server) BodyFilter(next http.Handler) http.Handler {
bodyBytes, _ := ioutil.ReadAll(req.Body)
req.Body.Close()
req.Body = ioutil.NopCloser(bytes.NewBuffer(bodyBytes))
err := json.Unmarshal(bodyBytes, &act)
err = json.Unmarshal(bodyBytes, &act)
if err != nil {
msg := "error decoding the request body"
writer.WriteHeader(http.StatusBadRequest)
Expand Down Expand Up @@ -147,8 +159,9 @@ func (s *Server) BodyFilter(next http.Handler) http.Handler {
}
con := context.Background()
ctx := &CtxConfig{
Act: &act,
Spdz: s.config,
AuthorizedUser: authorizedUser,
Act: &act,
Spdz: s.config,
}
con = context.WithValue(con, ctxConf, ctx)
r := req.Clone(con)
Expand All @@ -157,6 +170,43 @@ func (s *Server) BodyFilter(next http.Handler) http.Handler {
})
}

func GetUserFromAuthHeader(header string, idField string) (string, error) {
token := strings.TrimPrefix(header, "Bearer ")
if token == "" {
return "", fmt.Errorf("no token provided")
}
return GetUserIDFromToken(token, idField)
}

func GetUserIDFromToken(token string, field string) (string, error) {
jwtParts := strings.Split(token, ".")
if len(jwtParts) != 3 {
return "", fmt.Errorf("invalid JWT format")
}
jwt, err := base64.URLEncoding.WithPadding(base64.NoPadding).DecodeString(jwtParts[1])
if err != nil {
return "", fmt.Errorf("error decoding JWT claims: %w", err)
}
var claimsMap map[string]interface{}
err = json.Unmarshal(jwt, &claimsMap)
if err != nil {
return "", fmt.Errorf("error unmarshalling JWT claims: %w", err)
}
var ok bool
path := strings.Split(field, ".")
for _, part := range path[:len(path)-1] {
claimsMap, ok = claimsMap[part].(map[string]interface{})
if !ok {
return "", fmt.Errorf("field %s not found in JWT claims or invalid", part)
}
}
id, ok := claimsMap[path[len(path)-1]].(string)
if !ok {
return "", fmt.Errorf("field %s is not a string", field)
}
return id, nil
}

// CompilationHandler parses the JSON payload and adds it to the request context.
func (s *Server) CompilationHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, req *http.Request) {
Expand Down
61 changes: 54 additions & 7 deletions pkg/ephemeral/server_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2021-2023 - for information on the respective copyright owner
// Copyright (c) 2021-2024 - for information on the respective copyright owner
// see the NOTICE file and/or the repository https://github.com/carbynestack/ephemeral.
//
// SPDX-License-Identifier: Apache-2.0
Expand All @@ -7,8 +7,10 @@ package ephemeral
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"github.com/carbynestack/ephemeral/pkg/discovery/fsm"
"time"

Expand All @@ -33,6 +35,7 @@ var _ = Describe("Server", func() {
)

const gameID = "71b2a100-f3f6-11e9-81b4-2a2ae2dbcce4"
authHeader := fmt.Sprintf("Bearer header.%s.signature", base64.StdEncoding.WithPadding(base64.NoPadding).EncodeToString([]byte(`{"sub":"someID"}`)))
Context("when sending http requests", func() {
BeforeEach(func() {
act = &Activation{
Expand All @@ -49,7 +52,7 @@ var _ = Describe("Server", func() {
StateTimeout: 10 * time.Second,
NetworkEstablishTimeout: 10 * time.Second,
}
s = NewServer(func(*CtxConfig) error { return nil }, func(*CtxConfig) ([]byte, error) { return nil, nil }, l, config)
s = NewServer("sub", func(*CtxConfig) error { return nil }, func(*CtxConfig) ([]byte, error) { return nil, nil }, l, config)
})

Context("when going through body filter", func() {
Expand All @@ -63,14 +66,16 @@ var _ = Describe("Server", func() {
})
body, _ := json.Marshal(&act)
req, _ := http.NewRequest("POST", "/", bytes.NewReader(body))
s.BodyFilter(handler200).ServeHTTP(rr, req)
req.Header.Add("Authorization", authHeader)
s.RequestFilter(handler200).ServeHTTP(rr, req)
})
Context("when the game id is not a valid UUID", func() {
It("responds with 400 http code", func() {
act.GameID = "123"
body, _ := json.Marshal(act)
req, _ := http.NewRequest("POST", "/", bytes.NewReader(body))
s.BodyFilter(handler200).ServeHTTP(rr, req)
req.Header.Add("Authorization", authHeader)
s.RequestFilter(handler200).ServeHTTP(rr, req)
respCode := rr.Code
respBody := rr.Body.String()
Expect(respCode).To(Equal(http.StatusBadRequest))
Expand All @@ -82,15 +87,17 @@ var _ = Describe("Server", func() {
act.GameID = gameID
body, _ := json.Marshal(&act)
req, _ := http.NewRequest("POST", "/", bytes.NewReader(body))
s.BodyFilter(handler200).ServeHTTP(rr, req)
req.Header.Add("Authorization", authHeader)
s.RequestFilter(handler200).ServeHTTP(rr, req)
respCode := rr.Code
Expect(respCode).To(Equal(http.StatusOK))
})
})
Context("when the body is empty", func() {
It("returns a 400 response code", func() {
req, _ := http.NewRequest("POST", "/", nil)
s.BodyFilter(handler200).ServeHTTP(rr, req)
req.Header.Add("Authorization", authHeader)
s.RequestFilter(handler200).ServeHTTP(rr, req)
respCode := rr.Code
respBody := rr.Body.String()
Expect(respCode).To(Equal(http.StatusBadRequest))
Expand All @@ -102,7 +109,8 @@ var _ = Describe("Server", func() {
body := []byte("a")
checker := http.HandlerFunc(func(writer http.ResponseWriter, req *http.Request) {})
req, _ := http.NewRequest("POST", "/", bytes.NewReader(body))
s.BodyFilter(checker).ServeHTTP(rr, req)
req.Header.Add("Authorization", authHeader)
s.RequestFilter(checker).ServeHTTP(rr, req)
respCode := rr.Code
respBody := rr.Body.String()
Expect(respCode).To(Equal(http.StatusBadRequest))
Expand All @@ -115,6 +123,7 @@ var _ = Describe("Server", func() {
Context("when a get request is being sent", func() {
It("returns a 405 response code", func() {
req, _ := http.NewRequest("GET", "/", nil)
req.Header.Add("Authorization", authHeader)
s.MethodFilter(handler200).ServeHTTP(rr, req)
respCode := rr.Code
respBody := rr.Body.String()
Expand Down Expand Up @@ -425,6 +434,44 @@ var _ = Describe("Server", func() {
})
})
})

Context("when extracting authorization data form request", func() {
It("fails when bearer token is not provided", func() {
_, err := GetUserFromAuthHeader("", "sub")
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal("no token provided"))
})
It("fails when bearer token is not valid", func() {
_, err := GetUserFromAuthHeader("Bearer invalid.token", "sub")
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal("invalid JWT format"))
})
It("fails when jwt data field is invalid", func() {
_, err := GetUserFromAuthHeader(
fmt.Sprintf(
"Bearer header.%s.signature",
base64.StdEncoding.WithPadding(base64.NoPadding).EncodeToString([]byte("{"))), "sub")
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(HavePrefix("error unmarshalling JWT claims"))
})
It("returns the user id when the token is valid", func() {
id, err := GetUserFromAuthHeader(
fmt.Sprintf(
"Bearer header.%s.signature",
base64.StdEncoding.WithPadding(base64.NoPadding).EncodeToString([]byte(`{"sub":"someID"}`))), "sub")
Expect(err).NotTo(HaveOccurred())
Expect(id).To(Equal("someID"))
})
It("returns the user id when field is nested", func() {
id, err := GetUserFromAuthHeader(
fmt.Sprintf(
"Bearer header.%s.signature",
base64.StdEncoding.WithPadding(base64.NoPadding).EncodeToString([]byte(`{"traits": {"email": "someMail"}}`))), "traits.email")
Expect(err).NotTo(HaveOccurred())
Expect(id).To(Equal("someMail"))
})

})
})

var _ = Describe("PlayerWithIO", func() {
Expand Down
Loading

0 comments on commit 1604c39

Please sign in to comment.