diff --git a/charts/ephemeral/templates/ephemeral.yaml b/charts/ephemeral/templates/ephemeral.yaml index 6393cd87..de134269 100644 --- a/charts/ephemeral/templates/ephemeral.yaml +++ b/charts/ephemeral/templates/ephemeral.yaml @@ -77,6 +77,7 @@ metadata: data: config.json: |- { + "authUserIdField": "{{ .Values.ephemeral.authUserIdField }}", "retrySleep": "50ms", "networkEstablishTimeout": "{{ .Values.ephemeral.networkEstablishTimeout }}", "prime": "{{ .Values.ephemeral.spdz.prime }}", diff --git a/charts/ephemeral/values.yaml b/charts/ephemeral/values.yaml index 252a96f8..e5f4a285 100644 --- a/charts/ephemeral/values.yaml +++ b/charts/ephemeral/values.yaml @@ -47,6 +47,7 @@ ephemeral: cpu: minScale: 1 programIdentifier: "ephemeral-generic" + authUserIdField: "sub" opa: endpoint: "http://opa.default.svc.cluster.local:8081/" policyPackage: "carbynestack.def" diff --git a/cmd/ephemeral/main.go b/cmd/ephemeral/main.go index 2a4dc3af..af3883f3 100644 --- a/cmd/ephemeral/main.go +++ b/cmd/ephemeral/main.go @@ -64,14 +64,14 @@ func GetHandlerChain(conf *SPDZEngineConfig, logger *zap.SugaredLogger) (http.Ha if err != nil { return nil, err } - server := NewServer(spdzClient.Compile, spdzClient.Activate, logger, typedConfig) + server := NewServer(conf.AuthUserIdField, spdzClient.Compile, spdzClient.Activate, logger, typedConfig) activationHandler := http.HandlerFunc(server.ActivationHandler) // Apply in Order: // 1) MethodFilter: Check that only POST Requests can go through - // 2) BodyFilter: Check that Request Body is set properly and Sets the CtxConfig to the request + // 2) RequestFilter: Check that Request Body is set properly and Sets the CtxConfig to the request // 3) CompilationHandler: Compiles the script if ?compile=true // 4) ActivationHandler: Runs the script - filterChain := server.MethodFilter(server.BodyFilter(server.CompilationHandler(activationHandler))) + filterChain := server.MethodFilter(server.RequestFilter(server.CompilationHandler(activationHandler))) return filterChain, nil } diff --git a/pkg/ephemeral/io/feeder.go b/pkg/ephemeral/io/feeder.go index 0afcd7e1..07d7d1c2 100644 --- a/pkg/ephemeral/io/feeder.go +++ b/pkg/ephemeral/io/feeder.go @@ -73,16 +73,17 @@ func (f *AmphoraFeeder) LoadFromSecretStoreAndFeed(act *Activation, feedPort str data = append(data, osh.Data) } t := time.Now() - canExecute, err := f.conf.OpaClient.CanExecute( - map[string]interface{}{ - "subject": ctx.Spdz.ProgramIdentifier, - "inputs": inputs, - "time": map[string]interface{}{ - "formatted": t.String(), - "nano": t.UnixNano(), - }, - "playerCount": ctx.Spdz.PlayerCount, - }) + opaInput := map[string]interface{}{ + "subject": ctx.Spdz.ProgramIdentifier, + "executor": ctx.AuthorizedUser, + "inputs": inputs, + "time": map[string]interface{}{ + "formatted": t.String(), + "nano": t.UnixNano(), + }, + "playerCount": ctx.Spdz.PlayerCount, + } + canExecute, err := f.conf.OpaClient.CanExecute(opaInput) if err != nil { return nil, fmt.Errorf("failed to check if program can be executed: %w", err) } @@ -95,7 +96,7 @@ func (f *AmphoraFeeder) LoadFromSecretStoreAndFeed(act *Activation, feedPort str } // Write to amphora if required and return amphora secret ids. if act.Output.Type == AmphoraSecret { - ids, err := f.writeToAmphora(act, inputs, *resp) + ids, err := f.writeToAmphora(act, opaInput, *resp) if err != nil { return nil, err } @@ -114,7 +115,7 @@ func (f *AmphoraFeeder) LoadFromRequestAndFeed(act *Activation, feedPort string, } // Write to amphora if required and return amphora secret ids. if act.Output.Type == AmphoraSecret { - ids, err := f.writeToAmphora(act, []ActivationInput{}, *resp) + ids, err := f.writeToAmphora(act, map[string]interface{}{}, *resp) if err != nil { return nil, err } @@ -172,9 +173,9 @@ func (f *AmphoraFeeder) feedAndRead(params []string, feedPort string, ctx *CtxCo return f.carrier.Read(conv, isBulk) } -func (f *AmphoraFeeder) writeToAmphora(act *Activation, inputs []ActivationInput, resp Result) ([]string, error) { +func (f *AmphoraFeeder) writeToAmphora(act *Activation, opaInput map[string]interface{}, resp Result) ([]string, error) { client := f.conf.AmphoraClient - generatedTags, err := f.conf.OpaClient.GenerateTags(map[string]interface{}{"inputs": inputs}) + generatedTags, err := f.conf.OpaClient.GenerateTags(opaInput) if err != nil { return nil, fmt.Errorf("failed to generate tags for program output: %w", err) } diff --git a/pkg/ephemeral/player.go b/pkg/ephemeral/player.go index 54b96c0a..de7156b9 100644 --- a/pkg/ephemeral/player.go +++ b/pkg/ephemeral/player.go @@ -1,4 +1,4 @@ -// Copyright (c) 2021-2024 - for information on the respective copyright owner +// Copyright (c) 2021-2023 - for information on the respective copyright owner // see the NOTICE file and/or the repository https://github.com/carbynestack/ephemeral. // // SPDX-License-Identifier: Apache-2.0 @@ -49,7 +49,6 @@ func NewPlayer(ctx context.Context, bus mb.MessageBus, stateTimeout time.Duratio fsm.WhenIn(Init).GotEvent(Register).GoTo(Registering), fsm.WhenIn(Registering).GotEvent(PlayersReady).GoTo(Playing).WithTimeout(computationTimeout), fsm.WhenIn(Playing).GotEvent(PlayerFinishedWithSuccess).GoTo(PlayerFinishedWithSuccess), - fsm.WhenIn(Playing).GotEvent(PlayerFinishedWithError).GoTo(PlayerFinishedWithError), fsm.WhenIn(Playing).GotEvent(PlayingError).GoTo(PlayerFinishedWithError), fsm.WhenInAnyState().GotEvent(GameError).GoTo(PlayerFinishedWithError), fsm.WhenInAnyState().GotEvent(PlayerDone).GoTo(PlayerDone), diff --git a/pkg/ephemeral/server.go b/pkg/ephemeral/server.go index 2f7bf0a8..acce692a 100644 --- a/pkg/ephemeral/server.go +++ b/pkg/ephemeral/server.go @@ -1,4 +1,4 @@ -// Copyright (c) 2021-2023 - for information on the respective copyright owner +// Copyright (c) 2021-2024 - for information on the respective copyright owner // see the NOTICE file and/or the repository https://github.com/carbynestack/ephemeral. // // SPDX-License-Identifier: Apache-2.0 @@ -42,28 +42,32 @@ var ( ) // NewServer returns a new server. -func NewServer(compile func(*CtxConfig) error, activate func(*CtxConfig) ([]byte, error), logger *zap.SugaredLogger, config *SPDZEngineTypedConfig) *Server { +func NewServer(authUserIdField string, + compile func(*CtxConfig) error, + activate func(*CtxConfig) ([]byte, error), logger *zap.SugaredLogger, config *SPDZEngineTypedConfig) *Server { return &Server{ - player: &PlayerWithIO{}, - compile: compile, - activate: activate, - logger: logger, - config: config, - executor: NewCommander(), + authUserIdField: authUserIdField, + player: &PlayerWithIO{}, + compile: compile, + activate: activate, + logger: logger, + config: config, + executor: NewCommander(), } } // Server is a HTTP server which wraps the handling of incoming requests that trigger the MPC computation. type Server struct { - player AbstractPlayerWithIO - compile func(*CtxConfig) error - activate func(*CtxConfig) ([]byte, error) - logger *zap.SugaredLogger - config *SPDZEngineTypedConfig - respCh chan []byte - errCh chan error - execErrCh chan error - executor Executor + authUserIdField string + player AbstractPlayerWithIO + compile func(*CtxConfig) error + activate func(*CtxConfig) ([]byte, error) + logger *zap.SugaredLogger + config *SPDZEngineTypedConfig + respCh chan []byte + errCh chan error + execErrCh chan error + executor Executor } // MethodFilter assures that only HTTP POST requests are able to get through. @@ -89,11 +93,19 @@ func (s *Server) MethodFilter(next http.Handler) http.Handler { }) } -// BodyFilter verifies all necessary parameters are set in the request body. +// RequestFilter verifies all necessary headers and parameters in the request body are set. // Also sets the CtxConfig to the request -func (s *Server) BodyFilter(next http.Handler) http.Handler { +func (s *Server) RequestFilter(next http.Handler) http.Handler { return http.HandlerFunc(func(writer http.ResponseWriter, req *http.Request) { var act Activation + authorizedUser, err := GetUserFromAuthHeader(req.Header.Get("Authorization"), s.authUserIdField) + if err != nil { + msg := "unauthorized request" + writer.WriteHeader(http.StatusUnauthorized) + writer.Write([]byte(msg)) + s.logger.Errorw(msg, "Error", err) + return + } if req.Body == nil { msg := "request body is nil" writer.WriteHeader(http.StatusBadRequest) @@ -104,7 +116,7 @@ func (s *Server) BodyFilter(next http.Handler) http.Handler { bodyBytes, _ := ioutil.ReadAll(req.Body) req.Body.Close() req.Body = ioutil.NopCloser(bytes.NewBuffer(bodyBytes)) - err := json.Unmarshal(bodyBytes, &act) + err = json.Unmarshal(bodyBytes, &act) if err != nil { msg := "error decoding the request body" writer.WriteHeader(http.StatusBadRequest) @@ -147,8 +159,9 @@ func (s *Server) BodyFilter(next http.Handler) http.Handler { } con := context.Background() ctx := &CtxConfig{ - Act: &act, - Spdz: s.config, + AuthorizedUser: authorizedUser, + Act: &act, + Spdz: s.config, } con = context.WithValue(con, ctxConf, ctx) r := req.Clone(con) @@ -157,6 +170,43 @@ func (s *Server) BodyFilter(next http.Handler) http.Handler { }) } +func GetUserFromAuthHeader(header string, idField string) (string, error) { + token := strings.TrimPrefix(header, "Bearer ") + if token == "" { + return "", fmt.Errorf("no token provided") + } + return GetUserIDFromToken(token, idField) +} + +func GetUserIDFromToken(token string, field string) (string, error) { + jwtParts := strings.Split(token, ".") + if len(jwtParts) != 3 { + return "", fmt.Errorf("invalid JWT format") + } + jwt, err := base64.URLEncoding.WithPadding(base64.NoPadding).DecodeString(jwtParts[1]) + if err != nil { + return "", fmt.Errorf("error decoding JWT claims: %w", err) + } + var claimsMap map[string]interface{} + err = json.Unmarshal(jwt, &claimsMap) + if err != nil { + return "", fmt.Errorf("error unmarshalling JWT claims: %w", err) + } + var ok bool + path := strings.Split(field, ".") + for _, part := range path[:len(path)-1] { + claimsMap, ok = claimsMap[part].(map[string]interface{}) + if !ok { + return "", fmt.Errorf("field %s not found in JWT claims or invalid", part) + } + } + id, ok := claimsMap[path[len(path)-1]].(string) + if !ok { + return "", fmt.Errorf("field %s is not a string", field) + } + return id, nil +} + // CompilationHandler parses the JSON payload and adds it to the request context. func (s *Server) CompilationHandler(next http.Handler) http.Handler { return http.HandlerFunc(func(writer http.ResponseWriter, req *http.Request) { diff --git a/pkg/ephemeral/server_test.go b/pkg/ephemeral/server_test.go index 578ca6b6..e6c32c1c 100644 --- a/pkg/ephemeral/server_test.go +++ b/pkg/ephemeral/server_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2021-2023 - for information on the respective copyright owner +// Copyright (c) 2021-2024 - for information on the respective copyright owner // see the NOTICE file and/or the repository https://github.com/carbynestack/ephemeral. // // SPDX-License-Identifier: Apache-2.0 @@ -7,8 +7,10 @@ package ephemeral import ( "bytes" "context" + "encoding/base64" "encoding/json" "errors" + "fmt" "github.com/carbynestack/ephemeral/pkg/discovery/fsm" "time" @@ -33,6 +35,7 @@ var _ = Describe("Server", func() { ) const gameID = "71b2a100-f3f6-11e9-81b4-2a2ae2dbcce4" + authHeader := fmt.Sprintf("Bearer header.%s.signature", base64.StdEncoding.WithPadding(base64.NoPadding).EncodeToString([]byte(`{"sub":"someID"}`))) Context("when sending http requests", func() { BeforeEach(func() { act = &Activation{ @@ -49,7 +52,7 @@ var _ = Describe("Server", func() { StateTimeout: 10 * time.Second, NetworkEstablishTimeout: 10 * time.Second, } - s = NewServer(func(*CtxConfig) error { return nil }, func(*CtxConfig) ([]byte, error) { return nil, nil }, l, config) + s = NewServer("sub", func(*CtxConfig) error { return nil }, func(*CtxConfig) ([]byte, error) { return nil, nil }, l, config) }) Context("when going through body filter", func() { @@ -63,14 +66,16 @@ var _ = Describe("Server", func() { }) body, _ := json.Marshal(&act) req, _ := http.NewRequest("POST", "/", bytes.NewReader(body)) - s.BodyFilter(handler200).ServeHTTP(rr, req) + req.Header.Add("Authorization", authHeader) + s.RequestFilter(handler200).ServeHTTP(rr, req) }) Context("when the game id is not a valid UUID", func() { It("responds with 400 http code", func() { act.GameID = "123" body, _ := json.Marshal(act) req, _ := http.NewRequest("POST", "/", bytes.NewReader(body)) - s.BodyFilter(handler200).ServeHTTP(rr, req) + req.Header.Add("Authorization", authHeader) + s.RequestFilter(handler200).ServeHTTP(rr, req) respCode := rr.Code respBody := rr.Body.String() Expect(respCode).To(Equal(http.StatusBadRequest)) @@ -82,7 +87,8 @@ var _ = Describe("Server", func() { act.GameID = gameID body, _ := json.Marshal(&act) req, _ := http.NewRequest("POST", "/", bytes.NewReader(body)) - s.BodyFilter(handler200).ServeHTTP(rr, req) + req.Header.Add("Authorization", authHeader) + s.RequestFilter(handler200).ServeHTTP(rr, req) respCode := rr.Code Expect(respCode).To(Equal(http.StatusOK)) }) @@ -90,7 +96,8 @@ var _ = Describe("Server", func() { Context("when the body is empty", func() { It("returns a 400 response code", func() { req, _ := http.NewRequest("POST", "/", nil) - s.BodyFilter(handler200).ServeHTTP(rr, req) + req.Header.Add("Authorization", authHeader) + s.RequestFilter(handler200).ServeHTTP(rr, req) respCode := rr.Code respBody := rr.Body.String() Expect(respCode).To(Equal(http.StatusBadRequest)) @@ -102,7 +109,8 @@ var _ = Describe("Server", func() { body := []byte("a") checker := http.HandlerFunc(func(writer http.ResponseWriter, req *http.Request) {}) req, _ := http.NewRequest("POST", "/", bytes.NewReader(body)) - s.BodyFilter(checker).ServeHTTP(rr, req) + req.Header.Add("Authorization", authHeader) + s.RequestFilter(checker).ServeHTTP(rr, req) respCode := rr.Code respBody := rr.Body.String() Expect(respCode).To(Equal(http.StatusBadRequest)) @@ -115,6 +123,7 @@ var _ = Describe("Server", func() { Context("when a get request is being sent", func() { It("returns a 405 response code", func() { req, _ := http.NewRequest("GET", "/", nil) + req.Header.Add("Authorization", authHeader) s.MethodFilter(handler200).ServeHTTP(rr, req) respCode := rr.Code respBody := rr.Body.String() @@ -425,6 +434,44 @@ var _ = Describe("Server", func() { }) }) }) + + Context("when extracting authorization data form request", func() { + It("fails when bearer token is not provided", func() { + _, err := GetUserFromAuthHeader("", "sub") + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(Equal("no token provided")) + }) + It("fails when bearer token is not valid", func() { + _, err := GetUserFromAuthHeader("Bearer invalid.token", "sub") + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(Equal("invalid JWT format")) + }) + It("fails when jwt data field is invalid", func() { + _, err := GetUserFromAuthHeader( + fmt.Sprintf( + "Bearer header.%s.signature", + base64.StdEncoding.WithPadding(base64.NoPadding).EncodeToString([]byte("{"))), "sub") + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(HavePrefix("error unmarshalling JWT claims")) + }) + It("returns the user id when the token is valid", func() { + id, err := GetUserFromAuthHeader( + fmt.Sprintf( + "Bearer header.%s.signature", + base64.StdEncoding.WithPadding(base64.NoPadding).EncodeToString([]byte(`{"sub":"someID"}`))), "sub") + Expect(err).NotTo(HaveOccurred()) + Expect(id).To(Equal("someID")) + }) + It("returns the user id when field is nested", func() { + id, err := GetUserFromAuthHeader( + fmt.Sprintf( + "Bearer header.%s.signature", + base64.StdEncoding.WithPadding(base64.NoPadding).EncodeToString([]byte(`{"traits": {"email": "someMail"}}`))), "traits.email") + Expect(err).NotTo(HaveOccurred()) + Expect(id).To(Equal("someMail")) + }) + + }) }) var _ = Describe("PlayerWithIO", func() { diff --git a/pkg/types/types.go b/pkg/types/types.go index 207c4f7d..428cf776 100644 --- a/pkg/types/types.go +++ b/pkg/types/types.go @@ -84,16 +84,18 @@ type ProxyConfig struct { // CtxConfig contains both execution and platform specific parameters. type CtxConfig struct { - Act *Activation - Spdz *SPDZEngineTypedConfig - ProxyEntries []*ProxyConfig - ErrCh chan error - Context context.Context + AuthorizedUser string + Act *Activation + Spdz *SPDZEngineTypedConfig + ProxyEntries []*ProxyConfig + ErrCh chan error + Context context.Context } // SPDZEngineConfig is the VPC specific configuration. type SPDZEngineConfig struct { ProgramIdentifier string `json:"programIdentifier"` + AuthUserIdField string `json:"authUserIdField"` RetrySleep string `json:"retrySleep"` NetworkEstablishTimeout string `json:"networkEstablishTimeout"` Prime string `json:"prime"` @@ -160,6 +162,7 @@ type OutputConfig struct { // We need this type, since the default json decoder doesn't know how to deserialize big.Int. type SPDZEngineTypedConfig struct { ProgramIdentifier string + AuthUserIdField string RetrySleep time.Duration NetworkEstablishTimeout time.Duration Prime big.Int