From 7e3091b4727de0a493b9d007429c39006518e0d1 Mon Sep 17 00:00:00 2001 From: Grady Ward Date: Tue, 14 Nov 2023 13:25:52 -0700 Subject: [PATCH] WIP --- cmd/server/BUILD.bazel | 2 +- cmd/server/main.go | 59 +--------------- cmd/server/pactasrv/BUILD.bazel | 1 + cmd/server/pactasrv/initiative_invitation.go | 7 +- cmd/server/pactasrv/pactasrv.go | 6 +- cmd/server/pactasrv/user.go | 48 +++++++++++++ db/sqldb/user.go | 4 +- db/sqldb/user_test.go | 4 +- frontend/components/LinkButton.vue | 7 +- frontend/components/initiative/Toolbar.vue | 2 +- frontend/composables/useMSAL.ts | 30 ++++---- frontend/openapi/generated/pacta/index.ts | 4 +- ...folioRequest.ts => ProcessPortfolioReq.ts} | 2 +- ...lioResponse.ts => ProcessPortfolioResp.ts} | 2 +- .../pacta/services/DefaultService.ts | 30 ++++++-- openapi/pacta.yaml | 17 +++++ session/BUILD.bazel | 15 ++++ session/session.go | 70 +++++++++++++++++++ 18 files changed, 213 insertions(+), 97 deletions(-) rename frontend/openapi/generated/pacta/models/{ProcessPortfolioRequest.ts => ProcessPortfolioReq.ts} (80%) rename frontend/openapi/generated/pacta/models/{ProcessPortfolioResponse.ts => ProcessPortfolioResp.ts} (85%) create mode 100644 session/BUILD.bazel create mode 100644 session/session.go diff --git a/cmd/server/BUILD.bazel b/cmd/server/BUILD.bazel index 14c6f2d..8bc7f24 100644 --- a/cmd/server/BUILD.bazel +++ b/cmd/server/BUILD.bazel @@ -17,8 +17,8 @@ go_library( "//dockertask", "//oapierr", "//openapi:pacta_generated", - "//pacta", "//secrets", + "//session", "//task", "@com_github_azure_azure_sdk_for_go_sdk_azcore//:azcore", "@com_github_azure_azure_sdk_for_go_sdk_azidentity//:azidentity", diff --git a/cmd/server/main.go b/cmd/server/main.go index b008152..e690acd 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -23,8 +23,8 @@ import ( "github.com/RMI/pacta/dockertask" "github.com/RMI/pacta/oapierr" oapipacta "github.com/RMI/pacta/openapi/pacta" - "github.com/RMI/pacta/pacta" "github.com/RMI/pacta/secrets" + "github.com/RMI/pacta/session" "github.com/RMI/pacta/task" "github.com/Silicon-Ally/cryptorand" "github.com/Silicon-Ally/zaphttplog" @@ -318,7 +318,7 @@ func run(args []string) error { jwtauth.Verifier(jwtauth.New("EdDSA", nil, jwKey)), jwtauth.Authenticator, - addUserIdentityToContextIfLoggedIn(logger, db), + session.WithAuthn(logger, db), oapimiddleware.OapiRequestValidator(pactaSwagger), @@ -386,61 +386,6 @@ func rateLimitMiddleware(maxReq int, windowLength time.Duration) func(http.Handl })) } -func addUserIdentityToContextIfLoggedIn(logger *zap.Logger, db *sqldb.DB) func(http.Handler) http.Handler { - fn := func(c context.Context) (context.Context, error) { - token, _, err := jwtauth.FromContext(c) - if err != nil { - return nil, fmt.Errorf("error getting authorization token: %w", err) - } - if token == nil { - return nil, fmt.Errorf("nil authorization token") - } - emailsClaim, ok := token.PrivateClaims()["emails"] - if !ok { - return nil, fmt.Errorf("no email claim in token") - } - emails, ok := emailsClaim.([]interface{}) - if !ok || len(emails) == 0 { - return nil, fmt.Errorf("couldn't find email claim in token: %T", emailsClaim) - } - // TODO(#18) Handle Multiple Emails in the Token Claims gracefully - if len(emails) > 1 { - return nil, fmt.Errorf("multiple emails in token: %+v", emails) - } - email, ok := emails[0].(string) - if !ok { - return nil, fmt.Errorf("wrong type for email claim: %T", emails[0]) - } - canonical, err := pacta.CanonicalizeEmail(email) - if err != nil { - return nil, fmt.Errorf("invalid email on token: %q", email) - } - authnID := token.Subject() - if authnID == "" { - return nil, fmt.Errorf("couldn't find authn id in jwt") - } - user, err := db.GetOrCreateUserByAuthn(db.NoTxn(c), pacta.AuthnMechanism_EmailAndPass, authnID, email, canonical) - if err != nil { - return nil, fmt.Errorf("failed to get user by authn: %w", err) - } - return jwtauth.WithUserId(c, string(user.ID)), nil - } - - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx, err := fn(r.Context()) - if err != nil { - // Optionally log errors here when debugging authentication access. - // logger.Warn("couldn't authenticate", zap.Error(err)) - // http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) - next.ServeHTTP(w, r) - return - } - next.ServeHTTP(w, r.WithContext(ctx)) - }) - } -} - func findFirstInClaims(claims map[string]any, keys ...string) (string, error) { for _, k := range keys { v, ok := claims[k] diff --git a/cmd/server/pactasrv/BUILD.bazel b/cmd/server/pactasrv/BUILD.bazel index 086b5ef..d2e8c1e 100644 --- a/cmd/server/pactasrv/BUILD.bazel +++ b/cmd/server/pactasrv/BUILD.bazel @@ -20,6 +20,7 @@ go_library( "//oapierr", "//openapi:pacta_generated", "//pacta", + "//session", "//task", "@com_github_go_chi_jwtauth_v5//:jwtauth", "@com_github_google_uuid//:uuid", diff --git a/cmd/server/pactasrv/initiative_invitation.go b/cmd/server/pactasrv/initiative_invitation.go index 6d05b95..5b16434 100644 --- a/cmd/server/pactasrv/initiative_invitation.go +++ b/cmd/server/pactasrv/initiative_invitation.go @@ -62,12 +62,13 @@ func (s *Server) GetInitiativeInvitation(ctx context.Context, request api.GetIni } // Claims this initiative invitation, if it exists -// (POST /initiative-invitation/{id}) +// (POST /initiative-invitation/{id}:claim) func (s *Server) ClaimInitiativeInvitation(ctx context.Context, request api.ClaimInitiativeInvitationRequestObject) (api.ClaimInitiativeInvitationResponseObject, error) { userID, err := getUserID(ctx) if err != nil { return nil, err } + var customErr api.ClaimInitiativeInvitationResponseObject err = s.DB.Transactional(ctx, func(tx db.Tx) error { ii, err := s.DB.InitiativeInvitation(tx, pacta.InitiativeInvitationID(request.Id)) if err != nil { @@ -80,6 +81,7 @@ func (s *Server) ClaimInitiativeInvitation(ctx context.Context, request api.Clai // We may want to log this, though. return nil } else { + customErr = api.ClaimInitiativeInvitation409Response{} return fmt.Errorf("initiative is already used: %+v", ii) } } @@ -97,6 +99,9 @@ func (s *Server) ClaimInitiativeInvitation(ctx context.Context, request api.Clai return nil }) if err != nil { + if customErr != nil { + return customErr, nil + } return nil, oapierr.Internal("failed to claim initiative invitation", zap.Error(err)) } return api.ClaimInitiativeInvitation204Response{}, nil diff --git a/cmd/server/pactasrv/pactasrv.go b/cmd/server/pactasrv/pactasrv.go index 4311b36..06bfd13 100644 --- a/cmd/server/pactasrv/pactasrv.go +++ b/cmd/server/pactasrv/pactasrv.go @@ -8,8 +8,8 @@ import ( "github.com/RMI/pacta/db" "github.com/RMI/pacta/oapierr" "github.com/RMI/pacta/pacta" + "github.com/RMI/pacta/session" "github.com/RMI/pacta/task" - "github.com/go-chi/jwtauth/v5" "go.uber.org/zap" ) @@ -67,7 +67,7 @@ type DB interface { CreatePortfolioInitiativeMembership(tx db.Tx, pim *pacta.PortfolioInitiativeMembership) error DeletePortfolioInitiativeMembership(tx db.Tx, pid pacta.PortfolioID, iid pacta.InitiativeID) error - GetOrCreateUserByAuthn(tx db.Tx, authnMechanism pacta.AuthnMechanism, authnID, enteredEmail, canonicalEmail string) (*pacta.User, error) + GetOrCreateUserByAuthn(tx db.Tx, mech pacta.AuthnMechanism, authnID, email, canonicalEmail string) (*pacta.User, error) User(tx db.Tx, id pacta.UserID) (*pacta.User, error) Users(tx db.Tx, ids []pacta.UserID) (map[pacta.UserID]*pacta.User, error) UpdateUser(tx db.Tx, id pacta.UserID, mutations ...db.UpdateUserFn) error @@ -118,7 +118,7 @@ func dereference[T any](ts []*T, e error) ([]T, error) { } func getUserID(ctx context.Context) (pacta.UserID, error) { - userID, err := jwtauth.UserIDFromContext(ctx) + userID, err := session.UserIDFromContext(ctx) if err != nil { return "", oapierr.Unauthorized("error getting authorization token", zap.Error(err)) } diff --git a/cmd/server/pactasrv/user.go b/cmd/server/pactasrv/user.go index 4e8e204..a18da18 100644 --- a/cmd/server/pactasrv/user.go +++ b/cmd/server/pactasrv/user.go @@ -2,12 +2,14 @@ package pactasrv import ( "context" + "fmt" "github.com/RMI/pacta/cmd/server/pactasrv/conv" "github.com/RMI/pacta/db" "github.com/RMI/pacta/oapierr" api "github.com/RMI/pacta/openapi/pacta" "github.com/RMI/pacta/pacta" + "github.com/go-chi/jwtauth/v5" "go.uber.org/zap" ) @@ -84,3 +86,49 @@ func (s *Server) FindUserByMe(ctx context.Context, request api.FindUserByMeReque } return api.FindUserByMe200JSONResponse(*result), nil } + +// a callback after login to create or return the user +// (POST /user/authentication-followup) +func (s *Server) UserAuthenticationFollowup(ctx context.Context, _request api.UserAuthenticationFollowupRequestObject) (api.UserAuthenticationFollowupResponseObject, error) { + token, _, err := jwtauth.FromContext(ctx) + if err != nil { + return nil, oapierr.BadRequest("error getting authorization token", zap.Error(err)) + } + if token == nil { + return nil, oapierr.BadRequest("nil authorization token") + } + emailsClaim, ok := token.PrivateClaims()["emails"] + if !ok { + return nil, oapierr.BadRequest("no email claim in token") + } + emails, ok := emailsClaim.([]interface{}) + if !ok || len(emails) == 0 { + return nil, oapierr.BadRequest("couldn't find email claim in token", zap.String("emails_claim_type", fmt.Sprintf("%T", emailsClaim))) + } + // TODO(#18) Handle Multiple Emails in the Token Claims gracefully + if len(emails) > 1 { + return nil, oapierr.BadRequest(fmt.Sprintf("multiple emails in token: %+v", emails)) + } + email, ok := emails[0].(string) + if !ok { + return nil, oapierr.BadRequest("wrong type for email claim", zap.String("email_claim_type", fmt.Sprintf("%T", emails[0]))) + } + canonical, err := pacta.CanonicalizeEmail(email) + if err != nil { + return nil, oapierr.BadRequest(fmt.Sprintf("invalid email: %q", email), zap.String("email", email), zap.Error(err)) + } + authnID := token.Subject() + if authnID == "" { + return nil, oapierr.BadRequest("couldn't find authn id in jwt") + } + user, err := s.DB.GetOrCreateUserByAuthn(s.DB.NoTxn(ctx), pacta.AuthnMechanism_EmailAndPass, authnID, email, canonical) + if err != nil { + return nil, fmt.Errorf("failed to GetOrCreateUser by authn: %w", err) + } + result, err := conv.UserToOAPI(user) + if err != nil { + return nil, err + } + return api.UserAuthenticationFollowup200JSONResponse(*result), nil + +} diff --git a/db/sqldb/user.go b/db/sqldb/user.go index fb6cdca..c0412e0 100644 --- a/db/sqldb/user.go +++ b/db/sqldb/user.go @@ -37,7 +37,7 @@ func (d *DB) User(tx db.Tx, id pacta.UserID) (*pacta.User, error) { return exactlyOne("user", id, us) } -func (d *DB) userByAuthn(tx db.Tx, authnMechanism pacta.AuthnMechanism, authnID string) (*pacta.User, error) { +func (d *DB) UserByAuthn(tx db.Tx, authnMechanism pacta.AuthnMechanism, authnID string) (*pacta.User, error) { rows, err := d.query(tx, ` SELECT `+userSelectColumns+` FROM pacta_user @@ -55,7 +55,7 @@ func (d *DB) userByAuthn(tx db.Tx, authnMechanism pacta.AuthnMechanism, authnID func (d *DB) GetOrCreateUserByAuthn(tx db.Tx, authnMechanism pacta.AuthnMechanism, authnID, enteredEmail, canonicalEmail string) (*pacta.User, error) { var user *pacta.User err := d.RunOrContinueTransaction(tx, func(tx db.Tx) error { - u, err := d.userByAuthn(tx, authnMechanism, authnID) + u, err := d.UserByAuthn(tx, authnMechanism, authnID) if err == nil { user = u return nil diff --git a/db/sqldb/user_test.go b/db/sqldb/user_test.go index fc8e3e9..dc9752e 100644 --- a/db/sqldb/user_test.go +++ b/db/sqldb/user_test.go @@ -40,7 +40,7 @@ func TestcreateUser(t *testing.T) { } // Read by Authn - actual, err = tdb.userByAuthn(tx, u.AuthnMechanism, u.AuthnID) + actual, err = tdb.UserByAuthn(tx, u.AuthnMechanism, u.AuthnID) if err != nil { t.Fatalf("getting user by authn: %w", err) } @@ -230,7 +230,7 @@ func TestDeleteUser(t *testing.T) { } // Read by Authn - _, err = tdb.userByAuthn(tx, u.AuthnMechanism, u.AuthnID) + _, err = tdb.UserByAuthn(tx, u.AuthnMechanism, u.AuthnID) if err == nil { t.Fatalf("expected error, got nil") } diff --git a/frontend/components/LinkButton.vue b/frontend/components/LinkButton.vue index 8a3fb8e..e6ba870 100644 --- a/frontend/components/LinkButton.vue +++ b/frontend/components/LinkButton.vue @@ -112,11 +112,8 @@ const buttonClass = computed(() => { 'p-button-loading-label-only': props.loading !== undefined && props.icon === undefined && props.label !== undefined, 'no-underline': true, 'click-does-nothing': disabled.value, - } - if (isActive.value) { - result[props.activeClass] = true - } else { - result[props.inactiveClass] = true + [props.activeClass]: isActive.value, + [props.inactiveClass]: !isActive.value, } return result }) diff --git a/frontend/components/initiative/Toolbar.vue b/frontend/components/initiative/Toolbar.vue index 15ec69a..575bad9 100644 --- a/frontend/components/initiative/Toolbar.vue +++ b/frontend/components/initiative/Toolbar.vue @@ -17,7 +17,7 @@ const props = defineProps() const isManager = computed(() => { const mm = maybeMe.value - return (!!mm && props.initiativeUserRelationships.some(r => r.manager && r.userId === mm.id)) || true + return (!!mm && props.initiativeUserRelationships.some(r => r.manager && r.userId === mm.id)) }) const isMember = computed(() => { const mm = maybeMe.value diff --git a/frontend/composables/useMSAL.ts b/frontend/composables/useMSAL.ts index 91f50e1..bac7bfa 100644 --- a/frontend/composables/useMSAL.ts +++ b/frontend/composables/useMSAL.ts @@ -35,7 +35,7 @@ export const useMSAL = async () => { } const router = useRouter() - const { userClientWithAuth } = useAPI() + const { userClientWithAuth, pactaClientWithAuth } = useAPI() const localePath = useLocalePath() const { $msal: { msalConfig, b2cPolicies } } = useNuxtApp() @@ -210,19 +210,6 @@ export const useMSAL = async () => { return filteredAccounts[0] }) - const signIn = () => { - if (!instance.value) { - return Promise.reject(new Error('MSAL instance was not yet initialized')) - } - - const req = { scopes } - return instance.value.loginPopup(req) - .then(handleResponse) - .catch((err) => { - console.log('useMSAL.loginPopup', err) - }) - } - const getToken = () => { if (!instance.value) { return Promise.reject(new Error('MSAL instance was not yet initialized')) @@ -252,6 +239,21 @@ export const useMSAL = async () => { .then(handleResponse) } + const signIn = () => { + if (!instance.value) { + return Promise.reject(new Error('MSAL instance was not yet initialized')) + } + + const req = { scopes } + return instance.value.loginPopup(req) + .then(handleResponse) + .then(getToken) + .then(token => pactaClientWithAuth(token.idToken).userAuthenticationFollowup()) + .catch((err) => { + console.log('useMSAL.loginPopup', err) + }) + } + const createAPIKey = (): Promise => { return getToken() .then((response) => { diff --git a/frontend/openapi/generated/pacta/index.ts b/frontend/openapi/generated/pacta/index.ts index d8a6515..01cd498 100644 --- a/frontend/openapi/generated/pacta/index.ts +++ b/frontend/openapi/generated/pacta/index.ts @@ -23,8 +23,8 @@ export type { NewPortfolioAsset } from './models/NewPortfolioAsset'; export type { PactaVersion } from './models/PactaVersion'; export type { PactaVersionChanges } from './models/PactaVersionChanges'; export type { PactaVersionCreate } from './models/PactaVersionCreate'; -export type { ProcessPortfolioRequest } from './models/ProcessPortfolioRequest'; -export type { ProcessPortfolioResponse } from './models/ProcessPortfolioResponse'; +export type { ProcessPortfolioReq } from './models/ProcessPortfolioReq'; +export type { ProcessPortfolioResp } from './models/ProcessPortfolioResp'; export { User } from './models/User'; export { UserChanges } from './models/UserChanges'; diff --git a/frontend/openapi/generated/pacta/models/ProcessPortfolioRequest.ts b/frontend/openapi/generated/pacta/models/ProcessPortfolioReq.ts similarity index 80% rename from frontend/openapi/generated/pacta/models/ProcessPortfolioRequest.ts rename to frontend/openapi/generated/pacta/models/ProcessPortfolioReq.ts index c19be5f..7ec3df0 100644 --- a/frontend/openapi/generated/pacta/models/ProcessPortfolioRequest.ts +++ b/frontend/openapi/generated/pacta/models/ProcessPortfolioReq.ts @@ -3,7 +3,7 @@ /* tslint:disable */ /* eslint-disable */ -export type ProcessPortfolioRequest = { +export type ProcessPortfolioReq = { asset_ids: Array; }; diff --git a/frontend/openapi/generated/pacta/models/ProcessPortfolioResponse.ts b/frontend/openapi/generated/pacta/models/ProcessPortfolioResp.ts similarity index 85% rename from frontend/openapi/generated/pacta/models/ProcessPortfolioResponse.ts rename to frontend/openapi/generated/pacta/models/ProcessPortfolioResp.ts index 9ad4bb3..7cb606e 100644 --- a/frontend/openapi/generated/pacta/models/ProcessPortfolioResponse.ts +++ b/frontend/openapi/generated/pacta/models/ProcessPortfolioResp.ts @@ -3,7 +3,7 @@ /* tslint:disable */ /* eslint-disable */ -export type ProcessPortfolioResponse = { +export type ProcessPortfolioResp = { /** * The ID of the async task for processing the portfoio */ diff --git a/frontend/openapi/generated/pacta/services/DefaultService.ts b/frontend/openapi/generated/pacta/services/DefaultService.ts index dc3f93f..8ceb3dd 100644 --- a/frontend/openapi/generated/pacta/services/DefaultService.ts +++ b/frontend/openapi/generated/pacta/services/DefaultService.ts @@ -13,8 +13,8 @@ import type { NewPortfolioAsset } from '../models/NewPortfolioAsset'; import type { PactaVersion } from '../models/PactaVersion'; import type { PactaVersionChanges } from '../models/PactaVersionChanges'; import type { PactaVersionCreate } from '../models/PactaVersionCreate'; -import type { ProcessPortfolioRequest } from '../models/ProcessPortfolioRequest'; -import type { ProcessPortfolioResponse } from '../models/ProcessPortfolioResponse'; +import type { ProcessPortfolioReq } from '../models/ProcessPortfolioReq'; +import type { ProcessPortfolioResp } from '../models/ProcessPortfolioResp'; import type { User } from '../models/User'; import type { UserChanges } from '../models/UserChanges'; @@ -324,10 +324,13 @@ export class DefaultService { ): CancelablePromise { return this.httpRequest.request({ method: 'POST', - url: '/initiative-invitation/{id}', + url: '/initiative-invitation/{id}:claim', path: { 'id': id, }, + errors: { + 409: `initiative invitation already claimed`, + }, }); } @@ -343,7 +346,7 @@ export class DefaultService { ): CancelablePromise { return this.httpRequest.request({ method: 'DELETE', - url: '/initiative-invitation/{id}', + url: '/initiative-invitation/{id}:claim', path: { 'id': id, }, @@ -410,6 +413,19 @@ export class DefaultService { }); } + /** + * a callback after login to create or return the user + * Creates a user in the database, if the user does not yet exist, or returns the existing user. + * @returns User user response + * @throws ApiError + */ + public userAuthenticationFollowup(): CancelablePromise { + return this.httpRequest.request({ + method: 'POST', + url: '/user/authentication-followup', + }); + } + /** * Returns a user by ID * Returns a user based on a single ID @@ -492,12 +508,12 @@ export class DefaultService { * Starts processing raw uploaded files * * @param requestBody The raw portfolio files to process - * @returns ProcessPortfolioResponse The task has been started successfully + * @returns ProcessPortfolioResp The task has been started successfully * @throws ApiError */ public processPortfolio( - requestBody: ProcessPortfolioRequest, - ): CancelablePromise { + requestBody: ProcessPortfolioReq, + ): CancelablePromise { return this.httpRequest.request({ method: 'POST', url: '/test:processPortfolio', diff --git a/openapi/pacta.yaml b/openapi/pacta.yaml index 35f0680..8746c17 100644 --- a/openapi/pacta.yaml +++ b/openapi/pacta.yaml @@ -27,6 +27,7 @@ definitions: - fr - es - de +basePath: /v1 paths: /pacta-version/{id}: get: @@ -300,6 +301,7 @@ paths: application/json: schema: $ref: '#/components/schemas/InitiativeInvitation' + /initiative-invitation/{id}:claim: post: summary: Claims this initiative invitation, if it exists operationId: claimInitiativeInvitation @@ -313,6 +315,8 @@ paths: responses: '204': description: initiative invitation claimed successfully + '409': + description: initiative invitation already claimed delete: summary: Deletes an initiative invitation by id @@ -392,6 +396,19 @@ paths: application/json: schema: $ref: '#/components/schemas/User' + + /user/authentication-followup: + post: + description: Creates a user in the database, if the user does not yet exist, or returns the existing user. + summary: a callback after login to create or return the user + operationId: userAuthenticationFollowup + responses: + '200': + description: user response + content: + application/json: + schema: + $ref: '#/components/schemas/User' /user/{id}: get: diff --git a/session/BUILD.bazel b/session/BUILD.bazel new file mode 100644 index 0000000..b9eadc6 --- /dev/null +++ b/session/BUILD.bazel @@ -0,0 +1,15 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "session", + srcs = ["session.go"], + importpath = "github.com/RMI/pacta/session", + visibility = ["//visibility:public"], + deps = [ + "//db", + "//oapierr", + "//pacta", + "@com_github_go_chi_jwtauth_v5//:jwtauth", + "@org_uber_go_zap//:zap", + ], +) diff --git a/session/session.go b/session/session.go new file mode 100644 index 0000000..4c22233 --- /dev/null +++ b/session/session.go @@ -0,0 +1,70 @@ +package session + +import ( + "context" + "fmt" + "net/http" + + "github.com/RMI/pacta/db" + "github.com/RMI/pacta/oapierr" + "github.com/RMI/pacta/pacta" + "github.com/go-chi/jwtauth/v5" + "go.uber.org/zap" +) + +var userIDKey = struct{}{} + +func WithUserID(c context.Context, id pacta.UserID) context.Context { + return context.WithValue(c, userIDKey, id) +} + +type DB interface { + NoTxn(context.Context) db.Tx + UserByAuthn(tx db.Tx, mech pacta.AuthnMechanism, authnID string) (*pacta.User, error) +} + +func UserIDFromContext(ctx context.Context) (pacta.UserID, error) { + userID, ok := ctx.Value(userIDKey).(pacta.UserID) + if !ok { + return "", oapierr.Unauthorized("no user id in context") + } + if userID == "" { + return "", oapierr.Unauthorized("empty user id in context") + } + return userID, nil +} + +func WithAuthn(logger *zap.Logger, d DB) func(http.Handler) http.Handler { + fn := func(c context.Context) (context.Context, error) { + token, _, err := jwtauth.FromContext(c) + if err != nil { + return nil, fmt.Errorf("error getting authorization token: %w", err) + } + if token == nil { + return nil, fmt.Errorf("nil authorization token") + } + authnID := token.Subject() + if authnID == "" { + return nil, fmt.Errorf("couldn't find authn id in jwt") + } + user, err := d.UserByAuthn(d.NoTxn(c), pacta.AuthnMechanism_EmailAndPass, authnID) + if err != nil { + return nil, fmt.Errorf("failed to get user by authn: %w", err) + } + return WithUserID(c, user.ID), nil + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx, err := fn(r.Context()) + if err != nil { + // Optionally log errors here when debugging authentication access. + // logger.Warn("couldn't authenticate", zap.Error(err)) + // http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + next.ServeHTTP(w, r) + return + } + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +}