-
Notifications
You must be signed in to change notification settings - Fork 107
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
355 additions
and
53 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
package idp | ||
|
||
type IdentityProviderUserInfo struct { | ||
Identifier string | ||
Email string | ||
DisplayName string | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
// Package oauth2 is the plugin for OAuth2 Identity Provider. | ||
package oauth2 | ||
|
||
import ( | ||
"context" | ||
"encoding/json" | ||
"fmt" | ||
"io" | ||
"net/http" | ||
|
||
"github.com/pkg/errors" | ||
"golang.org/x/oauth2" | ||
|
||
"github.com/yourselfhosted/slash/plugin/idp" | ||
storepb "github.com/yourselfhosted/slash/proto/gen/store" | ||
) | ||
|
||
// IdentityProvider represents an OAuth2 Identity Provider. | ||
type IdentityProvider struct { | ||
config *storepb.IdentityProviderConfig_OAuth2Config | ||
} | ||
|
||
// NewIdentityProvider initializes a new OAuth2 Identity Provider with the given configuration. | ||
func NewIdentityProvider(config *storepb.IdentityProviderConfig_OAuth2Config) (*IdentityProvider, error) { | ||
for v, field := range map[string]string{ | ||
config.ClientId: "clientId", | ||
config.ClientSecret: "clientSecret", | ||
config.TokenUrl: "tokenUrl", | ||
config.UserInfoUrl: "userInfoUrl", | ||
config.FieldMapping.Identifier: "fieldMapping.identifier", | ||
} { | ||
if v == "" { | ||
return nil, errors.Errorf(`the field "%s" is empty but required`, field) | ||
} | ||
} | ||
|
||
return &IdentityProvider{ | ||
config: config, | ||
}, nil | ||
} | ||
|
||
// ExchangeToken returns the exchanged OAuth2 token using the given authorization code. | ||
func (p *IdentityProvider) ExchangeToken(ctx context.Context, redirectURL, code string) (string, error) { | ||
conf := &oauth2.Config{ | ||
ClientID: p.config.ClientId, | ||
ClientSecret: p.config.ClientSecret, | ||
RedirectURL: redirectURL, | ||
Scopes: p.config.Scopes, | ||
Endpoint: oauth2.Endpoint{ | ||
AuthURL: p.config.AuthUrl, | ||
TokenURL: p.config.TokenUrl, | ||
AuthStyle: oauth2.AuthStyleInParams, | ||
}, | ||
} | ||
|
||
token, err := conf.Exchange(ctx, code) | ||
if err != nil { | ||
return "", errors.Wrap(err, "failed to exchange access token") | ||
} | ||
|
||
accessToken, ok := token.Extra("access_token").(string) | ||
if !ok { | ||
return "", errors.New(`missing "access_token" from authorization response`) | ||
} | ||
|
||
return accessToken, nil | ||
} | ||
|
||
// UserInfo returns the parsed user information using the given OAuth2 token. | ||
func (p *IdentityProvider) UserInfo(token string) (*idp.IdentityProviderUserInfo, error) { | ||
client := &http.Client{} | ||
req, err := http.NewRequest(http.MethodGet, p.config.UserInfoUrl, nil) | ||
if err != nil { | ||
return nil, errors.Wrap(err, "failed to new http request") | ||
} | ||
req.Header.Set("Content-Type", "application/json") | ||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) | ||
resp, err := client.Do(req) | ||
if err != nil { | ||
return nil, errors.Wrap(err, "failed to get user information") | ||
} | ||
body, err := io.ReadAll(resp.Body) | ||
if err != nil { | ||
return nil, errors.Wrap(err, "failed to read response body") | ||
} | ||
defer resp.Body.Close() | ||
|
||
var claims map[string]any | ||
err = json.Unmarshal(body, &claims) | ||
if err != nil { | ||
return nil, errors.Wrap(err, "failed to unmarshal response body") | ||
} | ||
|
||
userInfo := &idp.IdentityProviderUserInfo{} | ||
if v, ok := claims[p.config.FieldMapping.Identifier].(string); ok { | ||
userInfo.Identifier = v | ||
} | ||
if userInfo.Identifier == "" { | ||
return nil, errors.Errorf("the field %q is not found in claims or has empty value", p.config.FieldMapping.Identifier) | ||
} | ||
|
||
// Best effort to map optional fields | ||
if p.config.FieldMapping.DisplayName != "" { | ||
if v, ok := claims[p.config.FieldMapping.DisplayName].(string); ok { | ||
userInfo.DisplayName = v | ||
} | ||
} | ||
if userInfo.DisplayName == "" { | ||
userInfo.DisplayName = userInfo.Identifier | ||
} | ||
if p.config.FieldMapping.Email != "" { | ||
if v, ok := claims[p.config.FieldMapping.Email].(string); ok { | ||
userInfo.Email = v | ||
} | ||
} | ||
return userInfo, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
package oauth2 | ||
|
||
import ( | ||
"context" | ||
"encoding/json" | ||
"fmt" | ||
"io" | ||
"net/http" | ||
"net/http/httptest" | ||
"net/url" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
|
||
"github.com/yourselfhosted/slash/plugin/idp" | ||
storepb "github.com/yourselfhosted/slash/proto/gen/store" | ||
) | ||
|
||
func TestNewIdentityProvider(t *testing.T) { | ||
tests := []struct { | ||
name string | ||
config *storepb.IdentityProviderConfig_OAuth2Config | ||
containsErr string | ||
}{ | ||
{ | ||
name: "no tokenUrl", | ||
config: &storepb.IdentityProviderConfig_OAuth2Config{ | ||
ClientId: "test-client-id", | ||
ClientSecret: "test-client-secret", | ||
AuthUrl: "", | ||
TokenUrl: "", | ||
UserInfoUrl: "https://example.com/api/user", | ||
FieldMapping: &storepb.IdentityProviderConfig_FieldMapping{ | ||
Identifier: "login", | ||
}, | ||
}, | ||
containsErr: `the field "tokenUrl" is empty but required`, | ||
}, | ||
{ | ||
name: "no userInfoUrl", | ||
config: &storepb.IdentityProviderConfig_OAuth2Config{ | ||
ClientId: "test-client-id", | ||
ClientSecret: "test-client-secret", | ||
AuthUrl: "", | ||
TokenUrl: "https://example.com/token", | ||
UserInfoUrl: "", | ||
FieldMapping: &storepb.IdentityProviderConfig_FieldMapping{ | ||
Identifier: "login", | ||
}, | ||
}, | ||
containsErr: `the field "userInfoUrl" is empty but required`, | ||
}, | ||
{ | ||
name: "no field mapping identifier", | ||
config: &storepb.IdentityProviderConfig_OAuth2Config{ | ||
ClientId: "test-client-id", | ||
ClientSecret: "test-client-secret", | ||
AuthUrl: "", | ||
TokenUrl: "https://example.com/token", | ||
UserInfoUrl: "https://example.com/api/user", | ||
FieldMapping: &storepb.IdentityProviderConfig_FieldMapping{ | ||
Identifier: "", | ||
}, | ||
}, | ||
containsErr: `the field "fieldMapping.identifier" is empty but required`, | ||
}, | ||
} | ||
for _, test := range tests { | ||
t.Run(test.name, func(t *testing.T) { | ||
_, err := NewIdentityProvider(test.config) | ||
assert.ErrorContains(t, err, test.containsErr) | ||
}) | ||
} | ||
} | ||
|
||
func newMockServer(t *testing.T, code, accessToken string, userinfo []byte) *httptest.Server { | ||
mux := http.NewServeMux() | ||
|
||
var rawIDToken string | ||
mux.HandleFunc("/oauth2/token", func(w http.ResponseWriter, r *http.Request) { | ||
require.Equal(t, http.MethodPost, r.Method) | ||
|
||
body, err := io.ReadAll(r.Body) | ||
require.NoError(t, err) | ||
vals, err := url.ParseQuery(string(body)) | ||
require.NoError(t, err) | ||
|
||
require.Equal(t, code, vals.Get("code")) | ||
require.Equal(t, "authorization_code", vals.Get("grant_type")) | ||
|
||
w.Header().Set("Content-Type", "application/json") | ||
err = json.NewEncoder(w).Encode(map[string]any{ | ||
"access_token": accessToken, | ||
"token_type": "Bearer", | ||
"expires_in": 3600, | ||
"id_token": rawIDToken, | ||
}) | ||
require.NoError(t, err) | ||
}) | ||
mux.HandleFunc("/oauth2/userinfo", func(w http.ResponseWriter, r *http.Request) { | ||
w.Header().Set("Content-Type", "application/json") | ||
_, err := w.Write(userinfo) | ||
require.NoError(t, err) | ||
}) | ||
|
||
s := httptest.NewServer(mux) | ||
|
||
return s | ||
} | ||
|
||
func TestIdentityProvider(t *testing.T) { | ||
ctx := context.Background() | ||
|
||
const ( | ||
testClientID = "test-client-id" | ||
testCode = "test-code" | ||
testAccessToken = "test-access-token" | ||
testSubject = "123456789" | ||
testName = "John Doe" | ||
testEmail = "[email protected]" | ||
) | ||
userInfo, err := json.Marshal( | ||
map[string]any{ | ||
"sub": testSubject, | ||
"name": testName, | ||
"email": testEmail, | ||
}, | ||
) | ||
require.NoError(t, err) | ||
|
||
s := newMockServer(t, testCode, testAccessToken, userInfo) | ||
|
||
oauth2, err := NewIdentityProvider( | ||
&storepb.IdentityProviderConfig_OAuth2Config{ | ||
ClientId: testClientID, | ||
ClientSecret: "test-client-secret", | ||
TokenUrl: fmt.Sprintf("%s/oauth2/token", s.URL), | ||
UserInfoUrl: fmt.Sprintf("%s/oauth2/userinfo", s.URL), | ||
FieldMapping: &storepb.IdentityProviderConfig_FieldMapping{ | ||
Identifier: "sub", | ||
DisplayName: "name", | ||
Email: "email", | ||
}, | ||
}, | ||
) | ||
require.NoError(t, err) | ||
|
||
redirectURL := "https://example.com/oauth/callback" | ||
oauthToken, err := oauth2.ExchangeToken(ctx, redirectURL, testCode) | ||
require.NoError(t, err) | ||
require.Equal(t, testAccessToken, oauthToken) | ||
|
||
userInfoResult, err := oauth2.UserInfo(oauthToken) | ||
require.NoError(t, err) | ||
|
||
wantUserInfo := &idp.IdentityProviderUserInfo{ | ||
Identifier: testSubject, | ||
DisplayName: testName, | ||
Email: testEmail, | ||
} | ||
assert.Equal(t, wantUserInfo, userInfoResult) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.