Skip to content

Commit

Permalink
IAM: fix OAuth2 error handling (nuts-foundation#3004)
Browse files Browse the repository at this point in the history
* IAM: fix user-page error rendering

* Render actual HTML page
  • Loading branch information
reinkrul authored Apr 18, 2024
1 parent 79f4c38 commit d1b26b2
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 24 deletions.
35 changes: 28 additions & 7 deletions auth/api/iam/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package iam
import (
"context"
crypto2 "crypto"
"embed"
"encoding/json"
"errors"
"fmt"
Expand All @@ -30,6 +31,7 @@ import (
"github.com/nuts-foundation/go-did/vc"
"github.com/nuts-foundation/nuts-node/audit"
"github.com/nuts-foundation/nuts-node/auth"
"github.com/nuts-foundation/nuts-node/auth/api/iam/assets"
"github.com/nuts-foundation/nuts-node/auth/log"
"github.com/nuts-foundation/nuts-node/auth/oauth"
"github.com/nuts-foundation/nuts-node/core"
Expand All @@ -43,6 +45,7 @@ import (
"github.com/nuts-foundation/nuts-node/vdr"
"github.com/nuts-foundation/nuts-node/vdr/didweb"
"github.com/nuts-foundation/nuts-node/vdr/resolver"
"html/template"
"net/http"
"net/url"
"strings"
Expand Down Expand Up @@ -73,6 +76,9 @@ const oid4vciSessionValidity = 15 * time.Minute
// - https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies
const userSessionCookieName = "__Host-SID"

//go:embed assets
var assetsFS embed.FS

// Wrapper handles OAuth2 flows.
type Wrapper struct {
auth auth.AuthenticationServices
Expand All @@ -84,6 +90,11 @@ type Wrapper struct {
}

func New(authInstance auth.AuthenticationServices, vcrInstance vcr.VCR, vdrInstance vdr.VDR, storageEngine storage.Engine, policyBackend policy.PDPBackend) *Wrapper {
templates := template.New("oauth2 templates")
_, err := templates.ParseFS(assetsFS, "assets/*.html")
if err != nil {
panic(err)
}
return &Wrapper{
auth: authInstance,
policyBackend: policyBackend,
Expand All @@ -97,29 +108,39 @@ func (r Wrapper) Routes(router core.EchoRouter) {
RegisterHandlers(router, NewStrictHandler(r, []StrictMiddlewareFunc{
func(f StrictHandlerFunc, operationID string) StrictHandlerFunc {
return func(ctx echo.Context, request interface{}) (response interface{}, err error) {
return r.middleware(ctx, request, operationID, f)
return r.strictMiddleware(ctx, request, operationID, f)
}
},
func(f StrictHandlerFunc, operationID string) StrictHandlerFunc {
return audit.StrictMiddleware(f, apiModuleName, operationID)
},
}))
// The following handlers are used for the user facing OAuth2 flows.
router.GET("/oauth2/:did/user", r.handleUserLanding, audit.Middleware(apiModuleName))
router.GET("/oauth2/:did/user", r.handleUserLanding, func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
middleware(c, "handleUserLanding")
return next(c)
}
}, audit.Middleware(apiModuleName))
}

func (r Wrapper) middleware(ctx echo.Context, request interface{}, operationID string, f StrictHandlerFunc) (interface{}, error) {
func (r Wrapper) strictMiddleware(ctx echo.Context, request interface{}, operationID string, f StrictHandlerFunc) (interface{}, error) {
middleware(ctx, operationID)
return f(ctx, request)
}

func middleware(ctx echo.Context, operationID string) {
ctx.Set(core.OperationIDContextKey, operationID)
ctx.Set(core.ModuleNameContextKey, apiModuleName)

// Add http.Request to context, to allow reading URL query parameters
requestCtx := context.WithValue(ctx.Request().Context(), httpRequestContextKey, ctx.Request())
ctx.SetRequest(ctx.Request().WithContext(requestCtx))
if strings.HasPrefix(ctx.Request().URL.Path, "/iam/") {
ctx.Set(core.ErrorWriterContextKey, &oauth.Oauth2ErrorWriter{})
if strings.HasPrefix(ctx.Request().URL.Path, "/oauth2/") {
ctx.Set(core.ErrorWriterContextKey, &oauth.Oauth2ErrorWriter{
HtmlPageTemplate: assets.ErrorTemplate,
})
}

return f(ctx, request)
}

// ResolveStatusCode maps errors returned by this API to specific HTTP status codes.
Expand Down
6 changes: 3 additions & 3 deletions auth/api/iam/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -767,14 +767,14 @@ func TestWrapper_middleware(t *testing.T) {
t.Run("OAuth2 error handling", func(t *testing.T) {
var handler strictServerCallCapturer
t.Run("OAuth2 path", func(t *testing.T) {
ctx := server.NewContext(httptest.NewRequest("GET", "/iam/foo", nil), httptest.NewRecorder())
_, _ = Wrapper{auth: authService}.middleware(ctx, nil, "Test", handler.handle)
ctx := server.NewContext(httptest.NewRequest("GET", "/oauth2/foo", nil), httptest.NewRecorder())
_, _ = Wrapper{auth: authService}.strictMiddleware(ctx, nil, "Test", handler.handle)

assert.IsType(t, &oauth.Oauth2ErrorWriter{}, ctx.Get(core.ErrorWriterContextKey))
})
t.Run("other path", func(t *testing.T) {
ctx := server.NewContext(httptest.NewRequest("GET", "/internal/foo", nil), httptest.NewRecorder())
_, _ = Wrapper{auth: authService}.middleware(ctx, nil, "Test", handler.handle)
_, _ = Wrapper{auth: authService}.strictMiddleware(ctx, nil, "Test", handler.handle)

assert.Nil(t, ctx.Get(core.ErrorWriterContextKey))
})
Expand Down
14 changes: 14 additions & 0 deletions auth/api/iam/assets/error.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
<html lang="en">
<head>
<meta charset="UTF-8">
<title>Authorization error</title>
</head>
<body>
<h1>Authorization error</h1>
<p>Couldn't fulfill the request. You can try again, or contact your service provider if the error persists.</p>
<p>You should provide them with the error details shown below.</p>
<pre style="margin: 5px; padding: 10px; background-color: #eff1f5">
{{ if .Code }}Code: {{ .Code }}{{ end }}
{{ if .Description }}Description: {{ .Description }}{{ end }}</pre>
</body>
</html>
35 changes: 35 additions & 0 deletions auth/api/iam/assets/templates.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright (C) 2024 Nuts community
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*
*/

package assets

import (
"embed"
"html/template"
)

//go:embed *.html
var assets embed.FS

// ErrorTemplate is the template used to render error pages.
var ErrorTemplate *template.Template

func init() {
templates := template.Must(template.ParseFS(assets, "*.html"))
ErrorTemplate = templates.Lookup("error.html")
}
30 changes: 30 additions & 0 deletions auth/api/iam/assets/templates_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright (C) 2024 Nuts community
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*
*/

package assets

import (
"github.com/stretchr/testify/assert"
"testing"
)

func TestTemplateLoading(t *testing.T) {
// This test is here to make sure the templates are loaded correctly.
// It doesn't test the actual content of the templates.
assert.NotNil(t, ErrorTemplate)
}
15 changes: 14 additions & 1 deletion auth/oauth/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@
package oauth

import (
"bytes"
"encoding/json"
"errors"
"github.com/labstack/echo/v4"
"github.com/nuts-foundation/nuts-node/auth/log"
"github.com/nuts-foundation/nuts-node/core"
"html/template"
"net/http"
"net/url"
"strings"
Expand Down Expand Up @@ -96,7 +99,9 @@ func (e OAuth2Error) Error() string {
}

// Oauth2ErrorWriter is a HTTP response writer for OAuth errors
type Oauth2ErrorWriter struct{}
type Oauth2ErrorWriter struct {
HtmlPageTemplate *template.Template
}

func (p Oauth2ErrorWriter) Write(echoContext echo.Context, _ int, _ string, err error) error {
var oauthErr OAuth2Error
Expand All @@ -123,6 +128,14 @@ func (p Oauth2ErrorWriter) Write(echoContext echo.Context, _ int, _ string, err
// Return JSON response
return echoContext.JSON(oauthErr.StatusCode(), oauthErr)
}
if p.HtmlPageTemplate != nil {
buf := new(bytes.Buffer)
err = p.HtmlPageTemplate.Execute(buf, oauthErr)
if err != nil {
log.Logger().WithError(err).Warnf("unable to render error page for error: %s", oauthErr.Error())
}
return echoContext.HTMLBlob(oauthErr.StatusCode(), buf.Bytes())
}
// Return plain text response
parts := []string{string(oauthErr.Code)}
if oauthErr.Description != "" {
Expand Down
49 changes: 36 additions & 13 deletions auth/oauth/error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/labstack/echo/v4"
"github.com/nuts-foundation/nuts-node/test"
"github.com/stretchr/testify/assert"
"html/template"
"io"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -57,22 +58,44 @@ func Test_oauth2ErrorWriter_Write(t *testing.T) {
assert.Equal(t, "https://example.com?error=invalid_request&error_description=failure", rec.Header().Get("Location"))
})
t.Run("user-agent is browser without redirect URI", func(t *testing.T) {
server := echo.New()
httpRequest := httptest.NewRequest("GET", "/", nil)
rec := httptest.NewRecorder()
ctx := server.NewContext(httpRequest, rec)
t.Run("text/html (from template)", func(t *testing.T) {
server := echo.New()
httpRequest := httptest.NewRequest("GET", "/", nil)
rec := httptest.NewRecorder()
ctx := server.NewContext(httpRequest, rec)

err := Oauth2ErrorWriter{}.Write(ctx, 0, "", OAuth2Error{
Code: InvalidRequest,
Description: "failure",
err := Oauth2ErrorWriter{
HtmlPageTemplate: template.Must(template.New("error").Parse("Error: {{.Code}} - {{.Description}}")),
}.Write(ctx, 0, "", OAuth2Error{
Code: InvalidRequest,
Description: "failure",
})

assert.NoError(t, err)
body, _ := io.ReadAll(rec.Body)
assert.Equal(t, http.StatusBadRequest, rec.Code)
assert.Equal(t, "text/html; charset=UTF-8", rec.Header().Get("Content-Type"))
assert.Equal(t, "Error: invalid_request - failure", string(body))
assert.Empty(t, rec.Header().Get("Location"))
})
t.Run("text/plain (no template)", func(t *testing.T) {
server := echo.New()
httpRequest := httptest.NewRequest("GET", "/", nil)
rec := httptest.NewRecorder()
ctx := server.NewContext(httpRequest, rec)

assert.NoError(t, err)
body, _ := io.ReadAll(rec.Body)
assert.Equal(t, http.StatusBadRequest, rec.Code)
assert.Equal(t, "text/plain; charset=UTF-8", rec.Header().Get("Content-Type"))
assert.Equal(t, "invalid_request - failure", string(body))
assert.Empty(t, rec.Header().Get("Location"))
err := Oauth2ErrorWriter{}.Write(ctx, 0, "", OAuth2Error{
Code: InvalidRequest,
Description: "failure",
})

assert.NoError(t, err)
body, _ := io.ReadAll(rec.Body)
assert.Equal(t, http.StatusBadRequest, rec.Code)
assert.Equal(t, "text/plain; charset=UTF-8", rec.Header().Get("Content-Type"))
assert.Equal(t, "invalid_request - failure", string(body))
assert.Empty(t, rec.Header().Get("Location"))
})
})
t.Run("user-agent is API client (sent JSON)", func(t *testing.T) {
server := echo.New()
Expand Down

0 comments on commit d1b26b2

Please sign in to comment.