-
Notifications
You must be signed in to change notification settings - Fork 3
/
middleware.go
238 lines (206 loc) · 8.54 KB
/
middleware.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
/*
Copyright © 2024 Acronis International GmbH.
Released under MIT license.
*/
package authkit
import (
"context"
"errors"
"net/http"
"strings"
"github.com/acronis/go-appkit/httpserver/middleware"
"github.com/acronis/go-appkit/log"
"github.com/acronis/go-appkit/restapi"
"github.com/acronis/go-authkit/idptoken"
"github.com/acronis/go-authkit/internal/idputil"
"github.com/acronis/go-authkit/jwt"
)
// HeaderAuthorization contains the name of HTTP header with data that is used for authentication and authorization.
const HeaderAuthorization = "Authorization"
// Authentication and authorization error codes.
// We are using "var" here because some services may want to use different error codes.
var (
ErrCodeBearerTokenMissing = "bearerTokenMissing"
ErrCodeAuthenticationFailed = "authenticationFailed"
ErrCodeAuthorizationFailed = "authorizationFailed"
)
// Authentication error messages.
// We are using "var" here because some services may want to use different error messages.
var (
ErrMessageBearerTokenMissing = "Authorization bearer token is missing."
ErrMessageAuthenticationFailed = "Authentication is failed."
ErrMessageAuthorizationFailed = "Authorization is failed."
)
type ctxKey int
const (
ctxKeyJWTClaims ctxKey = iota
ctxKeyBearerToken
)
// JWTParser is an interface for parsing string representation of JWT.
type JWTParser interface {
Parse(ctx context.Context, token string) (jwt.Claims, error)
}
// CachingJWTParser does the same as JWTParser but stores parsed JWT claims in cache.
type CachingJWTParser interface {
JWTParser
InvalidateCache(ctx context.Context)
}
// TokenIntrospector is an interface for introspecting tokens.
type TokenIntrospector interface {
IntrospectToken(ctx context.Context, token string) (idptoken.IntrospectionResult, error)
}
type jwtAuthHandler struct {
next http.Handler
errorDomain string
jwtParser JWTParser
verifyAccess func(r *http.Request, claims jwt.Claims) bool
tokenIntrospector TokenIntrospector
loggerProvider func(ctx context.Context) log.FieldLogger
}
type jwtAuthMiddlewareOpts struct {
verifyAccess func(r *http.Request, claims jwt.Claims) bool
tokenIntrospector TokenIntrospector
loggerProvider func(ctx context.Context) log.FieldLogger
}
// JWTAuthMiddlewareOption is an option for JWTAuthMiddleware.
type JWTAuthMiddlewareOption func(options *jwtAuthMiddlewareOpts)
// WithJWTAuthMiddlewareVerifyAccess is an option to set a function that verifies access for JWTAuthMiddleware.
func WithJWTAuthMiddlewareVerifyAccess(verifyAccess func(r *http.Request, claims jwt.Claims) bool) JWTAuthMiddlewareOption {
return func(options *jwtAuthMiddlewareOpts) {
options.verifyAccess = verifyAccess
}
}
// WithJWTAuthMiddlewareTokenIntrospector is an option to set a token introspector for JWTAuthMiddleware.
func WithJWTAuthMiddlewareTokenIntrospector(tokenIntrospector TokenIntrospector) JWTAuthMiddlewareOption {
return func(options *jwtAuthMiddlewareOpts) {
options.tokenIntrospector = tokenIntrospector
}
}
// WithJWTAuthMiddlewareLoggerProvider is an option to set a logger provider for JWTAuthMiddleware.
func WithJWTAuthMiddlewareLoggerProvider(loggerProvider func(ctx context.Context) log.FieldLogger) JWTAuthMiddlewareOption {
return func(options *jwtAuthMiddlewareOpts) {
options.loggerProvider = loggerProvider
}
}
// JWTAuthMiddleware is a middleware that does authentication
// by Access Token from the "Authorization" HTTP header of incoming request.
// errorDomain is used for error responses. It is usually the name of the service that uses the middleware,
// and its goal is distinguishing errors from different services.
// It helps to understand where the error occurred and what service caused it.
// For example, if the "Authorization" HTTP header is missing, the middleware will return 401 with the following response body:
//
// {"error": {"domain": "MyService", "code": "bearerTokenMissing", "message": "Authorization bearer token is missing."}}
func JWTAuthMiddleware(errorDomain string, jwtParser JWTParser, opts ...JWTAuthMiddlewareOption) func(next http.Handler) http.Handler {
options := jwtAuthMiddlewareOpts{loggerProvider: middleware.GetLoggerFromContext}
for _, opt := range opts {
opt(&options)
}
return func(next http.Handler) http.Handler {
return &jwtAuthHandler{
next: next,
errorDomain: errorDomain,
jwtParser: jwtParser,
verifyAccess: options.verifyAccess,
tokenIntrospector: options.tokenIntrospector,
loggerProvider: options.loggerProvider,
}
}
}
func (h *jwtAuthHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
reqCtx := r.Context()
bearerToken := GetBearerTokenFromRequest(r)
if bearerToken == "" {
apiErr := restapi.NewError(h.errorDomain, ErrCodeBearerTokenMissing, ErrMessageBearerTokenMissing)
restapi.RespondError(rw, http.StatusUnauthorized, apiErr, h.logger(reqCtx))
return
}
var jwtClaims jwt.Claims
if h.tokenIntrospector != nil {
if introspectionResult, err := h.tokenIntrospector.IntrospectToken(reqCtx, bearerToken); err != nil {
switch {
case errors.Is(err, idptoken.ErrTokenIntrospectionNotNeeded):
// Do nothing. Access Token already contains all necessary information for authN/authZ.
h.logger(reqCtx).AtLevel(log.LevelDebug, func(logFunc log.LogFunc) {
logFunc("token's introspection is not needed")
})
case errors.Is(err, idptoken.ErrTokenNotIntrospectable):
// Token is not introspectable by some reason.
// In this case, we will parse it as JWT and use it for authZ.
h.logger(reqCtx).Warn("token is not introspectable, it will be used for authentication and authorization as is",
log.Error(err))
default:
logger := h.logger(reqCtx)
logger.Error("token's introspection failed", log.Error(err))
apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthenticationFailed, ErrMessageAuthenticationFailed)
restapi.RespondError(rw, http.StatusUnauthorized, apiErr, logger)
return
}
} else {
if !introspectionResult.IsActive() {
h.logger(reqCtx).Warn("token was successfully introspected, but it is not active")
apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthenticationFailed, ErrMessageAuthenticationFailed)
restapi.RespondError(rw, http.StatusUnauthorized, apiErr, h.logger(reqCtx))
return
}
jwtClaims = introspectionResult.GetClaims()
h.logger(reqCtx).AtLevel(log.LevelDebug, func(logFunc log.LogFunc) {
logFunc("token was successfully introspected")
})
}
}
if jwtClaims == nil {
var err error
if jwtClaims, err = h.jwtParser.Parse(reqCtx, bearerToken); err != nil {
logger := h.logger(reqCtx)
logger.Error("authentication failed", log.Error(err))
apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthenticationFailed, ErrMessageAuthenticationFailed)
restapi.RespondError(rw, http.StatusUnauthorized, apiErr, logger)
return
}
}
if h.verifyAccess != nil {
if !h.verifyAccess(r, jwtClaims) {
apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthorizationFailed, ErrMessageAuthorizationFailed)
restapi.RespondError(rw, http.StatusForbidden, apiErr, h.logger(reqCtx))
return
}
}
reqCtx = NewContextWithBearerToken(reqCtx, bearerToken)
reqCtx = NewContextWithJWTClaims(reqCtx, jwtClaims)
h.next.ServeHTTP(rw, r.WithContext(reqCtx))
}
func (h *jwtAuthHandler) logger(ctx context.Context) log.FieldLogger {
return idputil.GetLoggerFromProvider(ctx, h.loggerProvider)
}
// GetBearerTokenFromRequest extracts jwt token from request headers.
func GetBearerTokenFromRequest(r *http.Request) string {
authHeader := strings.TrimSpace(r.Header.Get(HeaderAuthorization))
if strings.HasPrefix(authHeader, "Bearer ") || strings.HasPrefix(authHeader, "bearer ") {
return authHeader[7:]
}
return ""
}
// NewContextWithJWTClaims creates a new context with JWT claims.
func NewContextWithJWTClaims(ctx context.Context, jwtClaims jwt.Claims) context.Context {
return context.WithValue(ctx, ctxKeyJWTClaims, jwtClaims)
}
// GetJWTClaimsFromContext extracts JWT claims from the context.
func GetJWTClaimsFromContext(ctx context.Context) jwt.Claims {
value := ctx.Value(ctxKeyJWTClaims)
if value == nil {
return nil
}
return value.(jwt.Claims)
}
// NewContextWithBearerToken creates a new context with token.
func NewContextWithBearerToken(ctx context.Context, token string) context.Context {
return context.WithValue(ctx, ctxKeyBearerToken, token)
}
// GetBearerTokenFromContext extracts token from the context.
func GetBearerTokenFromContext(ctx context.Context) string {
value := ctx.Value(ctxKeyBearerToken)
if value == nil {
return ""
}
return value.(string)
}