diff --git a/go-shared/go.mod b/go-shared/go.mod
index 11aaed9a8..93ac5df5c 100644
--- a/go-shared/go.mod
+++ b/go-shared/go.mod
@@ -1,5 +1,5 @@
module github.com/broadinstitute/sherlock/go-shared
-go 1.21
+go 1.22
require github.com/google/go-cmp v0.6.0
diff --git a/sherlock-go-client/go.mod b/sherlock-go-client/go.mod
index 11ef6f1ef..50435f62a 100644
--- a/sherlock-go-client/go.mod
+++ b/sherlock-go-client/go.mod
@@ -1,6 +1,6 @@
module github.com/broadinstitute/sherlock/sherlock-go-client
-go 1.21
+go 1.22
require (
github.com/go-openapi/errors v0.22.0
diff --git a/sherlock-webhook-proxy/go.mod b/sherlock-webhook-proxy/go.mod
index a49047dd8..58fb84bad 100644
--- a/sherlock-webhook-proxy/go.mod
+++ b/sherlock-webhook-proxy/go.mod
@@ -1,6 +1,6 @@
module github.com/broadinstitute/sherlock/sherlock-webhook-proxy
-go 1.21
+go 1.22
require (
github.com/GoogleCloudPlatform/functions-framework-go v1.8.1
diff --git a/sherlock/Dockerfile b/sherlock/Dockerfile
index 87a81e440..4da88328d 100644
--- a/sherlock/Dockerfile
+++ b/sherlock/Dockerfile
@@ -1,11 +1,10 @@
-ARG GO_VERSION='1.21'
-ARG ALPINE_VERSION='3.19'
+ARG GO_VERSION='1.22'
-FROM golang:${GO_VERSION}-alpine${ALPINE_VERSION} AS build
+# https://github.com/microsoft/go-images/tree/microsoft/main
+FROM mcr.microsoft.com/oss/go/microsoft/golang:${GO_VERSION}-fips-cbl-mariner2.0 AS build
ARG BUILD_VERSION='development'
WORKDIR /build/sherlock
-ENV CGO_ENABLED=0
-ENV GOBIN=/bin
+ENV CGO_ENABLED=1
COPY sherlock/go.mod sherlock/go.sum ./
COPY go-shared ../go-shared/
@@ -14,8 +13,13 @@ RUN go mod download && go mod verify
COPY sherlock ./
RUN go build -buildvcs=false -ldflags="-X 'github.com/broadinstitute/sherlock/go-shared/pkg/version.BuildVersion=${BUILD_VERSION}'" -o /bin/sherlock ./cmd/...
-# FROM alpine:${ALPINE_VERSION} as runtime <-- use this if you hit issues
-FROM gcr.io/distroless/static:nonroot AS runtime
+# Check that the binary is FIPS-capable, fail if not
+RUN go get github.com/acardace/fips-detect && \
+ go run github.com/acardace/fips-detect /bin/sherlock \
+ | grep -E 'FIPS-capable Go binary.*Yes'
+
+# https://mcr.microsoft.com/en-us/product/cbl-mariner/distroless/minimal/about
+FROM mcr.microsoft.com/cbl-mariner/distroless/base:2.0-nonroot AS runtime
COPY --from=build /bin/sherlock /bin/sherlock
ENTRYPOINT [ "/bin/sherlock" ]
diff --git a/sherlock/config/default_config.yaml b/sherlock/config/default_config.yaml
index def650945..fd6b1a9d5 100644
--- a/sherlock/config/default_config.yaml
+++ b/sherlock/config/default_config.yaml
@@ -69,6 +69,56 @@ db:
ignoreNotFoundWarning: true
level: warn
+# Configures Sherlock's own OIDC provider, not to be confused with its capability to interpret tokens
+# from IAP or GitHub Actions.
+oidc:
+ enable: true
+ # The issuer URL of Sherlock itself. This should be scheme + host + "/oidc", because Sherlock
+ # serves its own OIDC provider at that sub-path. This should also generally be an in-cluster
+ # address, because the IAP in front of the public endpoints isn't in-spec. This is fine as long
+ # as we just use Sherlock as an in-cluster authorization service.
+ #
+ # An example is https://sherlock-api-service.sherlock.svc.cluster.local/oidc
+ issuerUrl: http://localhost:8080/oidc
+ # The *public* side of Sherlock's OIDC issuer. This should be a normally-accessible URL that should
+ # go to the same destination as issuerUrl above. This is automatically used in the OIDC discovery
+ # config to tell clients how to have *users* authenticate against Sherlock. The downstream system
+ # will talk to issuerUrl in-cluster, but end-users will need to talk to publicIssuerUrl.
+ #
+ # An example is https://sherlock.dsp-devops-prod.broadinstitute.org/oidc
+ publicIssuerUrl: http://localhost:8080/oidc
+
+ # The key that Sherlock should use to AES-256 encrypt internal data it sends to clients. This is
+ # used in two places by the underlying OIDC library:
+ # 1. Encrypting "{Token.ID}:{User.ID}" to create access tokens returned to clients
+ # 2. Encrypting "{AuthRequest.ID}" to create authorization codes returned to clients
+ # This does need to be rotated but doing so is potentially disruptive; Sherlock will cease
+ # respecting access tokens or authorization codes it has issued.
+ #
+ # Sherlock will error on boot if this doesn't parse from a hex string to 32 bytes. You'll probably
+ # want to pass this in the environment with SHERLOCK_oidc_encryptionKeyHex. It should be passed
+ # in hex format.
+ encryptionKeyHex: 7265706c6163652d6d652d776974682d33322d627974652d6b65792d2d2d2d2d # "replace-me-with-32-byte-key----"
+ # The duration that ID and access tokens vended to clients should be valid for.
+ tokenDuration: 15m
+ # The duration that refresh tokens vended to clients should be valid for.
+ refreshTokenDuration: 30m
+ # The duration that a particular signing key should be used before being rotated.
+ signingKeyPrimaryDuration: 4h
+ # The time after which a signing key should be deleted (and its signatures no longer accepted)
+ # after it has been rotated. This should be longer than all token durations so that we
+ # continue to respect our own signatures until they'd expire on their own.
+ signingKeyPostRotationDuration: 2h
+ # When enabled, Sherlock will use Google Cloud KMS to symmetrically encrypt the private keys
+ # it stores in its own database. This is a defense-in-depth measure to prevent key leakage in
+ # the event of SQL injection or other database compromise.
+ #
+ # This must be true when mode is not "debug".
+ signingKeyEncryptionKMSEnable: false
+ # The fully-qualified name of the KMS key to use when signingKeyEncryptionKMSEnable is true.
+ signingKeyEncryptionKMSKeyName: projects/some-project/locations/some-location/keyRings/some-key-ring/cryptoKeys/some-key
+
+
auth:
githubActionsOIDC:
issuer: https://token.actions.githubusercontent.com
@@ -217,6 +267,18 @@ argoCd:
environmentUrlFormatString: https://argocd.dsp-devops-prod.broadinstitute.org/applications?showFavorites=false&proj=&sync=&health=&namespace=&cluster=&labels=env%%253D%s
model:
+ roles:
+ # When set, Sherlock won't ever report an Environment/Cluster RequiredRole field as being null.
+ # Instead, it will substitute this value in its place (even though it won't be stored in the database).
+ # This can be useful in that it means downstream consumers don't need null handling like
+ # `requiredRole ?? "all-users"`. While that's simple, it's actually easier at a security/compliance
+ # level to say that Sherlock defines it and anything else uses it verbatim. (This was a specific
+ # request from appsec for this reason`.)
+ #
+ # Even if this is set, Sherlock will allow setting the field to empty to clear it out -- but will then
+ # respond in the API as if it's been set to this value. Note that the role set here needs to already
+ # exist or downstream consumers could have issues.
+ substituteEmptyRequiredRoleWithValue:
environments:
templates:
# Uses appVersionResolver = "none", chartVersionResolver = "latest", and helmfileRef = "HEAD"
diff --git a/sherlock/config/test_config.yaml b/sherlock/config/test_config.yaml
index 2efb17ba6..7e74fb878 100644
--- a/sherlock/config/test_config.yaml
+++ b/sherlock/config/test_config.yaml
@@ -44,6 +44,8 @@ pactbroker:
enable: false
model:
+ roles:
+ substituteEmptyRequiredRoleWithValue: all-users
environments:
templates:
autoPopulateCharts:
diff --git a/sherlock/db/migrations/000095_oidc.down.sql b/sherlock/db/migrations/000095_oidc.down.sql
new file mode 100644
index 000000000..622c7e409
--- /dev/null
+++ b/sherlock/db/migrations/000095_oidc.down.sql
@@ -0,0 +1,11 @@
+drop table signing_keys;
+
+drop table tokens;
+
+drop table refresh_tokens;
+
+drop table auth_request_codes;
+
+drop table auth_requests;
+
+drop table clients;
diff --git a/sherlock/db/migrations/000095_oidc.up.sql b/sherlock/db/migrations/000095_oidc.up.sql
new file mode 100644
index 000000000..9b8cdd1d6
--- /dev/null
+++ b/sherlock/db/migrations/000095_oidc.up.sql
@@ -0,0 +1,98 @@
+create table clients
+(
+ id text not null
+ primary key,
+ client_secret_hash bytea,
+ client_secret_salt bytea,
+ client_secret_iterations bigint,
+ client_redirect_uris text,
+ client_post_logout_redirect_uris text,
+ client_application_type text,
+ client_auth_method text,
+ client_id_token_lifetime bigint,
+ client_dev_mode boolean,
+ client_clock_skew bigint
+);
+
+create table auth_requests
+(
+ id text not null
+ primary key,
+ created_at timestamp with time zone,
+ done_at timestamp with time zone,
+ client_id text
+ constraint fk_auth_requests_client
+ references clients
+ on update cascade on delete cascade,
+ nonce text,
+ redirect_uri text,
+ response_type text,
+ response_mode text,
+ scopes text,
+ state text,
+ code_challenge text,
+ code_challenge_method text,
+ user_id bigint
+ constraint fk_auth_requests_user
+ references users
+ on update cascade on delete cascade
+);
+
+create table auth_request_codes
+(
+ code text not null
+ primary key,
+ created_at timestamp with time zone,
+ auth_request_id text
+ constraint fk_auth_request_codes_auth_request
+ references auth_requests
+ on update cascade on delete cascade
+);
+
+create table refresh_tokens
+(
+ id text not null
+ primary key,
+ created_at timestamp with time zone,
+ token_hash bytea unique,
+ client_id text
+ constraint fk_refresh_tokens_client
+ references clients
+ on update cascade on delete cascade,
+ scopes text,
+ original_auth_at timestamp with time zone,
+ user_id bigint
+ constraint fk_refresh_tokens_user
+ references users
+ on update cascade on delete cascade
+);
+
+create table tokens
+(
+ id text not null
+ primary key,
+ created_at timestamp with time zone,
+ refresh_token_id text
+ constraint fk_tokens_refresh_token
+ references refresh_tokens
+ on update cascade on delete cascade,
+ client_id text
+ constraint fk_tokens_client
+ references clients
+ on update cascade on delete cascade,
+ scopes text,
+ expiry timestamp with time zone,
+ user_id bigint
+ constraint fk_tokens_user
+ references users
+ on update cascade on delete cascade
+);
+
+create table signing_keys
+(
+ id text not null
+ primary key,
+ created_at timestamp with time zone,
+ public_key bytea,
+ private_key bytea
+);
diff --git a/sherlock/go.mod b/sherlock/go.mod
index ec81b6d3f..0f2e0020f 100644
--- a/sherlock/go.mod
+++ b/sherlock/go.mod
@@ -1,9 +1,12 @@
module github.com/broadinstitute/sherlock/sherlock
-go 1.21
+go 1.22
+
+toolchain go1.22.5
require (
cloud.google.com/go/cloudsqlconn v1.11.1
+ cloud.google.com/go/kms v1.18.3
contrib.go.opencensus.io/exporter/prometheus v0.4.2
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0
github.com/PagerDuty/go-pagerduty v1.8.0
@@ -13,6 +16,7 @@ require (
github.com/dustinkirkland/golang-petname v0.0.0-20230626224747-e794b9370d49
github.com/gin-contrib/cors v1.7.2
github.com/gin-gonic/gin v1.10.0
+ github.com/go-jose/go-jose/v4 v4.0.2
github.com/golang-migrate/migrate/v4 v4.17.1
github.com/google/go-cmp v0.6.0
github.com/google/go-github/v58 v58.0.0
@@ -29,10 +33,13 @@ require (
github.com/swaggo/files v1.0.1
github.com/swaggo/gin-swagger v1.6.0
github.com/swaggo/swag v1.16.3
+ github.com/zitadel/oidc/v3 v3.26.0
go.opencensus.io v0.24.0
+ golang.org/x/crypto v0.25.0
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa
golang.org/x/net v0.27.0
golang.org/x/oauth2 v0.21.0
+ golang.org/x/text v0.16.0
google.golang.org/api v0.188.0
gorm.io/datatypes v1.2.1
gorm.io/driver/postgres v1.5.9
@@ -42,15 +49,19 @@ require (
replace github.com/broadinstitute/sherlock/go-shared => ../go-shared
require (
+ cloud.google.com/go v0.115.0 // indirect
cloud.google.com/go/auth v0.7.0 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect
cloud.google.com/go/compute/metadata v0.4.0 // indirect
+ cloud.google.com/go/iam v1.1.10 // indirect
+ cloud.google.com/go/longrunning v0.5.9 // indirect
filippo.io/edwards25519 v1.1.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.12.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.9.0 // indirect
github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 // indirect
github.com/KyleBanks/depth v1.2.1 // indirect
github.com/beorn7/perks v1.0.1 // indirect
+ github.com/bmatcuk/doublestar/v4 v4.6.1 // indirect
github.com/bytedance/sonic v1.11.6 // indirect
github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
@@ -62,9 +73,10 @@ require (
github.com/fsnotify/fsnotify v1.6.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
+ github.com/go-chi/chi/v5 v5.1.0 // indirect
github.com/go-kit/log v0.2.1 // indirect
github.com/go-logfmt/logfmt v0.5.1 // indirect
- github.com/go-logr/logr v1.4.1 // indirect
+ github.com/go-logr/logr v1.4.2 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-openapi/jsonpointer v0.19.6 // indirect
github.com/go-openapi/jsonreference v0.20.2 // indirect
@@ -82,6 +94,7 @@ require (
github.com/google/s2a-go v0.1.7 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/googleapis/gax-go/v2 v2.12.5 // indirect
+ github.com/gorilla/securecookie v1.1.2 // indirect
github.com/gorilla/websocket v1.4.2 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-multierror v1.1.1 // indirect
@@ -116,6 +129,8 @@ require (
github.com/mitchellh/reflectwalk v1.0.2 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
+ github.com/muhlemmer/gu v0.3.1 // indirect
+ github.com/muhlemmer/httpforwarded v0.1.0 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
@@ -125,6 +140,8 @@ require (
github.com/prometheus/common v0.37.0 // indirect
github.com/prometheus/procfs v0.8.0 // indirect
github.com/prometheus/statsd_exporter v0.22.8 // indirect
+ github.com/rs/cors v1.11.0 // indirect
+ github.com/sirupsen/logrus v1.9.3 // indirect
github.com/spf13/afero v1.11.0 // indirect
github.com/spf13/cobra v1.8.1 // indirect
github.com/spf13/pflag v1.0.5 // indirect
@@ -132,19 +149,21 @@ require (
github.com/stretchr/objx v0.5.2 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
+ github.com/zitadel/logging v0.6.0 // indirect
+ github.com/zitadel/schema v1.3.0 // indirect
+ go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
- go.opentelemetry.io/otel v1.24.0 // indirect
- go.opentelemetry.io/otel/metric v1.24.0 // indirect
- go.opentelemetry.io/otel/trace v1.24.0 // indirect
+ go.opentelemetry.io/otel v1.28.0 // indirect
+ go.opentelemetry.io/otel/metric v1.28.0 // indirect
+ go.opentelemetry.io/otel/trace v1.28.0 // indirect
go.uber.org/atomic v1.10.0 // indirect
golang.org/x/arch v0.8.0 // indirect
- golang.org/x/crypto v0.25.0 // indirect
golang.org/x/sync v0.7.0 // indirect
golang.org/x/sys v0.22.0 // indirect
- golang.org/x/text v0.16.0 // indirect
golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
- google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 // indirect
+ google.golang.org/genproto v0.0.0-20240708141625-4ad9e859172b // indirect
+ google.golang.org/genproto/googleapis/api v0.0.0-20240701130421-f6361c86f094 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240709173604-40e1e62336c5 // indirect
google.golang.org/grpc v1.64.1 // indirect
google.golang.org/protobuf v1.34.2 // indirect
diff --git a/sherlock/go.sum b/sherlock/go.sum
index 47a2f5c2d..45392da2d 100644
--- a/sherlock/go.sum
+++ b/sherlock/go.sum
@@ -13,6 +13,8 @@ cloud.google.com/go v0.56.0/go.mod h1:jr7tqZxxKOVYizybht9+26Z/gUq7tiRzu+ACVAMbKV
cloud.google.com/go v0.57.0/go.mod h1:oXiQ6Rzq3RAkkY7N6t3TcE6jE+CIBBbA36lwQ1JyzZs=
cloud.google.com/go v0.62.0/go.mod h1:jmCYTdRCQuc1PHIIJ/maLInMho30T/Y0M4hTdTShOYc=
cloud.google.com/go v0.65.0/go.mod h1:O5N8zS7uWy9vkA9vayVHs65eM1ubvY4h553ofrNHObY=
+cloud.google.com/go v0.115.0 h1:CnFSK6Xo3lDYRoBKEcAtia6VSC837/ZkJuRduSFnr14=
+cloud.google.com/go v0.115.0/go.mod h1:8jIM5vVgoAEoiVxQ/O4BFTfHqulPZgs/ufEzMcFMdWU=
cloud.google.com/go/auth v0.7.0 h1:kf/x9B3WTbBUHkC+1VS8wwwli9TzhSt0vSTVBmMR8Ts=
cloud.google.com/go/auth v0.7.0/go.mod h1:D+WqdrpcjmiCgWrXmLLxOVq1GACoE36chW6KXoEvuIw=
cloud.google.com/go/auth/oauth2adapt v0.2.2 h1:+TTV8aXpjeChS9M+aTtN/TjdQnzJvmzKFt//oWu7HX4=
@@ -29,6 +31,12 @@ cloud.google.com/go/compute/metadata v0.4.0 h1:vHzJCWaM4g8XIcm8kopr3XmDA4Gy/lblD
cloud.google.com/go/compute/metadata v0.4.0/go.mod h1:SIQh1Kkb4ZJ8zJ874fqVkslA29PRXuleyj6vOzlbK7M=
cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE=
cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk=
+cloud.google.com/go/iam v1.1.10 h1:ZSAr64oEhQSClwBL670MsJAW5/RLiC6kfw3Bqmd5ZDI=
+cloud.google.com/go/iam v1.1.10/go.mod h1:iEgMq62sg8zx446GCaijmA2Miwg5o3UbO+nI47WHJps=
+cloud.google.com/go/kms v1.18.3 h1:8+Z2S4bQDSCdghB5ZA5dVDDJTLmnkRlowtFiXqMFd74=
+cloud.google.com/go/kms v1.18.3/go.mod h1:y/Lcf6fyhbdn7MrG1VaDqXxM8rhOBc5rWcWAhcvZjQU=
+cloud.google.com/go/longrunning v0.5.9 h1:haH9pAuXdPAMqHvzX0zlWQigXT7B0+CL4/2nXXdBo5k=
+cloud.google.com/go/longrunning v0.5.9/go.mod h1:HD+0l9/OOW0za6UWdKJtXoFAX/BGg/3Wj8p10NeWF7c=
cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I=
cloud.google.com/go/pubsub v1.1.0/go.mod h1:EwwdRX2sKPjnvnqCa270oGRyludottCI76h+R3AArQw=
cloud.google.com/go/pubsub v1.2.0/go.mod h1:jhfEVHT8odbXTkndysNHCcx0awwzvfOlguIAii9o8iA=
@@ -87,6 +95,8 @@ github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+Ce
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs=
+github.com/bmatcuk/doublestar/v4 v4.6.1 h1:FH9SifrbvJhnlQpztAx++wlkk70QBf0iBWDwNy7PA4I=
+github.com/bmatcuk/doublestar/v4 v4.6.1/go.mod h1:xBQ8jztBU6kakFMg+8WGxn0c6z1fTSPVIjEY1Wr7jzc=
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
@@ -158,9 +168,13 @@ github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU=
github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
+github.com/go-chi/chi/v5 v5.1.0 h1:acVI1TYaD+hhedDJ3r54HyA6sExp3HfXq7QWEEY/xMw=
+github.com/go-chi/chi/v5 v5.1.0/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
+github.com/go-jose/go-jose/v4 v4.0.2 h1:R3l3kkBds16bO7ZFAEEcofK0MkrAJt3jlJznWZG0nvk=
+github.com/go-jose/go-jose/v4 v4.0.2/go.mod h1:WVf9LFMHh/QVrmqrOfqun0C45tMe3RoiKJMPvgWwLfY=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
@@ -174,8 +188,8 @@ github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG
github.com/go-logfmt/logfmt v0.5.1 h1:otpy5pqBCBZ1ng9RQ0dPu4PN7ba75Y/aA+UpowDyNVA=
github.com/go-logfmt/logfmt v0.5.1/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
-github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ=
-github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
+github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
+github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg=
@@ -233,6 +247,8 @@ github.com/golang/mock v1.4.0/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt
github.com/golang/mock v1.4.1/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw=
github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw=
github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4=
+github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
+github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
@@ -275,6 +291,8 @@ github.com/google/go-github/v58 v58.0.0/go.mod h1:k4hxDKEfoWpSqFlc8LTpGd9fu2KrV1
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
+github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
+github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
@@ -296,6 +314,8 @@ github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
github.com/googleapis/gax-go/v2 v2.12.5 h1:8gw9KZK8TiVKB6q3zHY3SBzLnrGp6HQjyfYBYGmXdxA=
github.com/googleapis/gax-go/v2 v2.12.5/go.mod h1:BUDKcWo+RaKq5SC9vVYL0wLADa3VcfswbOMMRmB9H3E=
+github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA=
+github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk=
@@ -487,6 +507,10 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
+github.com/muhlemmer/gu v0.3.1 h1:7EAqmFrW7n3hETvuAdmFmn4hS8W+z3LgKtrnow+YzNM=
+github.com/muhlemmer/gu v0.3.1/go.mod h1:YHtHR+gxM+bKEIIs7Hmi9sPT3ZDUvTN/i88wQpZkrdM=
+github.com/muhlemmer/httpforwarded v0.1.0 h1:x4DLrzXdliq8mprgUMR0olDvHGkou5BJsK/vWUetyzY=
+github.com/muhlemmer/httpforwarded v0.1.0/go.mod h1:yo9czKedo2pdZhoXe+yDkGVbU0TJ0q9oQ90BVoDEtw0=
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
@@ -556,6 +580,8 @@ github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6L
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
+github.com/rs/cors v1.11.0 h1:0B9GE/r9Bc2UxRMMtymBkHTenPkHDv0CW4Y98GBY+po=
+github.com/rs/cors v1.11.0/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8=
github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
@@ -569,6 +595,8 @@ github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
+github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
+github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/slack-go/slack v0.13.1 h1:6UkM3U1OnbhPsYeb1IMkQ6HSNOSikWluwOncJt4Tz/o=
github.com/slack-go/slack v0.13.1/go.mod h1:hlGi5oXA+Gt+yWTPP0plCdRKmjsDxecdHxYQdlMQKOw=
github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
@@ -615,6 +643,12 @@ github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9de
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
+github.com/zitadel/logging v0.6.0 h1:t5Nnt//r+m2ZhhoTmoPX+c96pbMarqJvW1Vq6xFTank=
+github.com/zitadel/logging v0.6.0/go.mod h1:Y4CyAXHpl3Mig6JOszcV5Rqqsojj+3n7y2F591Mp/ow=
+github.com/zitadel/oidc/v3 v3.26.0 h1:BG3OUK+JpuKz7YHJIyUxL5Sl2JV6ePkG42UP4Xv3J2w=
+github.com/zitadel/oidc/v3 v3.26.0/go.mod h1:Cx6AYPTJO5q2mjqF3jaknbKOUjpq1Xui0SYvVhkKuXU=
+github.com/zitadel/schema v1.3.0 h1:kQ9W9tvIwZICCKWcMvCEweXET1OcOyGEuFbHs4o5kg0=
+github.com/zitadel/schema v1.3.0/go.mod h1:NptN6mkBDFvERUCvZHlvWmmME+gmZ44xzwRXwhzsbtc=
go.etcd.io/etcd/api/v3 v3.5.4/go.mod h1:5GB2vv4A4AOn3yk7MftYGHkUfGtDHnEraIjym4dYz5A=
go.etcd.io/etcd/client/pkg/v3 v3.5.4/go.mod h1:IJHfcCEKxYu1Os13ZdwCwIUTUVGYTSAM3YSwc9/Ac1g=
go.etcd.io/etcd/client/v3 v3.5.4/go.mod h1:ZaRkVgBZC+L+dLCjTcF1hRXpgZXQPOvnA/Ak/gq3kiY=
@@ -630,12 +664,14 @@ go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.4
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0/go.mod h1:Mjt1i1INqiaoZOMGR1RIUJN+i3ChKoFRqzrRQhlkbs0=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 h1:jq9TW8u3so/bN+JPT166wjOI6/vQPF6Xe7nMNIltagk=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0/go.mod h1:p8pYQP+m5XfbZm9fxtSKAbM6oIllS7s2AfxrChvc7iw=
-go.opentelemetry.io/otel v1.24.0 h1:0LAOdjNmQeSTzGBzduGe/rU4tZhMwL5rWgtp9Ku5Jfo=
-go.opentelemetry.io/otel v1.24.0/go.mod h1:W7b9Ozg4nkF5tWI5zsXkaKKDjdVjpD4oAt9Qi/MArHo=
-go.opentelemetry.io/otel/metric v1.24.0 h1:6EhoGWWK28x1fbpA4tYTOWBkPefTDQnb8WSGXlc88kI=
-go.opentelemetry.io/otel/metric v1.24.0/go.mod h1:VYhLe1rFfxuTXLgj4CBiyz+9WYBA8pNGJgDcSFRKBco=
-go.opentelemetry.io/otel/trace v1.24.0 h1:CsKnnL4dUAr/0llH9FKuc698G04IrpWV0MQA/Y1YELI=
-go.opentelemetry.io/otel/trace v1.24.0/go.mod h1:HPc3Xr/cOApsBI154IU0OI0HJexz+aw5uPdbs3UCjNU=
+go.opentelemetry.io/otel v1.28.0 h1:/SqNcYk+idO0CxKEUOtKQClMK/MimZihKYMruSMViUo=
+go.opentelemetry.io/otel v1.28.0/go.mod h1:q68ijF8Fc8CnMHKyzqL6akLO46ePnjkgfIMIjUIX9z4=
+go.opentelemetry.io/otel/metric v1.28.0 h1:f0HGvSl1KRAU1DLgLGFjrwVyismPlnuU6JD6bOeuA5Q=
+go.opentelemetry.io/otel/metric v1.28.0/go.mod h1:Fb1eVBFZmLVTMb6PPohq3TO9IIhUisDsbJoL/+uQW4s=
+go.opentelemetry.io/otel/sdk v1.24.0 h1:YMPPDNymmQN3ZgczicBY3B6sf9n62Dlj9pWD3ucgoDw=
+go.opentelemetry.io/otel/sdk v1.24.0/go.mod h1:KVrIYw6tEubO9E96HQpcmpTKDVn9gdv35HoYiQWGDFg=
+go.opentelemetry.io/otel/trace v1.28.0 h1:GhQ9cUuQGmNDd5BTCP2dAvv75RdMxEfTmYejp+lkx9g=
+go.opentelemetry.io/otel/trace v1.28.0/go.mod h1:jPyXzNPg6da9+38HEwElrQiHlVMTnVfM3/yv2OlIHaI=
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ=
go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
@@ -812,6 +848,7 @@ golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220708085239-5a0f0661e09d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -948,8 +985,10 @@ google.golang.org/genproto v0.0.0-20200729003335-053ba62fc06f/go.mod h1:FWY/as6D
google.golang.org/genproto v0.0.0-20200804131852-c06518451d9c/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
google.golang.org/genproto v0.0.0-20200825200019-8632dd797987/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
google.golang.org/genproto v0.0.0-20210602131652-f16073e35f0c/go.mod h1:UODoCrxHCcBojKKwX1terBiRUaqAsFqJiF615XL43r0=
-google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 h1:MuYw1wJzT+ZkybKfaOXKp5hJiZDn2iHaXRw0mRYdHSc=
-google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4/go.mod h1:px9SlOOZBg1wM1zdnr8jEL4CNGUBZ+ZKYtNPApNQc4c=
+google.golang.org/genproto v0.0.0-20240708141625-4ad9e859172b h1:dSTjko30weBaMj3eERKc0ZVXW4GudCswM3m+P++ukU0=
+google.golang.org/genproto v0.0.0-20240708141625-4ad9e859172b/go.mod h1:FfBgJBJg9GcpPvKIuHSZ/aE1g2ecGL74upMzGZjiGEY=
+google.golang.org/genproto/googleapis/api v0.0.0-20240701130421-f6361c86f094 h1:0+ozOGcrp+Y8Aq8TLNN2Aliibms5LEzsq99ZZmAGYm0=
+google.golang.org/genproto/googleapis/api v0.0.0-20240701130421-f6361c86f094/go.mod h1:fJ/e3If/Q67Mj99hin0hMhiNyCRmt6BQ2aWIJshUSJw=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240709173604-40e1e62336c5 h1:SbSDUWW1PAO24TNpLdeheoYPd7kllICcLU52x6eD4kQ=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240709173604-40e1e62336c5/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY=
google.golang.org/grpc v1.14.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw=
diff --git a/sherlock/html/close.html b/sherlock/html/close.html
index 78f0cacad..4eb1107fc 100644
--- a/sherlock/html/close.html
+++ b/sherlock/html/close.html
@@ -3,7 +3,7 @@
Sherlock
-
+
diff --git a/sherlock/html/logged-out.html b/sherlock/html/logged-out.html
new file mode 100644
index 000000000..ca6ef777b
--- /dev/null
+++ b/sherlock/html/logged-out.html
@@ -0,0 +1,9 @@
+
+
+
+ Sherlock - Logged Out
+
+
+ You've been logged out. You may close this window.
+
+
diff --git a/sherlock/internal/api/login/handler_test.go b/sherlock/internal/api/login/handler_test.go
new file mode 100644
index 000000000..516e636b5
--- /dev/null
+++ b/sherlock/internal/api/login/handler_test.go
@@ -0,0 +1,37 @@
+package login
+
+import (
+ "github.com/broadinstitute/sherlock/sherlock/internal/middleware/authentication"
+ "github.com/broadinstitute/sherlock/sherlock/internal/middleware/authentication/test_users"
+ "github.com/broadinstitute/sherlock/sherlock/internal/models"
+ "github.com/broadinstitute/sherlock/sherlock/internal/oidc_models"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/suite"
+ "testing"
+)
+
+type handlerSuite struct {
+ suite.Suite
+ test_users.TestUserHelper
+ models.TestSuiteHelper
+ oidc_models.TestClientHelper
+
+ internalRouter *gin.Engine
+}
+
+func TestLoginSuite(t *testing.T) {
+ suite.Run(t, new(handlerSuite))
+}
+
+func (s *handlerSuite) SetupSuite() {
+ s.TestSuiteHelper.SetupSuite()
+ // Reduces console output
+ gin.SetMode(gin.TestMode)
+}
+
+func (s *handlerSuite) SetupTest() {
+ s.TestSuiteHelper.SetupTest()
+ s.internalRouter = gin.New()
+ s.internalRouter.Use(authentication.TestMiddleware(s.DB, s.TestData)...)
+ s.internalRouter.GET("/login", LoginGet)
+}
diff --git a/sherlock/internal/api/login/login.go b/sherlock/internal/api/login/login.go
new file mode 100644
index 000000000..c467c14c6
--- /dev/null
+++ b/sherlock/internal/api/login/login.go
@@ -0,0 +1,55 @@
+package login
+
+import (
+ "database/sql"
+ "fmt"
+ "github.com/broadinstitute/sherlock/sherlock/internal/errors"
+ "github.com/broadinstitute/sherlock/sherlock/internal/middleware/authentication"
+ "github.com/broadinstitute/sherlock/sherlock/internal/oidc_models"
+ "github.com/gin-gonic/gin"
+ "github.com/google/uuid"
+ "gorm.io/gorm/clause"
+ "net/http"
+ "time"
+)
+
+// LoginGet is meant to handle redirects to /login?id=... from the OIDC subsystem,
+// read the IAP info from the request, "log the user in", and redirect back to the
+// OIDC subsystem.
+//
+// This isn't in Swagger as it's not an API really -- it's meant to play nice with
+// browsers.
+func LoginGet(ctx *gin.Context) {
+ user, err := authentication.MustUseUser(ctx)
+ if err != nil {
+ return
+ }
+
+ db, err := authentication.MustUseDB(ctx)
+ if err != nil {
+ return
+ }
+
+ authRequestID := ctx.Query("id")
+ if authRequestID == "" {
+ errors.AbortRequest(ctx, fmt.Errorf("(%s) no auth request ID passed", errors.BadRequest))
+ return
+ }
+
+ parsedAuthRequestID, err := uuid.Parse(authRequestID)
+ if err != nil {
+ errors.AbortRequest(ctx, fmt.Errorf("(%s) invalid auth request ID passed", errors.BadRequest))
+ return
+ }
+
+ err = db.Omit(clause.Associations).Where(&oidc_models.AuthRequest{ID: parsedAuthRequestID}).Where("done_at is null").Updates(&oidc_models.AuthRequest{
+ UserID: &user.ID,
+ DoneAt: sql.NullTime{Time: time.Now(), Valid: true},
+ }).Error
+ if err != nil {
+ errors.AbortRequest(ctx, fmt.Errorf("could not update auth request: %w", err))
+ return
+ }
+
+ ctx.Redirect(http.StatusFound, fmt.Sprintf("/oidc/authorize/callback?id=%s", authRequestID))
+}
diff --git a/sherlock/internal/api/login/login_test.go b/sherlock/internal/api/login/login_test.go
new file mode 100644
index 000000000..56fa18fa1
--- /dev/null
+++ b/sherlock/internal/api/login/login_test.go
@@ -0,0 +1,58 @@
+package login
+
+import (
+ "github.com/broadinstitute/sherlock/sherlock/internal/oidc_models"
+ "github.com/google/uuid"
+ "github.com/zitadel/oidc/v3/pkg/oidc"
+ "gorm.io/gorm/clause"
+ "net/http"
+ "net/http/httptest"
+)
+
+func (s *handlerSuite) TestLoginGet_noAuthRequestID() {
+ request, err := http.NewRequest("GET", "/login", nil)
+ s.NoError(err)
+ recorder := httptest.NewRecorder()
+ s.internalRouter.ServeHTTP(recorder, request)
+
+ s.Equal(http.StatusBadRequest, recorder.Code)
+}
+
+func (s *handlerSuite) TestLoginGet_invalidAuthRequestID() {
+ request, err := http.NewRequest("GET", "/login?id=invalid", nil)
+ s.NoError(err)
+ recorder := httptest.NewRecorder()
+ s.internalRouter.ServeHTTP(recorder, request)
+
+ s.Equal(http.StatusBadRequest, recorder.Code)
+}
+
+func (s *handlerSuite) TestLoginGet() {
+ clientID, _, err := s.GenerateClient(s.DB)
+ s.NoError(err)
+
+ authRequest := oidc_models.AuthRequest{
+ ID: uuid.New(),
+ ClientID: clientID,
+ Nonce: "some-nonce",
+ RedirectURI: s.GeneratedClientRedirectURI(),
+ Scopes: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, "groups"},
+ State: "some-state",
+ }
+
+ s.NoError(s.DB.Omit(clause.Associations).Create(&authRequest).Error)
+
+ request, err := http.NewRequest("GET", "/login?id="+authRequest.GetID(), nil)
+ s.NoError(err)
+ s.UseSuitableUserFor(request)
+ recorder := httptest.NewRecorder()
+ s.internalRouter.ServeHTTP(recorder, request)
+
+ s.Equal(http.StatusFound, recorder.Code)
+
+ // Check that the auth request was marked as done
+ var reloadedAuthRequest oidc_models.AuthRequest
+ s.NoError(s.DB.Where("id = ?", authRequest.ID.String()).First(&reloadedAuthRequest).Error)
+ s.True(reloadedAuthRequest.DoneAt.Valid)
+ s.Equal(s.TestData.User_Suitable().ID, *reloadedAuthRequest.UserID)
+}
diff --git a/sherlock/internal/api/sherlock/chart_releases_v3_test.go b/sherlock/internal/api/sherlock/chart_releases_v3_test.go
index ea9de4a7a..29bca0cdf 100644
--- a/sherlock/internal/api/sherlock/chart_releases_v3_test.go
+++ b/sherlock/internal/api/sherlock/chart_releases_v3_test.go
@@ -3,6 +3,7 @@ package sherlock
import (
"fmt"
"github.com/broadinstitute/sherlock/go-shared/pkg/utils"
+ "github.com/broadinstitute/sherlock/sherlock/internal/config"
"github.com/broadinstitute/sherlock/sherlock/internal/models"
"github.com/stretchr/testify/assert"
"gorm.io/gorm"
@@ -345,6 +346,7 @@ func (s *handlerSuite) TestChartReleaseV3_toModel() {
}
func Test_chartReleaseFromModel(t *testing.T) {
+ config.LoadTestConfig()
now := time.Now()
type args struct {
model models.ChartRelease
@@ -411,8 +413,8 @@ func Test_chartReleaseFromModel(t *testing.T) {
},
CiIdentifier: &CiIdentifierV3{CommonFields: CommonFields{ID: 2}},
ChartInfo: &ChartV3{CommonFields: CommonFields{ID: 3}, ChartV3Create: ChartV3Create{Name: "leonardo"}},
- ClusterInfo: &ClusterV3{CommonFields: CommonFields{ID: 4}, ClusterV3Create: ClusterV3Create{Name: "terra-prod"}},
- EnvironmentInfo: &EnvironmentV3{CommonFields: CommonFields{ID: 5}, EnvironmentV3Create: EnvironmentV3Create{Name: "prod"}},
+ ClusterInfo: &ClusterV3{CommonFields: CommonFields{ID: 4}, ClusterV3Create: ClusterV3Create{Name: "terra-prod", ClusterV3Edit: ClusterV3Edit{RequiredRole: utils.PointerTo(config.Config.String("model.roles.substituteEmptyRequiredRoleWithValue"))}}},
+ EnvironmentInfo: &EnvironmentV3{CommonFields: CommonFields{ID: 5}, EnvironmentV3Create: EnvironmentV3Create{Name: "prod", EnvironmentV3Edit: EnvironmentV3Edit{RequiredRole: utils.PointerTo(config.Config.String("model.roles.substituteEmptyRequiredRoleWithValue"))}}},
AppVersionReference: "leonardo/v1.0.0",
AppVersionInfo: &AppVersionV3{CommonFields: CommonFields{ID: 8}, AppVersionV3Create: AppVersionV3Create{AppVersion: "v1.0.0"}},
ChartVersionReference: "leonardo/2.0.0",
diff --git a/sherlock/internal/api/sherlock/clusters_v3.go b/sherlock/internal/api/sherlock/clusters_v3.go
index 87e4647fd..e0259e06a 100644
--- a/sherlock/internal/api/sherlock/clusters_v3.go
+++ b/sherlock/internal/api/sherlock/clusters_v3.go
@@ -3,6 +3,7 @@ package sherlock
import (
"fmt"
"github.com/broadinstitute/sherlock/go-shared/pkg/utils"
+ "github.com/broadinstitute/sherlock/sherlock/internal/config"
"github.com/broadinstitute/sherlock/sherlock/internal/models"
"gorm.io/gorm"
)
@@ -90,6 +91,8 @@ func clusterFromModel(model models.Cluster) ClusterV3 {
ret.RequiredRole = model.RequiredRole.Name
} else if model.RequiredRoleID != nil {
ret.RequiredRole = utils.PointerTo(utils.UintToString(*model.RequiredRoleID))
+ } else if substituteEmptyRole := config.Config.String("model.roles.substituteEmptyRequiredRoleWithValue"); substituteEmptyRole != "" {
+ ret.RequiredRole = &substituteEmptyRole
}
return ret
}
diff --git a/sherlock/internal/api/sherlock/clusters_v3_edit_test.go b/sherlock/internal/api/sherlock/clusters_v3_edit_test.go
index 46755fd28..4165252b1 100644
--- a/sherlock/internal/api/sherlock/clusters_v3_edit_test.go
+++ b/sherlock/internal/api/sherlock/clusters_v3_edit_test.go
@@ -3,6 +3,7 @@ package sherlock
import (
"fmt"
"github.com/broadinstitute/sherlock/go-shared/pkg/utils"
+ "github.com/broadinstitute/sherlock/sherlock/internal/config"
"github.com/broadinstitute/sherlock/sherlock/internal/errors"
"github.com/broadinstitute/sherlock/sherlock/internal/models"
"github.com/gin-gonic/gin"
@@ -182,7 +183,9 @@ func (s *handlerSuite) TestClusterV3Edit_clearRequiredRole() {
}),
&got)
s.Equal(http.StatusOK, code)
- s.Nil(got.RequiredRole)
+ if s.NotNil(got.RequiredRole) {
+ s.Equal(config.Config.String("model.roles.substituteEmptyRequiredRoleWithValue"), *got.RequiredRole)
+ }
var inDB models.Cluster
s.NoError(s.DB.First(&inDB, toEdit.ID).Error)
s.Nil(inDB.RequiredRoleID)
diff --git a/sherlock/internal/api/sherlock/environments_v3.go b/sherlock/internal/api/sherlock/environments_v3.go
index 13b3b396c..03d49a14c 100644
--- a/sherlock/internal/api/sherlock/environments_v3.go
+++ b/sherlock/internal/api/sherlock/environments_v3.go
@@ -4,6 +4,7 @@ import (
"database/sql"
"fmt"
"github.com/broadinstitute/sherlock/go-shared/pkg/utils"
+ "github.com/broadinstitute/sherlock/sherlock/internal/config"
"github.com/broadinstitute/sherlock/sherlock/internal/errors"
"github.com/broadinstitute/sherlock/sherlock/internal/models"
"github.com/google/uuid"
@@ -232,6 +233,8 @@ func environmentFromModel(model models.Environment) EnvironmentV3 {
ret.RequiredRole = model.RequiredRole.Name
} else if model.RequiredRoleID != nil {
ret.RequiredRole = utils.PointerTo(utils.UintToString(*model.RequiredRoleID))
+ } else if substituteEmptyRole := config.Config.String("model.roles.substituteEmptyRequiredRoleWithValue"); substituteEmptyRole != "" {
+ ret.RequiredRole = &substituteEmptyRole
}
return ret
}
diff --git a/sherlock/internal/api/sherlock/environments_v3_edit_test.go b/sherlock/internal/api/sherlock/environments_v3_edit_test.go
index 8af9da544..6ff176d16 100644
--- a/sherlock/internal/api/sherlock/environments_v3_edit_test.go
+++ b/sherlock/internal/api/sherlock/environments_v3_edit_test.go
@@ -3,6 +3,7 @@ package sherlock
import (
"fmt"
"github.com/broadinstitute/sherlock/go-shared/pkg/utils"
+ "github.com/broadinstitute/sherlock/sherlock/internal/config"
"github.com/broadinstitute/sherlock/sherlock/internal/errors"
"github.com/broadinstitute/sherlock/sherlock/internal/models"
"github.com/gin-gonic/gin"
@@ -129,7 +130,9 @@ func (s *handlerSuite) TestEnvironmentsV3Edit_clearRequiredRole() {
}),
&got)
s.Equal(http.StatusOK, code)
- s.Nil(got.RequiredRole)
+ if s.NotNil(got.RequiredRole) {
+ s.Equal(config.Config.String("model.roles.substituteEmptyRequiredRoleWithValue"), *got.RequiredRole)
+ }
var inDB models.Environment
s.NoError(s.DB.First(&inDB, toEdit.ID).Error)
s.Nil(inDB.RequiredRoleID)
diff --git a/sherlock/internal/api/sherlock/environments_v3_test.go b/sherlock/internal/api/sherlock/environments_v3_test.go
index 6aa037036..ff4de6377 100644
--- a/sherlock/internal/api/sherlock/environments_v3_test.go
+++ b/sherlock/internal/api/sherlock/environments_v3_test.go
@@ -3,6 +3,7 @@ package sherlock
import (
"database/sql"
"github.com/broadinstitute/sherlock/go-shared/pkg/utils"
+ "github.com/broadinstitute/sherlock/sherlock/internal/config"
"github.com/broadinstitute/sherlock/sherlock/internal/models"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
@@ -310,6 +311,7 @@ func (s *handlerSuite) TestEnvironmentV3_toModel() {
}
func Test_environmentFromModel(t *testing.T) {
+ config.LoadTestConfig()
now := time.Now()
nowString := utils.TimePtrToISO8601(&now)
nowTimeParsedAgain, err := utils.ISO8601PtrToTime(nowString)
@@ -326,7 +328,13 @@ func Test_environmentFromModel(t *testing.T) {
{
name: "empty",
args: args{},
- want: EnvironmentV3{},
+ want: EnvironmentV3{
+ EnvironmentV3Create: EnvironmentV3Create{
+ EnvironmentV3Edit: EnvironmentV3Edit{
+ RequiredRole: utils.PointerTo(config.Config.String("model.roles.substituteEmptyRequiredRoleWithValue")),
+ },
+ },
+ },
},
{
name: "full",
@@ -380,8 +388,8 @@ func Test_environmentFromModel(t *testing.T) {
UpdatedAt: now.Add(-time.Minute),
},
CiIdentifier: &CiIdentifierV3{CommonFields: CommonFields{ID: 2}},
- TemplateEnvironmentInfo: &EnvironmentV3{CommonFields: CommonFields{ID: 3}, EnvironmentV3Create: EnvironmentV3Create{Name: "name-3"}},
- DefaultClusterInfo: &ClusterV3{CommonFields: CommonFields{ID: 4}, ClusterV3Create: ClusterV3Create{Name: "name-4"}},
+ TemplateEnvironmentInfo: &EnvironmentV3{CommonFields: CommonFields{ID: 3}, EnvironmentV3Create: EnvironmentV3Create{Name: "name-3", EnvironmentV3Edit: EnvironmentV3Edit{RequiredRole: utils.PointerTo(config.Config.String("model.roles.substituteEmptyRequiredRoleWithValue"))}}},
+ DefaultClusterInfo: &ClusterV3{CommonFields: CommonFields{ID: 4}, ClusterV3Create: ClusterV3Create{Name: "name-4", ClusterV3Edit: ClusterV3Edit{RequiredRole: utils.PointerTo(config.Config.String("model.roles.substituteEmptyRequiredRoleWithValue"))}}},
PagerdutyIntegrationInfo: &PagerdutyIntegrationV3{CommonFields: CommonFields{ID: 6}, PagerdutyID: "blah"},
OwnerInfo: &UserV3{CommonFields: CommonFields{ID: 5}, Email: "example@example.com", Suitable: utils.PointerTo(false),
SuitabilityDescription: utils.PointerTo("no matching suitability record found or loaded; assuming unsuitable")},
@@ -428,6 +436,7 @@ func Test_environmentFromModel(t *testing.T) {
EnvironmentV3Create: EnvironmentV3Create{
EnvironmentV3Edit: EnvironmentV3Edit{
PagerdutyIntegration: utils.PointerTo("6"),
+ RequiredRole: utils.PointerTo(config.Config.String("model.roles.substituteEmptyRequiredRoleWithValue")),
},
},
},
diff --git a/sherlock/internal/boot/application.go b/sherlock/internal/boot/application.go
index 97c3d09d8..ead5651f7 100644
--- a/sherlock/internal/boot/application.go
+++ b/sherlock/internal/boot/application.go
@@ -12,6 +12,7 @@ import (
"github.com/broadinstitute/sherlock/sherlock/internal/metrics"
"github.com/broadinstitute/sherlock/sherlock/internal/middleware/authentication/gha_oidc"
"github.com/broadinstitute/sherlock/sherlock/internal/models"
+ "github.com/broadinstitute/sherlock/sherlock/internal/oidc_models"
"github.com/broadinstitute/sherlock/sherlock/internal/role_propagation"
"github.com/broadinstitute/sherlock/sherlock/internal/suitability_synchronization"
"github.com/gin-gonic/gin"
@@ -136,6 +137,15 @@ func (a *Application) Start() {
}
}
+ if config.Config.Bool("oidc.enable") {
+ log.Info().Msgf("BOOT | initializing OIDC provider...")
+ if err = oidc_models.Init(ctx, a.gormDB); err != nil {
+ log.Fatal().Err(err).Msgf("oidc_models.Init() error")
+ }
+ go oidc_models.KeepSigningKeysRotated(ctx, a.gormDB)
+ go oidc_models.KeepExpiringRefreshTokens(ctx, a.gormDB)
+ }
+
go models.KeepAutoAssigningRoles(ctx, a.gormDB)
log.Info().Msgf("BOOT | building Gin router...")
diff --git a/sherlock/internal/boot/router.go b/sherlock/internal/boot/router.go
index 26ca31f99..5bd99b83d 100644
--- a/sherlock/internal/boot/router.go
+++ b/sherlock/internal/boot/router.go
@@ -6,6 +6,7 @@ import (
"github.com/broadinstitute/sherlock/go-shared/pkg/version"
"github.com/broadinstitute/sherlock/sherlock/docs"
"github.com/broadinstitute/sherlock/sherlock/html"
+ "github.com/broadinstitute/sherlock/sherlock/internal/api/login"
"github.com/broadinstitute/sherlock/sherlock/internal/api/misc"
"github.com/broadinstitute/sherlock/sherlock/internal/api/sherlock"
"github.com/broadinstitute/sherlock/sherlock/internal/config"
@@ -16,11 +17,13 @@ import (
"github.com/broadinstitute/sherlock/sherlock/internal/middleware/csrf_protection"
"github.com/broadinstitute/sherlock/sherlock/internal/middleware/headers"
"github.com/broadinstitute/sherlock/sherlock/internal/middleware/logger"
+ "github.com/broadinstitute/sherlock/sherlock/internal/oidc_models"
"github.com/gin-gonic/gin"
swaggo_files "github.com/swaggo/files"
swaggo_gin "github.com/swaggo/gin-swagger"
"gorm.io/gorm"
"net/http"
+ "strings"
)
// @title Sherlock
@@ -57,6 +60,10 @@ func BuildRouter(ctx context.Context, db *gorm.DB) *gin.Engine {
cors.Cors(),
headers.Headers())
+ resourceMiddleware := make(gin.HandlersChain, 0)
+ resourceMiddleware = append(resourceMiddleware, csrf_protection.CsrfProtection())
+ resourceMiddleware = append(resourceMiddleware, authentication.Middleware(db)...)
+
// Replace Gin's standard fallback responses with our standard error format for friendlier client behavior
router.NoRoute(func(ctx *gin.Context) {
errors.AbortRequest(ctx, fmt.Errorf("(%s) no handler for %s found", errors.NotFound, ctx.Request.URL.Path))
@@ -80,16 +87,27 @@ func BuildRouter(ctx context.Context, db *gorm.DB) *gin.Engine {
}))
router.GET("", func(ctx *gin.Context) { ctx.Redirect(http.StatusMovedPermanently, "/swagger/index.html") })
+ if config.Config.Bool("oidc.enable") {
+ // delegate /oidc/* to OIDC library, trimming the path prefix because of how the library expects to receive requests
+ // https://broadinstitute.slack.com/archives/CQ6SL4N5T/p1721406732128199
+ router.Any("/oidc/*any", func(ctx *gin.Context) {
+ req := ctx.Request.Clone(ctx)
+ req.RequestURI = strings.TrimPrefix(req.RequestURI, "/oidc")
+ req.URL.Path = strings.TrimPrefix(req.URL.Path, "/oidc")
+ oidc_models.Provider.ServeHTTP(ctx.Writer, req)
+ })
+ // authenticate /login handler to complete OIDC auth requests
+ router.GET("/login", append(resourceMiddleware, login.LoginGet)...)
+ }
+
// routes under /api require authentication and may use the database
- apiRouter := router.Group("api")
- apiRouter.Use(csrf_protection.CsrfProtection())
- apiRouter.Use(authentication.Middleware(db)...)
+ apiRouter := router.Group("/api", resourceMiddleware...)
// refactored sherlock API, under /api/{type}/v3
sherlock.ConfigureRoutes(apiRouter)
// special error for the removed "v2" API, under /api/v2/{type}
- apiRouter.Any("v2/*path", func(ctx *gin.Context) {
+ apiRouter.Any("/v2/*path", func(ctx *gin.Context) {
errors.AbortRequest(ctx, fmt.Errorf("(%s) sherlock's v2 API has been removed; reach out to #dsp-devops-champions for help updating your client", errors.NotFound))
})
diff --git a/sherlock/internal/config/config.go b/sherlock/internal/config/config.go
index faee1afe7..ef8f80e3b 100644
--- a/sherlock/internal/config/config.go
+++ b/sherlock/internal/config/config.go
@@ -112,6 +112,7 @@ func configureLogging(infoMessages ...string) {
// log messages using Go's built-in logging. We can at least format those messages
// correctly by redirecting that into zerolog, though it won't have proper leveling
// information
+ stdlog.SetFlags(0)
stdlog.SetOutput(log.Logger)
if logLevel := Config.String("log.level"); logLevel != "" {
diff --git a/sherlock/internal/models/user.go b/sherlock/internal/models/user.go
index 5e7313134..1e1c240c1 100644
--- a/sherlock/internal/models/user.go
+++ b/sherlock/internal/models/user.go
@@ -153,11 +153,11 @@ func (u *User) AlphaNumericHyphenatedUsername() string {
return string(ret)
}
-func (u *User) NameOrEmailHandle() string {
+func (u *User) NameOrUsername() string {
if u.Name != nil {
return *u.Name
} else {
- return strings.Split(u.Email, "@")[0]
+ return u.AlphaNumericHyphenatedUsername()
}
}
@@ -165,7 +165,7 @@ func (u *User) SlackReference(mention bool) string {
if u.SlackID != nil && mention {
return fmt.Sprintf("<@%s>", *u.SlackID)
} else {
- return fmt.Sprintf("", u.Email, u.NameOrEmailHandle())
+ return fmt.Sprintf("", u.Email, u.NameOrUsername())
}
}
diff --git a/sherlock/internal/oidc_models/auth_request.go b/sherlock/internal/oidc_models/auth_request.go
new file mode 100644
index 000000000..75fa4f151
--- /dev/null
+++ b/sherlock/internal/oidc_models/auth_request.go
@@ -0,0 +1,105 @@
+package oidc_models
+
+import (
+ "database/sql"
+ "github.com/broadinstitute/sherlock/go-shared/pkg/utils"
+ "github.com/broadinstitute/sherlock/sherlock/internal/models"
+ "github.com/google/uuid"
+ "github.com/zitadel/oidc/v3/pkg/oidc"
+ "github.com/zitadel/oidc/v3/pkg/op"
+ "time"
+)
+
+// "AuthRequest implements the op.AuthRequest interface"
+var _ op.AuthRequest = &AuthRequest{}
+
+type AuthRequest struct {
+ ID uuid.UUID `gorm:"primaryKey"`
+ CreatedAt time.Time
+ DoneAt sql.NullTime
+
+ Client *Client `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE"`
+ ClientID string // AKA Audience, Application ID
+ Nonce string
+ RedirectURI string // AKA CallbackURI
+ ResponseType oidc.ResponseType `gorm:"string"`
+ ResponseMode oidc.ResponseMode `gorm:"string"`
+ Scopes oidc.SpaceDelimitedArray
+ State string
+
+ CodeChallenge string
+ CodeChallengeMethod oidc.CodeChallengeMethod `gorm:"string"`
+
+ // The user won't be filled until the request has been logged-in to
+ User *models.User
+ UserID *uint
+}
+
+func (r *AuthRequest) GetID() string {
+ return r.ID.String()
+}
+
+func (r *AuthRequest) GetACR() string {
+ return ""
+}
+
+func (r *AuthRequest) GetAMR() []string {
+ // Return an empty array because we don't know for sure what AMRs IAP enforced on the caller.
+ // https://openid.net/specs/openid-connect-core-1_0.html#IDToken
+ // https://www.rfc-editor.org/info/rfc8176
+ return []string{}
+}
+
+func (r *AuthRequest) GetAudience() []string {
+ return []string{r.ClientID}
+}
+
+func (r *AuthRequest) GetAuthTime() time.Time {
+ return r.CreatedAt
+}
+
+func (r *AuthRequest) GetClientID() string {
+ return r.ClientID
+}
+
+func (r *AuthRequest) GetCodeChallenge() *oidc.CodeChallenge {
+ return &oidc.CodeChallenge{
+ Challenge: r.CodeChallenge,
+ Method: r.CodeChallengeMethod,
+ }
+}
+
+func (r *AuthRequest) GetNonce() string {
+ return r.Nonce
+}
+
+func (r *AuthRequest) GetRedirectURI() string {
+ return r.RedirectURI
+}
+
+func (r *AuthRequest) GetResponseType() oidc.ResponseType {
+ return r.ResponseType
+}
+
+func (r *AuthRequest) GetResponseMode() oidc.ResponseMode {
+ return r.ResponseMode
+}
+
+func (r *AuthRequest) GetScopes() []string {
+ return r.Scopes
+}
+
+func (r *AuthRequest) GetState() string {
+ return r.State
+}
+
+func (r *AuthRequest) GetSubject() string {
+ if r.UserID == nil {
+ return ""
+ }
+ return utils.UintToString(*r.UserID)
+}
+
+func (r *AuthRequest) Done() bool {
+ return r.DoneAt.Valid
+}
diff --git a/sherlock/internal/oidc_models/auth_request_code.go b/sherlock/internal/oidc_models/auth_request_code.go
new file mode 100644
index 000000000..8418c4059
--- /dev/null
+++ b/sherlock/internal/oidc_models/auth_request_code.go
@@ -0,0 +1,13 @@
+package oidc_models
+
+import (
+ "github.com/google/uuid"
+ "time"
+)
+
+type AuthRequestCode struct {
+ Code string `gorm:"primaryKey"`
+ CreatedAt time.Time
+ AuthRequestID uuid.UUID
+ AuthRequest *AuthRequest `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE"`
+}
diff --git a/sherlock/internal/oidc_models/boot.go b/sherlock/internal/oidc_models/boot.go
new file mode 100644
index 000000000..a37351526
--- /dev/null
+++ b/sherlock/internal/oidc_models/boot.go
@@ -0,0 +1,72 @@
+package oidc_models
+
+import (
+ kms "cloud.google.com/go/kms/apiv1"
+ "cloud.google.com/go/kms/apiv1/kmspb"
+ "context"
+ "fmt"
+ "github.com/broadinstitute/sherlock/sherlock/internal/clients/slack"
+ "github.com/broadinstitute/sherlock/sherlock/internal/config"
+ "gorm.io/gorm"
+ "time"
+)
+
+var (
+ kmsKey string
+ kmsClient *kms.KeyManagementClient
+)
+
+func Init(ctx context.Context, db *gorm.DB) error {
+ if config.Config.Bool("oidc.signingKeyEncryptionKMSEnable") {
+ kmsKey = config.Config.String("oidc.signingKeyEncryptionKMSKeyName")
+ var err error
+ kmsClient, err = kms.NewKeyManagementClient(ctx)
+ if err != nil {
+ return fmt.Errorf("error creating KMS client: %w", err)
+ }
+ response, err := kmsClient.GetCryptoKey(ctx, &kmspb.GetCryptoKeyRequest{
+ Name: kmsKey,
+ })
+ if err != nil {
+ return fmt.Errorf("error getting KMS key '%s': %w", kmsKey, err)
+ } else if response.Purpose != kmspb.CryptoKey_ENCRYPT_DECRYPT {
+ return fmt.Errorf("KMS key '%s' is not an encrypt/decrypt key", kmsKey)
+ }
+ } else if config.Config.String("mode") != "debug" {
+ return fmt.Errorf("oidc.signingKeyEncryptionKMSEnable is false, but mode is not debug")
+ }
+
+ if err := rotateSigningKeys(ctx, db); err != nil {
+ return fmt.Errorf("error rotating oidc signing keys: %w", err)
+ }
+
+ return initProvider(db)
+}
+
+func KeepSigningKeysRotated(ctx context.Context, db *gorm.DB) {
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case <-time.After(15 * time.Minute):
+ err := rotateSigningKeys(ctx, db)
+ if err != nil {
+ slack.ReportError(ctx, "error rotating oidc signing keys", err)
+ }
+ }
+ }
+}
+
+func KeepExpiringRefreshTokens(ctx context.Context, db *gorm.DB) {
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case <-time.After(time.Minute):
+ err := expireRefreshTokens(db)
+ if err != nil {
+ slack.ReportError(ctx, "error expiring refresh tokens", err)
+ }
+ }
+ }
+}
diff --git a/sherlock/internal/oidc_models/client.go b/sherlock/internal/oidc_models/client.go
new file mode 100644
index 000000000..25f165808
--- /dev/null
+++ b/sherlock/internal/oidc_models/client.go
@@ -0,0 +1,124 @@
+package oidc_models
+
+import (
+ "fmt"
+ "github.com/zitadel/oidc/v3/pkg/oidc"
+ "github.com/zitadel/oidc/v3/pkg/op"
+ "time"
+)
+
+// "Client implements the op.Client interface"
+var _ op.Client = &Client{}
+
+type Client struct {
+ ID string `gorm:"primaryKey"`
+ ClientSecretHash []byte // PBKDF2 derived key, HMAC-SHA-512; should be empty for PKCE; Sherlock will derive the same number of bytes as the length of this field automatically
+ ClientSecretSalt []byte // Salt for ClientSecretHash; should be empty for PKCE
+ ClientSecretIterations int // Number of iterations for ClientSecretHash; should be empty for PKCE. https://cheatsheetseries.owasp.org/cheatsheets/Password_Storage_Cheat_Sheet.html#pbkdf2
+ ClientRedirectURIs oidc.SpaceDelimitedArray // In dev mode, may include globs or use http
+ ClientPostLogoutRedirectURIs oidc.SpaceDelimitedArray // In dev mode, may include globs or use http
+ ClientApplicationType op.ApplicationType // "web", "user_agent", or "native"
+ ClientAuthMethod oidc.AuthMethod // "client_secret_basic", "client_secret_post", "none", or "private_key_jwt"
+ ClientIDTokenLifetime int64 // Nanoseconds
+ ClientDevMode bool // Allow insecure/nonstandard redirect URIs, like globs and http
+ ClientClockSkew int64 // Nanoseconds
+}
+
+func (c *Client) GetID() string {
+ return c.ID
+}
+
+func (c *Client) RedirectURIs() []string {
+ return c.ClientRedirectURIs
+}
+
+func (c *Client) PostLogoutRedirectURIs() []string {
+ return c.ClientPostLogoutRedirectURIs
+}
+
+func (c *Client) ApplicationType() op.ApplicationType {
+ return c.ClientApplicationType
+}
+
+func (c *Client) AuthMethod() oidc.AuthMethod {
+ return c.ClientAuthMethod
+}
+
+func (c *Client) ResponseTypes() []oidc.ResponseType {
+ return []oidc.ResponseType{oidc.ResponseTypeCode, oidc.ResponseTypeIDTokenOnly, oidc.ResponseTypeIDToken}
+}
+
+func (c *Client) GrantTypes() []oidc.GrantType {
+ return []oidc.GrantType{oidc.GrantTypeCode, oidc.GrantTypeRefreshToken}
+}
+
+func (c *Client) LoginURL(authRequestID string) string {
+ // To sherlock/sherlock/internal/api/login/login.go
+ return fmt.Sprintf("/login?id=%s", authRequestID)
+}
+
+func (c *Client) AccessTokenType() op.AccessTokenType {
+ return op.AccessTokenTypeJWT
+}
+
+func (c *Client) IDTokenLifetime() time.Duration {
+ return time.Duration(c.ClientIDTokenLifetime)
+}
+
+func (c *Client) DevMode() bool {
+ return c.ClientDevMode
+}
+
+func (c *Client) IDTokenUserinfoClaimsAssertion() bool {
+ // It technically violates spec to return userinfo claims in ID tokens when an access token is provided.
+ // The caller is expected to call the userinfo endpoint with the access token to get the userinfo data.
+ // But... Dex supports and actually expects these claims in the initial ID token. It can be configured
+ // to call the userinfo endpoint, but why make the extra roundtrip?
+ //
+ // In the future we could always have this be a bool field on Client itself to customize the behavior
+ // per client.
+ return true
+}
+
+func (c *Client) ClockSkew() time.Duration {
+ return time.Duration(c.ClientClockSkew)
+}
+
+func (c *Client) RestrictAdditionalIdTokenScopes() func(scopes []string) []string {
+ return func(scopes []string) []string { return scopes }
+}
+
+func (c *Client) RestrictAdditionalAccessTokenScopes() func(scopes []string) []string {
+ // Hardcode no client-level extra restrictions on access token scopes
+ return func(scopes []string) []string { return scopes }
+}
+
+func (c *Client) IsScopeAllowed(scope string) bool {
+ // Hardcode that groups are allowed, because currently that's the whole point of Sherlock doing OIDC
+ return scope == groupsClaim
+}
+
+// "devModeGlobClient implements the op.HasRedirectGlobs interface"
+var _ op.HasRedirectGlobs = &devModeGlobClient{}
+
+type devModeGlobClient struct {
+ Client
+}
+
+func (c *devModeGlobClient) RedirectURIGlobs() []string {
+ return c.RedirectURIs()
+}
+
+func (c *devModeGlobClient) PostLogoutRedirectURIGlobs() []string {
+ return c.PostLogoutRedirectURIs()
+}
+
+// wrapPossibleDevModeClient takes a Client from the database returns an op.Client, possibly wrapped with
+// devModeGlobClient based on Client.ClientDevMode.
+func wrapPossibleDevModeClient(client Client) op.Client {
+ if client.DevMode() {
+ return &devModeGlobClient{client}
+ } else {
+ return &client
+ }
+}
diff --git a/sherlock/internal/oidc_models/oidc_models_test.go b/sherlock/internal/oidc_models/oidc_models_test.go
new file mode 100644
index 000000000..fffda2aef
--- /dev/null
+++ b/sherlock/internal/oidc_models/oidc_models_test.go
@@ -0,0 +1,23 @@
+package oidc_models
+
+import (
+ "github.com/broadinstitute/sherlock/sherlock/internal/models"
+ "github.com/stretchr/testify/suite"
+ "testing"
+)
+
+type oidcModelsSuite struct {
+ suite.Suite
+ models.TestSuiteHelper
+ TestClientHelper
+ storage *storageImpl
+}
+
+func TestOidcModelsSuite(t *testing.T) {
+ suite.Run(t, new(oidcModelsSuite))
+}
+
+func (s *oidcModelsSuite) SetupTest() {
+ s.TestSuiteHelper.SetupTest()
+ s.storage = &storageImpl{db: s.DB}
+}
diff --git a/sherlock/internal/oidc_models/provider.go b/sherlock/internal/oidc_models/provider.go
new file mode 100644
index 000000000..7bf7cd3e7
--- /dev/null
+++ b/sherlock/internal/oidc_models/provider.go
@@ -0,0 +1,43 @@
+package oidc_models
+
+import (
+ "encoding/hex"
+ "fmt"
+ "github.com/broadinstitute/sherlock/sherlock/internal/config"
+ "github.com/zitadel/oidc/v3/pkg/op"
+ "golang.org/x/text/language"
+ "gorm.io/gorm"
+)
+
+var Provider op.OpenIDProvider
+
+func initProvider(db *gorm.DB) error {
+ key, err := hex.DecodeString(config.Config.String("oidc.encryptionKeyHex"))
+ if err != nil {
+ return fmt.Errorf("could not decode oidc.encryptionKeyHex: %w", err)
+ } else if len(key) != 32 {
+ return fmt.Errorf("oidc.encryptionKeyHex must be 32 bytes long; got %d", len(key))
+ }
+
+ storage := &storageImpl{db: db}
+ conf := &op.Config{
+ CryptoKey: ([32]byte)(key),
+ DefaultLogoutRedirectURI: "/static/logged-out.html",
+ CodeMethodS256: true, // Enable PKCE and S256 code challenge method
+ GrantTypeRefreshToken: true, // Allow refresh token grant user
+ SupportedUILocales: []language.Tag{language.AmericanEnglish},
+ SupportedClaims: append(op.DefaultSupportedClaims, groupsClaim), // Technically more than we provide but better a superset than subset
+ }
+ options := []op.Option{
+ op.WithCORSOptions(nil),
+ // If we expand our usage of the OIDC subsystem such that users need to connect to other endpoints,
+ // we may need to add custom URLs for those endpoints here.
+ op.WithCustomAuthEndpoint(op.NewEndpointWithURL("/authorize", config.Config.String("oidc.publicIssuerUrl")+"/authorize")),
+ }
+ if config.Config.String("mode") == "debug" {
+ options = append(options, op.WithAllowInsecure())
+ }
+
+ Provider, err = op.NewProvider(conf, storage, op.StaticIssuer(config.Config.String("oidc.issuerUrl")), options...)
+ return err
+}
diff --git a/sherlock/internal/oidc_models/refresh_token.go b/sherlock/internal/oidc_models/refresh_token.go
new file mode 100644
index 000000000..1d584ca9a
--- /dev/null
+++ b/sherlock/internal/oidc_models/refresh_token.go
@@ -0,0 +1,72 @@
+package oidc_models
+
+import (
+ "errors"
+ "github.com/broadinstitute/sherlock/go-shared/pkg/utils"
+ "github.com/broadinstitute/sherlock/sherlock/internal/config"
+ "github.com/broadinstitute/sherlock/sherlock/internal/models"
+ "github.com/google/uuid"
+ "github.com/zitadel/oidc/v3/pkg/oidc"
+ "github.com/zitadel/oidc/v3/pkg/op"
+ "gorm.io/gorm"
+ "gorm.io/gorm/clause"
+ "time"
+)
+
+var _ op.RefreshTokenRequest = &RefreshToken{}
+
+type RefreshToken struct {
+ ID uuid.UUID `gorm:"primaryKey"`
+ CreatedAt time.Time
+
+ TokenHash []byte // SHA-512 hash
+
+ Client *Client `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE"`
+ ClientID string // AKA Audience, Application ID
+ Scopes oidc.SpaceDelimitedArray
+ OriginalAuthAt time.Time
+
+ User *models.User
+ UserID uint
+}
+
+func expireRefreshTokens(db *gorm.DB) error {
+ err := db.
+ Omit(clause.Associations).
+ Where("created_at < ?", time.Now().Add(-config.Config.MustDuration("oidc.refreshTokenDuration"))).
+ Delete(&RefreshToken{}).
+ Error
+ if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
+ return err
+ } else {
+ return nil
+ }
+}
+
+func (r *RefreshToken) GetAMR() []string {
+ return []string{}
+}
+
+func (r *RefreshToken) GetAudience() []string {
+ return []string{r.ClientID}
+}
+
+func (r *RefreshToken) GetAuthTime() time.Time {
+ return r.OriginalAuthAt
+}
+
+func (r *RefreshToken) GetClientID() string {
+ return r.ClientID
+}
+
+func (r *RefreshToken) GetScopes() []string {
+ return r.Scopes
+}
+
+func (r *RefreshToken) GetSubject() string {
+ return utils.UintToString(r.UserID)
+}
+
+func (r *RefreshToken) SetCurrentScopes(scopes []string) {
+ r.Scopes = scopes
+}
diff --git a/sherlock/internal/oidc_models/signing_key.go b/sherlock/internal/oidc_models/signing_key.go
new file mode 100644
index 000000000..566563998
--- /dev/null
+++ b/sherlock/internal/oidc_models/signing_key.go
@@ -0,0 +1,175 @@
+package oidc_models
+
+import (
+ "cloud.google.com/go/kms/apiv1/kmspb"
+ "context"
+ "crypto/ecdsa"
+ "crypto/elliptic"
+ "crypto/rand"
+ "crypto/x509"
+ "fmt"
+ "github.com/broadinstitute/sherlock/sherlock/internal/config"
+ "github.com/go-jose/go-jose/v4"
+ "github.com/google/uuid"
+ "github.com/rs/zerolog/log"
+ "github.com/zitadel/oidc/v3/pkg/op"
+ "gorm.io/gorm"
+ "gorm.io/gorm/clause"
+ "time"
+)
+
+type SigningKey struct {
+ ID uuid.UUID `gorm:"primaryKey"`
+ CreatedAt time.Time
+ PublicKey []byte
+ PrivateKey []byte
+}
+
+func rotateSigningKeys(ctx context.Context, db *gorm.DB) error {
+ var validKeys []SigningKey
+ err := db.
+ Where("created_at > ?", time.Now().
+ Add(-config.Config.MustDuration("oidc.signingKeyPrimaryDuration"))).
+ Omit(clause.Associations).
+ Find(&validKeys).Error
+ if err != nil {
+ return fmt.Errorf("error loading valid signing keys: %w", err)
+ }
+ if len(validKeys) == 0 {
+ newKey, err := saveNewSigningKey(ctx, db)
+ if err != nil {
+ return fmt.Errorf("error generating new signing key: %w", err)
+ } else {
+ log.Info().Msgf("OIDC | generated new signing key with ID %s", newKey.ID)
+ }
+ }
+
+ var expiredKeys []SigningKey
+ err = db.
+ Where("created_at <= ?", time.Now().
+ Add(-(config.Config.MustDuration("oidc.signingKeyPrimaryDuration"))).
+ Add(-config.Config.MustDuration("oidc.signingKeyPostRotationDuration"))).
+ Omit(clause.Associations).
+ Find(&expiredKeys).Error
+ if err != nil {
+ return fmt.Errorf("error loading expired signing keys: %w", err)
+ }
+ for _, key := range expiredKeys {
+ err = db.Omit(clause.Associations).Delete(&key).Error
+ if err != nil {
+ return fmt.Errorf("error deleting expired signing key: %w", err)
+ }
+ log.Info().Msgf("OIDC | deleted expired signing key with ID %s", key.ID)
+ }
+ return nil
+}
+
+func saveNewSigningKey(ctx context.Context, db *gorm.DB) (*SigningKey, error) {
+ keyModel := &SigningKey{
+ ID: uuid.New(),
+ }
+
+ // Generate new private key
+ privateKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
+ if err != nil {
+ return nil, fmt.Errorf("error generating private key: %w", err)
+ }
+
+ // Store public key plaintext in database for easy access
+ keyModel.PublicKey = elliptic.MarshalCompressed(privateKey.PublicKey.Curve, privateKey.PublicKey.X, privateKey.PublicKey.Y)
+
+ // Store private key, encrypting with KMS if configured
+ privateKeyBytes, err := x509.MarshalECPrivateKey(privateKey)
+ if err != nil {
+ return nil, fmt.Errorf("error marshaling private key: %w", err)
+ }
+ if kmsKey != "" {
+ // Store encrypted
+ response, err := kmsClient.Encrypt(ctx, &kmspb.EncryptRequest{
+ Name: kmsKey,
+ Plaintext: privateKeyBytes,
+ })
+ if err != nil {
+ return nil, fmt.Errorf("error encrypting private key with KMS: %w", err)
+ }
+ keyModel.PrivateKey = response.Ciphertext
+ } else {
+ // Store plaintext
+ keyModel.PrivateKey = privateKeyBytes
+ }
+
+ // Save database row
+ err = db.Create(keyModel).Error
+ if err != nil {
+ return nil, fmt.Errorf("error saving generated signing key to database: %w", err)
+ }
+ return keyModel, nil
+}
+
+var _ op.Key = &publicSigningKey{}
+
+type publicSigningKey struct {
+ SigningKey
+}
+
+func (p *publicSigningKey) ID() string {
+ return p.SigningKey.ID.String()
+}
+
+func (p *publicSigningKey) Algorithm() jose.SignatureAlgorithm {
+ return jose.ES512
+}
+
+func (p *publicSigningKey) Use() string {
+ return "sig"
+}
+
+func (p *publicSigningKey) Key() any {
+ x, y := elliptic.UnmarshalCompressed(elliptic.P521(), p.PublicKey)
+ return &ecdsa.PublicKey{
+ Curve: elliptic.P521(),
+ X: x,
+ Y: y,
+ }
+}
+
+var _ op.SigningKey = &decryptedPrivateSigningKey{}
+
+type decryptedPrivateSigningKey struct {
+ SigningKey
+ unencryptedPrivateKey *ecdsa.PrivateKey
+}
+
+func (p *decryptedPrivateSigningKey) ID() string {
+ return p.SigningKey.ID.String()
+}
+
+func (p *decryptedPrivateSigningKey) SignatureAlgorithm() jose.SignatureAlgorithm {
+ return jose.ES512
+}
+
+func (p *decryptedPrivateSigningKey) Key() any {
+ return p.unencryptedPrivateKey
+}
+
+func decryptPrivateSigningKey(key *SigningKey) (*decryptedPrivateSigningKey, error) {
+ bytesToUnmarshall := key.PrivateKey
+ if kmsKey != "" {
+ response, err := kmsClient.Decrypt(context.Background(), &kmspb.DecryptRequest{
+ Name: kmsKey,
+ Ciphertext: bytesToUnmarshall,
+ })
+ if err != nil {
+ return nil, fmt.Errorf("error decrypting private signing key with KMS: %w", err)
+ }
+ bytesToUnmarshall = response.Plaintext
+ }
+ privateKey, err := x509.ParseECPrivateKey(bytesToUnmarshall)
+ if err != nil {
+ return nil, fmt.Errorf("error unmarshalling private signing key: %w", err)
+ }
+ return &decryptedPrivateSigningKey{
+ SigningKey: *key,
+ unencryptedPrivateKey: privateKey,
+ }, nil
+}
diff --git a/sherlock/internal/oidc_models/storage.go b/sherlock/internal/oidc_models/storage.go
new file mode 100644
index 000000000..aeabb3fe4
--- /dev/null
+++ b/sherlock/internal/oidc_models/storage.go
@@ -0,0 +1,581 @@
+package oidc_models
+
+import (
+ "bytes"
+ "context"
+ "crypto/rand"
+ "crypto/sha512"
+ "errors"
+ "fmt"
+ "github.com/broadinstitute/sherlock/go-shared/pkg/utils"
+ "github.com/broadinstitute/sherlock/sherlock/internal/config"
+ "github.com/broadinstitute/sherlock/sherlock/internal/models"
+ "github.com/go-jose/go-jose/v4"
+ "github.com/google/uuid"
+ "github.com/zitadel/oidc/v3/pkg/oidc"
+ "github.com/zitadel/oidc/v3/pkg/op"
+ "golang.org/x/crypto/pbkdf2"
+ "golang.org/x/text/language"
+ "gorm.io/gorm"
+ "gorm.io/gorm/clause"
+ "strings"
+ "time"
+)
+
+const groupsClaim = "groups"
+
+var _ op.Storage = &storageImpl{}
+var _ op.CanSetUserinfoFromRequest = &storageImpl{}
+
+type storageImpl struct {
+ db *gorm.DB
+}
+
+func (s *storageImpl) CreateAuthRequest(_ context.Context, authRequest *oidc.AuthRequest, hintedUserID string) (op.AuthRequest, error) {
+
+ requestUUID, err := uuid.NewUUID()
+ if err != nil {
+ return nil, fmt.Errorf("failed to create UUID: %w", err)
+ }
+
+ request := AuthRequest{
+ ID: requestUUID,
+ ClientID: authRequest.ClientID,
+ Nonce: authRequest.Nonce,
+ RedirectURI: authRequest.RedirectURI,
+ ResponseType: authRequest.ResponseType,
+ ResponseMode: authRequest.ResponseMode,
+ Scopes: authRequest.Scopes,
+ State: authRequest.State,
+ CodeChallenge: authRequest.CodeChallenge,
+ CodeChallengeMethod: authRequest.CodeChallengeMethod,
+ }
+
+ if hintedUserID != "" {
+ parsedHintedUserID, err := utils.ParseUint(hintedUserID)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse user ID: %w", err)
+ }
+ request.UserID = &parsedHintedUserID
+ }
+
+ err = s.db.Omit(clause.Associations).Create(&request).Error
+ if err != nil {
+ return nil, fmt.Errorf("failed to create auth request: %w", err)
+ }
+
+ return &request, nil
+}
+
+func (s *storageImpl) AuthRequestByID(_ context.Context, requestID string) (op.AuthRequest, error) {
+ validRequestUUID, err := uuid.Parse(requestID)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse request UUID: %w", err)
+ }
+ var request AuthRequest
+ err = s.db.Omit(clause.Associations).Where(&AuthRequest{ID: validRequestUUID}).Take(&request).Error
+ if err != nil {
+ return nil, fmt.Errorf("failed to get auth request: %w", err)
+ }
+ return &request, nil
+}
+
+func (s *storageImpl) AuthRequestByCode(_ context.Context, code string) (op.AuthRequest, error) {
+ var requestCode AuthRequestCode
+ err := s.db.Preload("AuthRequest").Where(&AuthRequestCode{Code: code}).Take(&requestCode).Error
+ if err != nil {
+ return nil, fmt.Errorf("failed to get auth request code: %w", err)
+ } else if requestCode.AuthRequest == nil {
+ return nil, fmt.Errorf("auth request not found for code: %s", code)
+ }
+ return requestCode.AuthRequest, nil
+}
+
+func (s *storageImpl) SaveAuthCode(_ context.Context, requestID string, code string) error {
+ validRequestUUID, err := uuid.Parse(requestID)
+ if err != nil {
+ return fmt.Errorf("failed to parse request UUID: %w", err)
+ }
+ requestCode := AuthRequestCode{
+ Code: code,
+ AuthRequestID: validRequestUUID,
+ }
+ err = s.db.Omit(clause.Associations).Create(&requestCode).Error
+ if err != nil {
+ return fmt.Errorf("failed to save auth code: %w", err)
+ }
+ return nil
+}
+
+func (s *storageImpl) DeleteAuthRequest(_ context.Context, requestID string) error {
+ validRequestUUID, err := uuid.Parse(requestID)
+ if err != nil {
+ return fmt.Errorf("failed to parse request UUID: %w", err)
+ }
+ // Auth codes deleted via foreign key cascade
+ err = s.db.Omit(clause.Associations).Where(&AuthRequest{ID: validRequestUUID}).Delete(&AuthRequest{}).Error
+ if err != nil {
+ return fmt.Errorf("failed to delete auth request: %w", err)
+ }
+ return nil
+}
+
+func (s *storageImpl) CreateAccessToken(_ context.Context, request op.TokenRequest) (accessTokenID string, expiration time.Time, err error) {
+ var clientID string
+ var userID uint
+ switch req := request.(type) {
+ case *AuthRequest:
+ if !req.Done() || req.UserID == nil {
+ return "", time.Time{}, oidc.ErrLoginRequired()
+ }
+ clientID = req.ClientID
+ userID = *req.UserID
+ // It is possible for requests to be of another type -- like op.TokenExchangeRequest -- but we don't support
+ // that currently
+ default:
+ return "", time.Time{}, fmt.Errorf("unsupported request type: %T", request)
+ }
+
+ token, err := s.createAccessToken(clientID, nil, request.GetScopes(), userID)
+ if err != nil {
+ return "", time.Time{}, err
+ }
+ return token.ID.String(), token.Expiry, nil
+}
+
+func (s *storageImpl) CreateAccessAndRefreshTokens(_ context.Context, request op.TokenRequest, currentRefreshToken string) (accessTokenID string, newRefreshToken string, expiration time.Time, err error) {
+ var clientID string
+ var userID uint
+ var authTime time.Time
+ switch req := request.(type) {
+ case *AuthRequest:
+ if !req.Done() || req.UserID == nil {
+ return "", "", time.Time{}, oidc.ErrLoginRequired()
+ }
+ clientID = req.ClientID
+ userID = *req.UserID
+ authTime = req.GetAuthTime()
+ case *RefreshToken:
+ clientID = req.ClientID
+ userID = req.UserID
+ authTime = req.GetAuthTime()
+ default:
+ return "", "", time.Time{}, fmt.Errorf("unsupported request type: %T", request)
+ }
+
+ var refreshTokenModel *RefreshToken
+ if currentRefreshToken == "" {
+ refreshTokenModel, newRefreshToken, err = s.createRefreshToken(clientID, request.GetScopes(), userID, authTime)
+ } else {
+ refreshTokenModel, newRefreshToken, err = s.renewRefreshToken(currentRefreshToken)
+ }
+ if err != nil {
+ return "", "", time.Time{}, fmt.Errorf("failed to create refresh token: %w", err)
+ }
+
+ var tokenModel *Token
+ tokenModel, err = s.createAccessToken(clientID, &refreshTokenModel.ID, request.GetScopes(), userID)
+ if err != nil {
+ return "", "", time.Time{}, fmt.Errorf("failed to create access token: %w", err)
+ }
+
+ return tokenModel.ID.String(), newRefreshToken, tokenModel.Expiry, nil
+}
+
+func (s *storageImpl) TokenRequestByRefreshToken(_ context.Context, refreshTokenID string) (op.RefreshTokenRequest, error) {
+ parsedRefreshTokenID, err := uuid.Parse(refreshTokenID)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse refresh token ID: %w", err)
+ }
+ var refreshToken RefreshToken
+ err = s.db.Omit(clause.Associations).Where(&RefreshToken{ID: parsedRefreshTokenID}).Take(&refreshToken).Error
+ if err != nil {
+ return nil, fmt.Errorf("failed to get refresh token: %w", err)
+ }
+ return &refreshToken, nil
+}
+
+func (s *storageImpl) TerminateSession(_ context.Context, userID string, clientID string) error {
+ parsedUserID, err := utils.ParseUint(userID)
+ if err != nil {
+ return fmt.Errorf("failed to parse user ID: %w", err)
+ }
+ err1 := s.db.Omit(clause.Associations).Where(&RefreshToken{ClientID: clientID, UserID: parsedUserID}).Delete(&RefreshToken{}).Error
+ // Most tokens should've been caught by the foreign key cascade, but just in case we'll make sure we wipe all matching tokens too
+ err2 := s.db.Omit(clause.Associations).Where(&Token{ClientID: clientID, UserID: parsedUserID}).Delete(&Token{}).Error
+ if err1 != nil && !errors.Is(err1, gorm.ErrRecordNotFound) {
+ return fmt.Errorf("failed to delete refresh tokens: %w", err1)
+ }
+ if err2 != nil && !errors.Is(err2, gorm.ErrRecordNotFound) {
+ return fmt.Errorf("failed to delete access tokens: %w", err2)
+ }
+ return nil
+}
+
+func (s *storageImpl) RevokeToken(_ context.Context, tokenOrTokenID string, userID string, clientID string) *oidc.Error {
+ var tokenUUIDToRevoke, refreshTokenUUIDToRevoke *uuid.UUID
+
+ // UserID can be empty! If it's passed we'll filter on it.
+ var queryUserID uint
+ if userID != "" {
+ var err error
+ queryUserID, err = utils.ParseUint(userID)
+ if err != nil {
+ return oidc.ErrServerError().WithParent(err).WithDescription("failed to parse user ID")
+ }
+ }
+
+ tokenUUID, err := uuid.Parse(tokenOrTokenID)
+ if err != nil {
+ // Token's not a UUID, so all we can do is revoke a matching refresh token
+ hash := sha512.Sum512([]byte(tokenOrTokenID))
+ var refreshToken RefreshToken
+ err = s.db.Omit(clause.Associations).Where(&RefreshToken{TokenHash: hash[:], ClientID: clientID, UserID: queryUserID}).Take(&refreshToken).Error
+ if err == nil {
+ refreshTokenUUIDToRevoke = &refreshToken.ID
+ }
+ } else {
+ // Look up as token UUID
+ var token Token
+ err = s.db.Omit(clause.Associations).Where(&Token{ID: tokenUUID, ClientID: clientID, UserID: queryUserID}).Take(&token).Error
+ if err == nil {
+ tokenUUIDToRevoke = &token.ID
+ refreshTokenUUIDToRevoke = token.RefreshTokenID
+ } else {
+ // Look up as refresh token UUID
+ var refreshToken RefreshToken
+ err = s.db.Omit(clause.Associations).Where(&RefreshToken{ID: tokenUUID, ClientID: clientID, UserID: queryUserID}).Take(&refreshToken).Error
+ if err == nil {
+ refreshTokenUUIDToRevoke = &refreshToken.ID
+ }
+ }
+ }
+
+ if tokenUUIDToRevoke != nil {
+ err = s.db.Omit(clause.Associations).Where(&Token{ID: *tokenUUIDToRevoke}).Delete(&Token{}).Error
+ }
+ if refreshTokenUUIDToRevoke != nil {
+ err = s.revokeRefreshToken(*refreshTokenUUIDToRevoke)
+ }
+ if err != nil {
+ if errors.Is(err, gorm.ErrRecordNotFound) {
+ return oidc.ErrInvalidRequest().WithDescription("token not found")
+ } else {
+ return oidc.ErrServerError().WithParent(err).WithDescription("failed to revoke token")
+ }
+ }
+ return nil
+}
+
+func (s *storageImpl) GetRefreshTokenInfo(_ context.Context, clientID string, token string) (userID string, tokenID string, err error) {
+ hash := sha512.Sum512([]byte(token))
+ var refreshToken RefreshToken
+ err = s.db.Omit(clause.Associations).Where(&RefreshToken{TokenHash: hash[:], ClientID: clientID}).Take(&refreshToken).Error
+ if err != nil {
+ return "", "", fmt.Errorf("failed to get refresh token: %w", err)
+ }
+ return utils.UintToString(refreshToken.UserID), refreshToken.ID.String(), nil
+}
+
+func (s *storageImpl) SigningKey(_ context.Context) (op.SigningKey, error) {
+ var key SigningKey
+ err := s.db.Omit(clause.Associations).Order("created_at desc").First(&key).Error
+ if err != nil {
+ return nil, fmt.Errorf("failed to get signing key: %w", err)
+ }
+ return decryptPrivateSigningKey(&key)
+}
+
+func (s *storageImpl) KeySet(_ context.Context) ([]op.Key, error) {
+ var keys []SigningKey
+ err := s.db.Omit(clause.Associations).Order("created_at desc").Find(&keys).Error
+ if err != nil {
+ return nil, fmt.Errorf("failed to get signing keys: %w", err)
+ }
+ return utils.Map(keys, func(rawKey SigningKey) op.Key {
+ return &publicSigningKey{SigningKey: rawKey}
+ }), nil
+}
+
+func (s *storageImpl) GetClientByClientID(_ context.Context, clientID string) (op.Client, error) {
+ var client Client
+ err := s.db.Omit(clause.Associations).Where(&Client{ID: clientID}).Take(&client).Error
+ if err != nil {
+ return nil, fmt.Errorf("failed to get client: %w", err)
+ }
+ return wrapPossibleDevModeClient(client), nil
+}
+
+func (s *storageImpl) AuthorizeClientIDSecret(_ context.Context, clientID, clientSecret string) error {
+ var client Client
+ err := s.db.Omit(clause.Associations).Where(&Client{ID: clientID}).Take(&client).Error
+ if err != nil {
+ return fmt.Errorf("failed to get client: %w", err)
+ }
+ if len(client.ClientSecretSalt) == 0 || len(client.ClientSecretHash) == 0 || client.ClientSecretIterations == 0 {
+ return fmt.Errorf("client secret not configured; perhaps client should be using PKCE")
+ }
+ derivedKey := pbkdf2.Key([]byte(clientSecret), client.ClientSecretSalt, client.ClientSecretIterations, len(client.ClientSecretHash), sha512.New)
+ if !bytes.Equal(derivedKey, client.ClientSecretHash) {
+ return fmt.Errorf("client secret does not match")
+ }
+ return nil
+}
+
+// SetUserinfoFromScopes is an empty implementation because according to Zitadel's example, SetUserinfoFromRequest
+// should be implemented instead.
+func (s *storageImpl) SetUserinfoFromScopes(_ context.Context, _ *oidc.UserInfo, _, _ string, _ []string) error {
+ return nil
+}
+
+func (s *storageImpl) SetUserinfoFromRequest(_ context.Context, userinfo *oidc.UserInfo, request op.IDTokenRequest, scopes []string) error {
+ userID, err := utils.ParseUint(request.GetSubject())
+ if err != nil {
+ return fmt.Errorf("failed to parse user ID: %w", err)
+ }
+ return s.setUserinfo(userinfo, userID, scopes)
+}
+
+func (s *storageImpl) SetUserinfoFromToken(_ context.Context, userinfo *oidc.UserInfo, tokenID, subject, _ string) error {
+ parsedTokenID, err := uuid.Parse(tokenID)
+ if err != nil {
+ return fmt.Errorf("failed to parse token ID: %w", err)
+ }
+ var token Token
+ err = s.db.Omit(clause.Associations).Where(&Token{ID: parsedTokenID}).Take(&token).Error
+ if err != nil {
+ return fmt.Errorf("failed to get token: %w", err)
+ }
+ if utils.UintToString(token.UserID) != subject {
+ // This check is theoretically unnecessary because the library should be covering this,
+ // but we have the info so why not.
+ return fmt.Errorf("token mismatched with subject")
+ }
+
+ // We ignore the origin argument to this function because CORS is handled by Gin middleware, no reason for us
+ // to implement that here.
+
+ return s.setUserinfo(userinfo, token.UserID, token.Scopes)
+}
+
+// SetIntrospectionFromToken has some arguments we ignore. We ignore the subject because we can get the user ID in
+// a more verified way from the token.
+func (s *storageImpl) SetIntrospectionFromToken(_ context.Context, introspection *oidc.IntrospectionResponse, tokenID, subject, clientID string) error {
+ parsedUserID, err := utils.ParseUint(subject)
+ if err != nil {
+ return fmt.Errorf("failed to parse user ID: %w", err)
+ }
+
+ parsedTokenID, err := uuid.Parse(tokenID)
+ if err != nil {
+ return fmt.Errorf("failed to parse token ID: %w", err)
+ }
+ var token Token
+ err = s.db.Omit(clause.Associations).Where(&Token{ID: parsedTokenID, ClientID: clientID, UserID: parsedUserID}).Take(&token).Error
+ if err != nil {
+ return fmt.Errorf("failed to get token: %w", err)
+ }
+
+ // In our case, we basically equate audience, client ID, and application.
+ // This kind of audience == client ID check is "the right way" to do this,
+ // though, so we do it to avoid a gotcha down the line.
+ for _, aud := range token.GetAudience() {
+ if aud == clientID {
+ userInfo := new(oidc.UserInfo)
+ err = s.setUserinfo(userInfo, token.UserID, token.Scopes)
+ if err != nil {
+ return fmt.Errorf("failed to set userinfo: %w", err)
+ }
+ introspection.SetUserInfo(userInfo)
+ introspection.Scope = token.Scopes
+ introspection.ClientID = token.ClientID
+ introspection.Active = true
+ return nil
+ }
+ }
+ return fmt.Errorf("token not valid for client")
+}
+
+func (s *storageImpl) GetPrivateClaimsFromScopes(_ context.Context, userID, _ string, scopes []string) (map[string]any, error) {
+ claims := make(map[string]any)
+ for _, scope := range scopes {
+ switch scope {
+ case groupsClaim:
+ // Only get the user if we need to
+ parsedUserID, err := utils.ParseUint(userID)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse user ID: %w", err)
+ }
+ var user models.User
+ err = s.db.Where(&models.User{Model: gorm.Model{ID: parsedUserID}}).Scopes(models.ReadUserScope).Take(&user).Error
+ if err != nil {
+ return nil, fmt.Errorf("failed to get user: %w", err)
+ }
+ groups := make([]string, 0, len(user.Assignments))
+ for _, assignment := range user.Assignments {
+ if assignment != nil && assignment.IsActive() && assignment.Role != nil && assignment.Role.Name != nil {
+ groups = append(groups, *assignment.Role.Name)
+ }
+ }
+ claims[groupsClaim] = groups
+ }
+ }
+ return claims, nil
+}
+
+func (s *storageImpl) SignatureAlgorithms(_ context.Context) ([]jose.SignatureAlgorithm, error) {
+ // See signing_key.go, this is hardcoded
+ return []jose.SignatureAlgorithm{jose.ES512}, nil
+}
+
+func (s *storageImpl) GetKeyByIDAndClientID(_ context.Context, _, _ string) (*jose.JSONWebKey, error) {
+ // What we're meant to do here is define a list of (client ID, key ID, public key) tuples in configuration or something.
+ // The idea is that a client application would have the private key, and it could then sign JWTs *to send to us*, per
+ // RFC 7523 (urn:ietf:params:oauth:grant-type:jwt-bearer). zitadel-oidc would call this function to get the public key,
+ // validate the signature, and then return the requested access token. This allows a client to get an access token for
+ // a user without user interaction.
+ //
+ // We don't support that grant type for Sherlock so we leave this unimplemented. If somehow this does get called,
+ // the error will bubble up as if the assertion was invalid.
+ //
+ // https://datatracker.ietf.org/doc/html/rfc7523
+ return nil, fmt.Errorf("JWT Profile Authorization Grants (RFC 7523) aren't currently implemented")
+}
+
+func (s *storageImpl) ValidateJWTProfileScopes(_ context.Context, _ string, scopes []string) ([]string, error) {
+ allowedScopes := make([]string, 0)
+ for _, scope := range scopes {
+ if scope == oidc.ScopeOpenID || scope == oidc.ScopeProfile || scope == oidc.ScopeEmail {
+ allowedScopes = append(allowedScopes, scope)
+ }
+ }
+ return allowedScopes, nil
+}
+
+// Health isn't something we particularly need to rely on, because this OIDC provider exists in the context of a larger
+// application that already has liveness and readiness probes configured. If the OIDC API is available, the OIDC provider
+// will be healthy.
+func (s *storageImpl) Health(_ context.Context) error {
+ if s.db == nil {
+ return fmt.Errorf("no database connection")
+ }
+ return nil
+}
+
+func (s *storageImpl) createRefreshToken(clientID string, scopes []string, userID uint, authAt time.Time) (*RefreshToken, string, error) {
+ id, err := uuid.NewUUID()
+ if err != nil {
+ return nil, "", fmt.Errorf("failed to create UUID: %w", err)
+ }
+ refreshToken := make([]byte, 256)
+ _, err = rand.Read(refreshToken)
+ if err != nil {
+ return nil, "", fmt.Errorf("failed to generate refresh token: %w", err)
+ }
+ hash := sha512.Sum512(refreshToken)
+ refreshTokenModel := &RefreshToken{
+ ID: id,
+ TokenHash: hash[:],
+ ClientID: clientID,
+ Scopes: scopes,
+ OriginalAuthAt: authAt,
+ UserID: userID,
+ }
+ err = s.db.Omit(clause.Associations).Create(refreshTokenModel).Error
+ if err != nil {
+ return nil, "", fmt.Errorf("failed to create refresh token: %w", err)
+ }
+ return refreshTokenModel, string(refreshToken), nil
+}
+
+func (s *storageImpl) revokeRefreshToken(refreshTokenID uuid.UUID) error {
+ // Normal tokens are deleted via foreign key cascade
+ err := s.db.Omit(clause.Associations).Where(&RefreshToken{ID: refreshTokenID}).Delete(&RefreshToken{}).Error
+ if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
+ return fmt.Errorf("failed to delete refresh token: %w", err)
+ }
+ return nil
+}
+
+func (s *storageImpl) renewRefreshToken(refreshToken string) (*RefreshToken, string, error) {
+ hash := sha512.Sum512([]byte(refreshToken))
+ var refreshTokenModel RefreshToken
+ err := s.db.Omit(clause.Associations).Where(&RefreshToken{TokenHash: hash[:]}).Take(&refreshTokenModel).Error
+ if err != nil {
+ return nil, "", fmt.Errorf("failed to get refresh token: %w", err)
+ }
+
+ err = s.revokeRefreshToken(refreshTokenModel.ID)
+ if err != nil {
+ return nil, "", fmt.Errorf("failed to revoke existing refresh token: %w", err)
+ }
+
+ return s.createRefreshToken(refreshTokenModel.ClientID, refreshTokenModel.Scopes, refreshTokenModel.UserID, refreshTokenModel.OriginalAuthAt)
+}
+
+func (s *storageImpl) createAccessToken(clientID string, refreshTokenID *uuid.UUID, scopes []string, userID uint) (*Token, error) {
+ id, err := uuid.NewUUID()
+ if err != nil {
+ return nil, fmt.Errorf("failed to create UUID: %w", err)
+ }
+ tokenModel := &Token{
+ ID: id,
+ RefreshTokenID: refreshTokenID,
+ ClientID: clientID,
+ Scopes: scopes,
+ Expiry: time.Now().Add(config.Config.MustDuration("oidc.tokenDuration")),
+ UserID: userID,
+ }
+
+ err = s.db.Omit(clause.Associations).Create(tokenModel).Error
+ if err != nil {
+ return nil, fmt.Errorf("failed to create access token: %w", err)
+ }
+
+ return tokenModel, nil
+}
+
+func (s *storageImpl) setUserinfo(userInfo *oidc.UserInfo, userID uint, scopes []string) error {
+ var user models.User
+ err := s.db.Where(&models.User{Model: gorm.Model{ID: userID}}).Scopes(models.ReadUserScope).Take(&user).Error
+ if err != nil {
+ return fmt.Errorf("failed to get user: %w", err)
+ }
+ for _, scope := range scopes {
+ switch scope {
+ case oidc.ScopeOpenID:
+ userInfo.Subject = utils.UintToString(user.ID)
+ case oidc.ScopeEmail:
+ userInfo.Email = user.Email
+ userInfo.EmailVerified = true
+ case oidc.ScopeProfile:
+ // Downstream systems expect consistent claims. We can get away with passing GivenName and FamilyName
+ // conditionally, but for Name it's worth it to always pass it, even if we potentially have to pass
+ // an ugly-looking email handle.
+ userInfo.Name = user.NameOrUsername()
+ userInfo.Nickname = user.NameOrUsername()
+ userInfo.PreferredUsername = user.AlphaNumericHyphenatedUsername()
+ userInfo.Locale = oidc.NewLocale(language.AmericanEnglish)
+ userInfo.UpdatedAt = oidc.FromTime(user.UpdatedAt)
+ if user.Name != nil {
+ nameParts := strings.Split(*user.Name, " ")
+ if len(nameParts) > 0 {
+ userInfo.GivenName = nameParts[0]
+ }
+ if len(nameParts) > 1 {
+ userInfo.FamilyName = nameParts[len(nameParts)-1]
+ }
+ }
+ userInfo.Website = "https://broad.io/beehive/r/user/" + user.Email // :shrug:
+ case groupsClaim:
+ groups := make([]string, 0, len(user.Assignments))
+ for _, assignment := range user.Assignments {
+ if assignment != nil && assignment.IsActive() && assignment.Role != nil && assignment.Role.Name != nil {
+ groups = append(groups, *assignment.Role.Name)
+ }
+ }
+ userInfo.AppendClaims(groupsClaim, groups)
+ }
+ }
+ return nil
+}
diff --git a/sherlock/internal/oidc_models/storage_test.go b/sherlock/internal/oidc_models/storage_test.go
new file mode 100644
index 000000000..32e2576bb
--- /dev/null
+++ b/sherlock/internal/oidc_models/storage_test.go
@@ -0,0 +1,695 @@
+package oidc_models
+
+import (
+ "context"
+ "database/sql"
+ "github.com/broadinstitute/sherlock/go-shared/pkg/utils"
+ "github.com/broadinstitute/sherlock/sherlock/internal/models"
+ "github.com/go-jose/go-jose/v4"
+ "github.com/google/uuid"
+ "github.com/zitadel/oidc/v3/pkg/oidc"
+ "golang.org/x/text/language"
+ "gorm.io/gorm/clause"
+ "strings"
+ "time"
+)
+
+func (s *oidcModelsSuite) TestStorageImpl_CreateAuthRequest() {
+ clientID, _, err := s.GenerateClient(s.DB)
+ s.NoError(err)
+
+ authRequest, err := s.storage.CreateAuthRequest(context.Background(), &oidc.AuthRequest{
+ Scopes: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim},
+ ResponseType: oidc.ResponseTypeIDTokenOnly,
+ ClientID: clientID,
+ RedirectURI: s.GeneratedClientRedirectURI(),
+ State: "some-state",
+ Nonce: "some-nonce",
+ ResponseMode: oidc.ResponseModeQuery,
+ CodeChallenge: "code-challenge",
+ CodeChallengeMethod: oidc.CodeChallengeMethodS256,
+ }, "")
+ s.NoError(err)
+
+ var authRequests []AuthRequest
+ s.NoError(s.DB.Find(&authRequests).Error)
+ s.Len(authRequests, 1)
+
+ s.Equal(authRequest.GetID(), authRequests[0].ID.String())
+ s.Equal(clientID, authRequests[0].ClientID)
+ s.Equal(oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim}, authRequests[0].Scopes)
+ s.Equal(oidc.ResponseTypeIDTokenOnly, authRequests[0].ResponseType)
+ s.Equal(s.GeneratedClientRedirectURI(), authRequests[0].RedirectURI)
+ s.Equal("some-state", authRequests[0].State)
+ s.Equal("some-nonce", authRequests[0].Nonce)
+ s.Equal(oidc.ResponseModeQuery, authRequests[0].ResponseMode)
+ s.Equal("code-challenge", authRequests[0].CodeChallenge)
+ s.Equal(oidc.CodeChallengeMethodS256, authRequests[0].CodeChallengeMethod)
+}
+
+func (s *oidcModelsSuite) TestStorageImpl_AuthRequestByID() {
+ clientID, _, err := s.GenerateClient(s.DB)
+ s.NoError(err)
+
+ authRequest := &AuthRequest{
+ ID: uuid.New(),
+ DoneAt: sql.NullTime{Time: time.Now(), Valid: true},
+ ClientID: clientID,
+ Nonce: "some-nonce",
+ RedirectURI: s.GeneratedClientRedirectURI(),
+ Scopes: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim},
+ State: "some-state",
+ UserID: utils.PointerTo(s.TestData.User_Suitable().ID),
+ }
+ s.NoError(s.DB.Omit(clause.Associations).Create(authRequest).Error)
+
+ authRequestByID, err := s.storage.AuthRequestByID(context.Background(), authRequest.ID.String())
+ s.NoError(err)
+ s.Equal(authRequest.ID.String(), authRequestByID.GetID())
+}
+
+func (s *oidcModelsSuite) TestStorageImpl_AuthRequestByCode() {
+ clientID, _, err := s.GenerateClient(s.DB)
+ s.NoError(err)
+
+ authRequest := &AuthRequest{
+ ID: uuid.New(),
+ DoneAt: sql.NullTime{Time: time.Now(), Valid: true},
+ ClientID: clientID,
+ Nonce: "some-nonce",
+ RedirectURI: s.GeneratedClientRedirectURI(),
+ Scopes: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim},
+ State: "some-state",
+ UserID: utils.PointerTo(s.TestData.User_Suitable().ID),
+ }
+ s.NoError(s.DB.Omit(clause.Associations).Create(authRequest).Error)
+
+ code := "some-code"
+ s.NoError(s.storage.SaveAuthCode(context.Background(), authRequest.ID.String(), code))
+
+ authRequestByCode, err := s.storage.AuthRequestByCode(context.Background(), code)
+ s.NoError(err)
+ s.Equal(authRequest.ID.String(), authRequestByCode.GetID())
+}
+
+func (s *oidcModelsSuite) TestStorageImpl_SaveAuthCode() {
+ clientID, _, err := s.GenerateClient(s.DB)
+ s.NoError(err)
+
+ authRequest := &AuthRequest{
+ ID: uuid.New(),
+ DoneAt: sql.NullTime{Time: time.Now(), Valid: true},
+ ClientID: clientID,
+ Nonce: "some-nonce",
+ RedirectURI: s.GeneratedClientRedirectURI(),
+ Scopes: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim},
+ State: "some-state",
+ UserID: utils.PointerTo(s.TestData.User_Suitable().ID),
+ }
+ s.NoError(s.DB.Omit(clause.Associations).Create(authRequest).Error)
+
+ code := "some-code"
+ s.NoError(s.storage.SaveAuthCode(context.Background(), authRequest.ID.String(), code))
+
+ var authRequestCodes []AuthRequestCode
+ s.NoError(s.DB.Find(&authRequestCodes).Error)
+ s.Len(authRequestCodes, 1)
+ s.Equal(code, authRequestCodes[0].Code)
+ s.Equal(authRequest.ID.String(), authRequestCodes[0].AuthRequestID.String())
+}
+
+func (s *oidcModelsSuite) TestStorageImpl_DeleteAuthRequest() {
+ clientID, _, err := s.GenerateClient(s.DB)
+ s.NoError(err)
+
+ authRequest := &AuthRequest{
+ ID: uuid.New(),
+ DoneAt: sql.NullTime{Time: time.Now(), Valid: true},
+ ClientID: clientID,
+ Nonce: "some-nonce",
+ RedirectURI: s.GeneratedClientRedirectURI(),
+ Scopes: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim},
+ State: "some-state",
+ UserID: utils.PointerTo(s.TestData.User_Suitable().ID),
+ }
+ s.NoError(s.DB.Omit(clause.Associations).Create(authRequest).Error)
+
+ s.NoError(s.storage.DeleteAuthRequest(context.Background(), authRequest.ID.String()))
+
+ var authRequests []AuthRequest
+ s.NoError(s.DB.Find(&authRequests).Error)
+ s.Len(authRequests, 0)
+}
+
+func (s *oidcModelsSuite) TestStorageImpl_CreateAccessToken() {
+ clientID, _, err := s.GenerateClient(s.DB)
+ s.NoError(err)
+ accessTokenID, _, err := s.storage.CreateAccessToken(context.Background(), &AuthRequest{
+ ID: uuid.New(),
+ DoneAt: sql.NullTime{Time: time.Now(), Valid: true},
+ ClientID: clientID,
+ Nonce: "some-nonce",
+ RedirectURI: s.GeneratedClientRedirectURI(),
+ Scopes: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim},
+ State: "some-state",
+ UserID: utils.PointerTo(s.TestData.User_Suitable().ID),
+ })
+ s.NoError(err)
+
+ var tokens []Token
+ s.NoError(s.DB.Find(&tokens).Error)
+ s.Len(tokens, 1)
+ s.Equal(accessTokenID, tokens[0].ID.String())
+ s.Equal(clientID, tokens[0].ClientID)
+ s.Equal(oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim}, tokens[0].Scopes)
+
+}
+
+func (s *oidcModelsSuite) TestStorageImpl_CreateAccessAndRefreshTokens() {
+ clientID, _, err := s.GenerateClient(s.DB)
+ s.NoError(err)
+ authRequest := &AuthRequest{
+ ID: uuid.New(),
+ DoneAt: sql.NullTime{Time: time.Now(), Valid: true},
+ ClientID: clientID,
+ Nonce: "some-nonce",
+ RedirectURI: s.GeneratedClientRedirectURI(),
+ Scopes: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim},
+ State: "some-state",
+ UserID: utils.PointerTo(s.TestData.User_Suitable().ID),
+ }
+ accessTokenID, _, _, err := s.storage.CreateAccessAndRefreshTokens(context.Background(), authRequest, "")
+ s.NoError(err)
+
+ var tokens []Token
+ s.NoError(s.DB.Find(&tokens).Error)
+ s.Len(tokens, 1)
+ s.Equal(accessTokenID, tokens[0].ID.String())
+ s.Equal(clientID, tokens[0].ClientID)
+ s.Equal(oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim}, tokens[0].Scopes)
+
+ var refreshTokens []RefreshToken
+ s.NoError(s.DB.Find(&refreshTokens).Error)
+ s.Len(refreshTokens, 1)
+}
+
+func (s *oidcModelsSuite) TestStorageImpl_TokenRequestByRefreshToken() {
+ clientID, _, err := s.GenerateClient(s.DB)
+ s.NoError(err)
+ refreshTokenModel, _, err := s.storage.createRefreshToken(clientID, []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim}, s.TestData.User_Suitable().ID, time.Now())
+ s.NoError(err)
+
+ tokenRequest, err := s.storage.TokenRequestByRefreshToken(context.Background(), refreshTokenModel.ID.String())
+ s.NoError(err)
+ s.Equal(refreshTokenModel.OriginalAuthAt.Second(), tokenRequest.GetAuthTime().Second())
+}
+
+func (s *oidcModelsSuite) TestStorageImpl_TerminateSession() {
+ clientID, _, err := s.GenerateClient(s.DB)
+ s.NoError(err)
+ _, _, _, err = s.storage.CreateAccessAndRefreshTokens(context.Background(), &AuthRequest{
+ ID: uuid.New(),
+ DoneAt: sql.NullTime{Time: time.Now(), Valid: true},
+ ClientID: clientID,
+ Nonce: "some-nonce",
+ RedirectURI: s.GeneratedClientRedirectURI(),
+ Scopes: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim},
+ State: "some-state",
+ UserID: utils.PointerTo(s.TestData.User_Suitable().ID),
+ }, "")
+ s.NoError(err)
+ s.NoError(s.storage.TerminateSession(context.Background(), utils.UintToString(s.TestData.User_Suitable().ID), clientID))
+ var refreshTokens []RefreshToken
+ s.NoError(s.DB.Find(&refreshTokens).Error)
+ s.Len(refreshTokens, 0)
+ var tokens []Token
+ s.NoError(s.DB.Find(&tokens).Error)
+ s.Len(tokens, 0)
+}
+
+func (s *oidcModelsSuite) TestStorageImpl_RevokeToken() {
+ clientID, _, err := s.GenerateClient(s.DB)
+ s.NoError(err)
+ s.Run("refresh token", func() {
+ _, refreshToken, _, err := s.storage.CreateAccessAndRefreshTokens(context.Background(), &AuthRequest{
+ ID: uuid.New(),
+ DoneAt: sql.NullTime{Time: time.Now(), Valid: true},
+ ClientID: clientID,
+ Nonce: "some-nonce",
+ RedirectURI: s.GeneratedClientRedirectURI(),
+ Scopes: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim},
+ State: "some-state",
+ UserID: utils.PointerTo(s.TestData.User_Suitable().ID),
+ }, "")
+ s.NoError(err)
+
+ s.Nil(s.storage.RevokeToken(context.Background(), refreshToken, utils.UintToString(s.TestData.User_Suitable().ID), clientID))
+
+ var refreshTokens []RefreshToken
+ s.NoError(s.DB.Find(&refreshTokens).Error)
+ s.Len(refreshTokens, 0)
+ var tokens []Token
+ s.NoError(s.DB.Find(&tokens).Error)
+ s.Len(tokens, 0)
+ })
+ s.Run("token", func() {
+ accessTokenID, _, err := s.storage.CreateAccessToken(context.Background(), &AuthRequest{
+ ID: uuid.New(),
+ DoneAt: sql.NullTime{Time: time.Now(), Valid: true},
+ ClientID: clientID,
+ Nonce: "some-nonce",
+ RedirectURI: s.GeneratedClientRedirectURI(),
+ Scopes: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim},
+ State: "some-state",
+ UserID: utils.PointerTo(s.TestData.User_Suitable().ID),
+ })
+ s.NoError(err)
+
+ s.Nil(s.storage.RevokeToken(context.Background(), accessTokenID, utils.UintToString(s.TestData.User_Suitable().ID), clientID))
+
+ var refreshTokens []RefreshToken
+ s.NoError(s.DB.Find(&refreshTokens).Error)
+ s.Len(refreshTokens, 0)
+ var tokens []Token
+ s.NoError(s.DB.Find(&tokens).Error)
+ s.Len(tokens, 0)
+ })
+}
+
+func (s *oidcModelsSuite) TestStorageImpl_GetRefreshTokenInfo() {
+ clientID, _, err := s.GenerateClient(s.DB)
+ s.NoError(err)
+ refreshTokenModel, rawRefreshToken, err := s.storage.createRefreshToken(clientID, []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim}, s.TestData.User_Suitable().ID, time.Now())
+ s.NoError(err)
+
+ userID, tokenID, err := s.storage.GetRefreshTokenInfo(context.Background(), clientID, rawRefreshToken)
+ s.NoError(err)
+ s.Equal(utils.UintToString(s.TestData.User_Suitable().ID), userID)
+ s.Equal(refreshTokenModel.ID.String(), tokenID)
+}
+
+func (s *oidcModelsSuite) TestStorageImpl_SigningKey() {
+ key1, err := saveNewSigningKey(context.Background(), s.DB)
+ s.NoError(err)
+ key2, err := saveNewSigningKey(context.Background(), s.DB)
+ s.NoError(err)
+ s.NoError(s.DB.Model(&key1).UpdateColumn("created_at", time.Now().Add(-time.Hour)).Error)
+ signingKey, err := s.storage.SigningKey(context.Background())
+ s.NoError(err)
+ s.Equal(key2.ID.String(), signingKey.ID())
+}
+
+func (s *oidcModelsSuite) TestStorageImpl_KeySet() {
+ key1, err := saveNewSigningKey(context.Background(), s.DB)
+ s.NoError(err)
+ key2, err := saveNewSigningKey(context.Background(), s.DB)
+ s.NoError(err)
+ keyset, err := s.storage.KeySet(context.Background())
+ s.NoError(err)
+ s.Len(keyset, 2)
+ var key1Found, key2Found bool
+ for _, key := range keyset {
+ if key.ID() == key1.ID.String() {
+ key1Found = true
+ }
+ if key.ID() == key2.ID.String() {
+ key2Found = true
+ }
+ }
+ s.True(key1Found)
+ s.True(key2Found)
+}
+
+func (s *oidcModelsSuite) TestStorageImpl_GetClientByClientID() {
+ clientID, _, err := s.GenerateClient(s.DB)
+ s.NoError(err)
+ client, err := s.storage.GetClientByClientID(context.Background(), clientID)
+ s.NoError(err)
+ s.Equal(clientID, client.GetID())
+}
+
+func (s *oidcModelsSuite) TestStorageImpl_AuthorizeClientIDSecret() {
+ clientID, clientSecret, err := s.GenerateClient(s.DB)
+ s.NoError(err)
+ s.Run("valid", func() {
+ s.NoError(s.storage.AuthorizeClientIDSecret(context.Background(), clientID, clientSecret))
+ })
+ s.Run("invalid", func() {
+ s.Error(s.storage.AuthorizeClientIDSecret(context.Background(), clientID, "invalid"))
+ })
+}
+
+func (s *oidcModelsSuite) TestStorageImpl_SetUserinfoFromScopes() {
+ clientID, _, err := s.GenerateClient(s.DB)
+ s.NoError(err)
+ userinfo := &oidc.UserInfo{}
+ s.NoError(s.storage.SetUserinfoFromScopes(context.Background(), userinfo, utils.UintToString(s.TestData.User_Suitable().ID), clientID, []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim}))
+ s.Empty(userinfo.Subject) // This method does nothing intentionally; the library says to implement the request handling instead
+}
+
+func (s *oidcModelsSuite) TestStorageImpl_SetUserinfoFromRequest() {
+ userinfo := &oidc.UserInfo{}
+ s.NoError(s.storage.SetUserinfoFromRequest(context.Background(), userinfo, &AuthRequest{
+ UserID: utils.PointerTo(s.TestData.User_Suitable().ID),
+ }, []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim}))
+ s.Equal(utils.UintToString(s.TestData.User_Suitable().ID), userinfo.Subject)
+}
+
+func (s *oidcModelsSuite) TestStorageImpl_SetUserinfoFromToken() {
+ clientID, _, err := s.GenerateClient(s.DB)
+ s.NoError(err)
+ token := Token{
+ ID: uuid.New(),
+ ClientID: clientID,
+ Scopes: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim},
+ Expiry: time.Now().Add(time.Hour),
+ UserID: s.TestData.User_Suitable().ID,
+ }
+ s.NoError(s.DB.Omit(clause.Associations).Create(&token).Error)
+ userinfo := &oidc.UserInfo{}
+ s.NoError(s.storage.SetUserinfoFromToken(context.Background(), userinfo, token.ID.String(), utils.UintToString(s.TestData.User_Suitable().ID), ""))
+ s.Equal(utils.UintToString(s.TestData.User_Suitable().ID), userinfo.Subject)
+}
+
+func (s *oidcModelsSuite) TestStorageImpl_SetIntrospectionFromToken() {
+ clientID, _, err := s.GenerateClient(s.DB)
+ s.NoError(err)
+ token := Token{
+ ID: uuid.New(),
+ ClientID: clientID,
+ Scopes: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim},
+ Expiry: time.Now().Add(time.Hour),
+ UserID: s.TestData.User_Suitable().ID,
+ }
+ s.NoError(s.DB.Omit(clause.Associations).Create(&token).Error)
+ introspection := &oidc.IntrospectionResponse{}
+ s.NoError(s.storage.SetIntrospectionFromToken(context.Background(), introspection, token.ID.String(), utils.UintToString(s.TestData.User_Suitable().ID), clientID))
+ s.Equal(utils.UintToString(s.TestData.User_Suitable().ID), introspection.Subject)
+ s.Equal(token.Scopes, introspection.Scope)
+ s.Equal(clientID, introspection.ClientID)
+ s.True(introspection.Active)
+}
+
+func (s *oidcModelsSuite) TestStorageImpl_GetPrivateClaimsFromScopes() {
+ s.Run("empty", func() {
+ claims, err := s.storage.GetPrivateClaimsFromScopes(context.Background(), utils.UintToString(s.TestData.User_Suitable().ID), "", []string{})
+ s.NoError(err)
+ s.Empty(claims)
+ })
+ s.Run("groups", func() {
+ claims, err := s.storage.GetPrivateClaimsFromScopes(context.Background(), utils.UintToString(s.TestData.User_Suitable().ID), "", []string{groupsClaim})
+ s.NoError(err)
+ if s.NotEmpty(claims) && s.Contains(claims, groupsClaim) && s.IsType([]string{}, claims[groupsClaim]) {
+ groups := claims[groupsClaim].([]string)
+ for _, ra := range s.TestData.User_Suitable().Assignments {
+ if ra != nil && ra.IsActive() {
+ s.Contains(groups, *ra.Role.Name)
+ }
+ }
+ }
+ })
+ s.Run("extra", func() {
+ claims, err := s.storage.GetPrivateClaimsFromScopes(context.Background(), utils.UintToString(s.TestData.User_Suitable().ID), "", []string{groupsClaim, "extra"})
+ s.NoError(err)
+ if s.NotEmpty(claims) && s.Contains(claims, groupsClaim) && s.IsType([]string{}, claims[groupsClaim]) {
+ groups := claims[groupsClaim].([]string)
+ for _, ra := range s.TestData.User_Suitable().Assignments {
+ if ra != nil && ra.IsActive() {
+ s.Contains(groups, *ra.Role.Name)
+ }
+ }
+ }
+ })
+}
+
+func (s *oidcModelsSuite) TestStorageImpl_SignatureAlgorithms() {
+ algorithms, err := s.storage.SignatureAlgorithms(context.Background())
+ s.NoError(err)
+ s.Equal([]jose.SignatureAlgorithm{jose.ES512}, algorithms)
+}
+
+func (s *oidcModelsSuite) TestStorageImpl_GetKeyByIDAndClientID() {
+ _, err := s.storage.GetKeyByIDAndClientID(context.Background(), "", "")
+ s.ErrorContains(err, "aren't currently implemented")
+}
+
+func (s *oidcModelsSuite) TestStorageImpl_ValidateJWTProfileScopes() {
+ s.Run("empty", func() {
+ scopes, err := s.storage.ValidateJWTProfileScopes(context.Background(), "", []string{})
+ s.NoError(err)
+ s.Empty(scopes)
+ })
+ s.Run("all allowed", func() {
+ scopes, err := s.storage.ValidateJWTProfileScopes(context.Background(), "", []string{oidc.ScopeOpenID, oidc.ScopeEmail, oidc.ScopeProfile})
+ s.NoError(err)
+ s.Equal([]string{oidc.ScopeOpenID, oidc.ScopeEmail, oidc.ScopeProfile}, scopes)
+ })
+ s.Run("extra stuff", func() {
+ scopes, err := s.storage.ValidateJWTProfileScopes(context.Background(), "", []string{oidc.ScopeOpenID, oidc.ScopeEmail, oidc.ScopeProfile, "extra"})
+ s.NoError(err)
+ s.Equal([]string{oidc.ScopeOpenID, oidc.ScopeEmail, oidc.ScopeProfile}, scopes)
+ })
+}
+
+func (s *oidcModelsSuite) TestStorageImpl_Health() {
+ s.Run("error when nil db", func() {
+ storage := &storageImpl{}
+ s.Error(storage.Health(context.Background()))
+ })
+ s.Run("no error", func() {
+ s.NoError(s.storage.Health(context.Background()))
+ })
+}
+
+func (s *oidcModelsSuite) TestStorageImpl_createRefreshToken() {
+ clientID, _, err := s.GenerateClient(s.DB)
+ s.NoError(err)
+ refreshToken, _, err := s.storage.createRefreshToken(clientID, []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim}, s.TestData.User_Suitable().ID, time.Now())
+ s.NoError(err)
+ s.NotEmpty(refreshToken.ID)
+ s.Equal(clientID, refreshToken.ClientID)
+ s.Equal(oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim}, refreshToken.Scopes)
+ s.Equal(s.TestData.User_Suitable().ID, refreshToken.UserID)
+
+ var refreshTokens []RefreshToken
+ s.NoError(s.DB.Find(&refreshTokens).Error)
+ s.Len(refreshTokens, 1)
+}
+
+func (s *oidcModelsSuite) TestStorageImpl_revokeRefreshToken() {
+ clientID, _, err := s.GenerateClient(s.DB)
+ s.NoError(err)
+ refreshToken, _, err := s.storage.createRefreshToken(clientID, []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim}, s.TestData.User_Suitable().ID, time.Now())
+ s.NoError(err)
+
+ s.NoError(s.storage.revokeRefreshToken(refreshToken.ID))
+
+ var refreshTokens []RefreshToken
+ s.NoError(s.DB.Find(&refreshTokens).Error)
+ s.Len(refreshTokens, 0)
+}
+
+func (s *oidcModelsSuite) TestStorageImpl_renewRefreshToken() {
+ clientID, _, err := s.GenerateClient(s.DB)
+ s.NoError(err)
+ refreshToken, rawRefreshToken, err := s.storage.createRefreshToken(clientID, []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim}, s.TestData.User_Suitable().ID, time.Now())
+ s.NoError(err)
+
+ newRefreshToken, _, err := s.storage.renewRefreshToken(rawRefreshToken)
+ s.NoError(err)
+
+ s.NotEqual(refreshToken.ID, newRefreshToken.ID)
+ s.Equal(refreshToken.ClientID, newRefreshToken.ClientID)
+ s.Equal(refreshToken.Scopes, newRefreshToken.Scopes)
+ s.Equal(refreshToken.UserID, newRefreshToken.UserID)
+
+ var refreshTokens []RefreshToken
+ s.NoError(s.DB.Find(&refreshTokens).Error)
+ s.Len(refreshTokens, 1)
+}
+
+func (s *oidcModelsSuite) TestStorageImpl_createAccessToken() {
+ clientID, _, err := s.GenerateClient(s.DB)
+ s.NoError(err)
+ accessToken, err := s.storage.createAccessToken(clientID, nil, []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim}, s.TestData.User_Suitable().ID)
+ s.NoError(err)
+
+ var tokens []Token
+ s.NoError(s.DB.Find(&tokens).Error)
+ s.Len(tokens, 1)
+ s.Equal(accessToken.ID.String(), tokens[0].ID.String())
+ s.Equal(clientID, tokens[0].ClientID)
+ s.Equal(oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim}, tokens[0].Scopes)
+ s.Equal(s.TestData.User_Suitable().ID, tokens[0].UserID)
+}
+
+func (s *oidcModelsSuite) TestStorageImpl_setUserInfo() {
+ s.Run("no scopes", func() {
+ userInfo := &oidc.UserInfo{}
+ s.NoError(s.storage.setUserinfo(userInfo, s.TestData.User_Suitable().ID, []string{}))
+ s.Empty(userInfo.Subject)
+ s.Empty(userInfo.Email)
+ s.Empty(userInfo.EmailVerified)
+ s.Empty(userInfo.Name)
+ s.Empty(userInfo.Nickname)
+ s.Empty(userInfo.PreferredUsername)
+ s.Empty(userInfo.Locale)
+ s.Empty(userInfo.UpdatedAt)
+ s.Empty(userInfo.GivenName)
+ s.Empty(userInfo.FamilyName)
+ s.Empty(userInfo.Website)
+ s.Empty(userInfo.Claims)
+ })
+ s.Run("openid", func() {
+ userInfo := &oidc.UserInfo{}
+ s.NoError(s.storage.setUserinfo(userInfo, s.TestData.User_Suitable().ID, []string{oidc.ScopeOpenID}))
+ s.Equal(utils.UintToString(s.TestData.User_Suitable().ID), userInfo.Subject)
+ s.Empty(userInfo.Email)
+ s.Empty(userInfo.EmailVerified)
+ s.Empty(userInfo.Name)
+ s.Empty(userInfo.Nickname)
+ s.Empty(userInfo.PreferredUsername)
+ s.Empty(userInfo.Locale)
+ s.Empty(userInfo.UpdatedAt)
+ s.Empty(userInfo.GivenName)
+ s.Empty(userInfo.FamilyName)
+ s.Empty(userInfo.Website)
+ s.Empty(userInfo.Claims)
+ })
+ s.Run("email", func() {
+ userInfo := &oidc.UserInfo{}
+ s.NoError(s.storage.setUserinfo(userInfo, s.TestData.User_Suitable().ID, []string{oidc.ScopeEmail}))
+ s.Empty(userInfo.Subject)
+ s.Equal(s.TestData.User_Suitable().Email, userInfo.Email)
+ s.True(bool(userInfo.EmailVerified))
+ s.Empty(userInfo.Name)
+ s.Empty(userInfo.Nickname)
+ s.Empty(userInfo.PreferredUsername)
+ s.Empty(userInfo.Locale)
+ s.Empty(userInfo.UpdatedAt)
+ s.Empty(userInfo.GivenName)
+ s.Empty(userInfo.FamilyName)
+ s.Empty(userInfo.Website)
+ s.Empty(userInfo.Claims)
+ })
+ s.Run("profile", func() {
+ userInfo := &oidc.UserInfo{}
+ s.NoError(s.storage.setUserinfo(userInfo, s.TestData.User_Suitable().ID, []string{oidc.ScopeProfile}))
+ s.Empty(userInfo.Subject)
+ s.Empty(userInfo.Email)
+ s.Empty(userInfo.EmailVerified)
+ s.Equal(utils.PointerTo(s.TestData.User_Suitable()).NameOrUsername(), userInfo.Name)
+ s.Equal(utils.PointerTo(s.TestData.User_Suitable()).NameOrUsername(), userInfo.Nickname)
+ s.Equal(utils.PointerTo(s.TestData.User_Suitable()).AlphaNumericHyphenatedUsername(), userInfo.PreferredUsername)
+ s.Equal(oidc.NewLocale(language.AmericanEnglish), userInfo.Locale)
+ s.Equal(oidc.FromTime(s.TestData.User_Suitable().UpdatedAt), userInfo.UpdatedAt)
+ if s.TestData.User_Suitable().Name != nil {
+ nameParts := strings.Split(*s.TestData.User_Suitable().Name, " ")
+ s.Equal(nameParts[0], userInfo.GivenName)
+ s.Equal(nameParts[len(nameParts)-1], userInfo.FamilyName)
+ } else {
+ s.Empty(userInfo.GivenName)
+ s.Empty(userInfo.FamilyName)
+ }
+ s.Equal("https://broad.io/beehive/r/user/"+s.TestData.User_Suitable().Email, userInfo.Website)
+ s.Empty(userInfo.Claims)
+ })
+ s.Run("groups", func() {
+ userInfo := &oidc.UserInfo{}
+ s.NoError(s.storage.setUserinfo(userInfo, s.TestData.User_Suitable().ID, []string{groupsClaim}))
+ s.Empty(userInfo.Subject)
+ s.Empty(userInfo.Email)
+ s.Empty(userInfo.EmailVerified)
+ s.Empty(userInfo.Name)
+ s.Empty(userInfo.Nickname)
+ s.Empty(userInfo.PreferredUsername)
+ s.Empty(userInfo.Locale)
+ s.Empty(userInfo.UpdatedAt)
+ s.Empty(userInfo.GivenName)
+ s.Empty(userInfo.FamilyName)
+ s.Empty(userInfo.Website)
+ if s.NotEmpty(userInfo.Claims) && s.Contains(userInfo.Claims, groupsClaim) && s.IsType([]string{}, userInfo.Claims[groupsClaim]) {
+ groups := userInfo.Claims[groupsClaim].([]string)
+ for _, ra := range s.TestData.User_Suitable().Assignments {
+ if ra != nil && ra.IsActive() {
+ s.Contains(groups, *ra.Role.Name)
+ }
+ }
+ }
+ })
+ s.Run("openid email profile groups", func() {
+ userInfo := &oidc.UserInfo{}
+ s.NoError(s.storage.setUserinfo(userInfo, s.TestData.User_Suitable().ID, []string{oidc.ScopeOpenID, oidc.ScopeEmail, oidc.ScopeProfile, groupsClaim}))
+ s.Equal(utils.UintToString(s.TestData.User_Suitable().ID), userInfo.Subject)
+ s.Equal(s.TestData.User_Suitable().Email, userInfo.Email)
+ s.True(bool(userInfo.EmailVerified))
+ s.Equal(utils.PointerTo(s.TestData.User_Suitable()).NameOrUsername(), userInfo.Name)
+ s.Equal(utils.PointerTo(s.TestData.User_Suitable()).NameOrUsername(), userInfo.Nickname)
+ s.Equal(utils.PointerTo(s.TestData.User_Suitable()).AlphaNumericHyphenatedUsername(), userInfo.PreferredUsername)
+ s.Equal(oidc.NewLocale(language.AmericanEnglish), userInfo.Locale)
+ s.Equal(oidc.FromTime(s.TestData.User_Suitable().UpdatedAt), userInfo.UpdatedAt)
+ if s.TestData.User_Suitable().Name != nil {
+ nameParts := strings.Split(*s.TestData.User_Suitable().Name, " ")
+ s.Equal(nameParts[0], userInfo.GivenName)
+ s.Equal(nameParts[len(nameParts)-1], userInfo.FamilyName)
+ } else {
+ s.Empty(userInfo.GivenName)
+ s.Empty(userInfo.FamilyName)
+ }
+ s.Equal("https://broad.io/beehive/r/user/"+s.TestData.User_Suitable().Email, userInfo.Website)
+ if s.NotEmpty(userInfo.Claims) && s.Contains(userInfo.Claims, groupsClaim) && s.IsType([]string{}, userInfo.Claims[groupsClaim]) {
+ groups := userInfo.Claims[groupsClaim].([]string)
+ for _, ra := range s.TestData.User_Suitable().Assignments {
+ if ra != nil && ra.IsActive() {
+ s.Contains(groups, *ra.Role.Name)
+ }
+ }
+ }
+ })
+ s.Run("groups only active role assignments", func() {
+ suspendedRole := models.Role{
+ RoleFields: models.RoleFields{
+ Name: utils.PointerTo("test-role"),
+ },
+ }
+ s.SetSelfSuperAdminForDB()
+ s.NoError(s.DB.Create(&suspendedRole).Error)
+ suspendedRoleAssignment := models.RoleAssignment{
+ UserID: s.TestData.User_Suitable().ID,
+ RoleID: suspendedRole.ID,
+ RoleAssignmentFields: models.RoleAssignmentFields{
+ Suspended: utils.PointerTo(true),
+ },
+ }
+ s.NoError(s.DB.Create(&suspendedRoleAssignment).Error)
+
+ // reload user to get updated assignments
+ var user models.User
+ s.NoError(s.DB.Scopes(models.ReadUserScope).First(&user, s.TestData.User_Suitable().ID).Error)
+
+ // make this test extra explicit -- set these bools to true at key points
+ var suspendedRoleAssignmentOnUser, suspendedRoleAssignmentOmittedFromUserinfo bool
+
+ userInfo := &oidc.UserInfo{}
+ s.NoError(s.storage.setUserinfo(userInfo, user.ID, []string{groupsClaim}))
+ if s.NotEmpty(userInfo.Claims) && s.Contains(userInfo.Claims, groupsClaim) && s.IsType([]string{}, userInfo.Claims[groupsClaim]) {
+ groups := userInfo.Claims[groupsClaim].([]string)
+ for _, ra := range user.Assignments {
+ if ra != nil {
+ if *ra.Role.Name == *suspendedRole.Name {
+ suspendedRoleAssignmentOnUser = true
+ }
+
+ if ra.IsActive() {
+ s.Contains(groups, *ra.Role.Name)
+ } else if s.NotContains(groups, *ra.Role.Name) && *ra.Role.Name == *suspendedRole.Name {
+ suspendedRoleAssignmentOmittedFromUserinfo = true
+ }
+ }
+ }
+ }
+
+ s.True(suspendedRoleAssignmentOnUser)
+ s.True(suspendedRoleAssignmentOmittedFromUserinfo)
+ })
+}
diff --git a/sherlock/internal/oidc_models/test_helper.go b/sherlock/internal/oidc_models/test_helper.go
new file mode 100644
index 000000000..279c58ecd
--- /dev/null
+++ b/sherlock/internal/oidc_models/test_helper.go
@@ -0,0 +1,64 @@
+package oidc_models
+
+import (
+ "crypto/rand"
+ "crypto/sha512"
+ "encoding/hex"
+ "github.com/zitadel/oidc/v3/pkg/oidc"
+ "github.com/zitadel/oidc/v3/pkg/op"
+ "golang.org/x/crypto/pbkdf2"
+ "gorm.io/gorm"
+ "time"
+)
+
+type TestClientHelper struct{}
+
+func (h TestClientHelper) GeneratedClientRedirectURI() string {
+ return "http://localhost:8080/test/fake/redirect"
+}
+
+func (h TestClientHelper) GeneratedClientPostLogoutRedirectURI() string {
+ return "http://localhost:8080/test/fake/postlogout"
+}
+
+func (h TestClientHelper) GenerateClient(db *gorm.DB) (clientID string, clientSecret string, err error) {
+ clientIDBytes := make([]byte, 16)
+ _, err = rand.Read(clientIDBytes)
+ if err != nil {
+ return "", "", err
+ }
+ clientID = hex.EncodeToString(clientIDBytes)
+
+ clientSecretBytes := make([]byte, 32)
+ _, err = rand.Read(clientSecretBytes)
+ if err != nil {
+ return "", "", err
+ }
+ clientSecret = hex.EncodeToString(clientSecretBytes)
+
+ clientSecretSalt := make([]byte, 32)
+ _, err = rand.Read(clientSecretSalt)
+ if err != nil {
+ return "", "", err
+ }
+
+ testClientHashIterations := 1_000 // Low low low value!!! Needs to be 210_000+ in production, this just makes tests run faster
+ clientSecretHash := pbkdf2.Key([]byte(clientSecret), clientSecretSalt, testClientHashIterations, 32, sha512.New)
+
+ client := Client{
+ ID: clientID,
+ ClientSecretHash: clientSecretHash,
+ ClientSecretSalt: clientSecretSalt,
+ ClientSecretIterations: testClientHashIterations,
+ ClientRedirectURIs: oidc.SpaceDelimitedArray{h.GeneratedClientRedirectURI()},
+ ClientPostLogoutRedirectURIs: oidc.SpaceDelimitedArray{h.GeneratedClientPostLogoutRedirectURI()},
+ ClientApplicationType: op.ApplicationTypeWeb,
+ ClientAuthMethod: oidc.AuthMethodBasic,
+ ClientIDTokenLifetime: (15 * time.Minute).Nanoseconds(),
+ ClientDevMode: true,
+ ClientClockSkew: (15 * time.Second).Nanoseconds(),
+ }
+
+ err = db.Create(&client).Error
+ return clientID, clientSecret, err
+}
diff --git a/sherlock/internal/oidc_models/token.go b/sherlock/internal/oidc_models/token.go
new file mode 100644
index 000000000..fa2569ff3
--- /dev/null
+++ b/sherlock/internal/oidc_models/token.go
@@ -0,0 +1,40 @@
+package oidc_models
+
+import (
+ "github.com/broadinstitute/sherlock/go-shared/pkg/utils"
+ "github.com/broadinstitute/sherlock/sherlock/internal/models"
+ "github.com/google/uuid"
+ "github.com/zitadel/oidc/v3/pkg/oidc"
+ "github.com/zitadel/oidc/v3/pkg/op"
+ "time"
+)
+
+var _ op.TokenRequest = &Token{}
+
+type Token struct {
+ ID uuid.UUID `gorm:"primaryKey"`
+ CreatedAt time.Time
+
+ RefreshToken *RefreshToken `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE"`
+ RefreshTokenID *uuid.UUID
+
+ Client *Client `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE"`
+ ClientID string // AKA Audience, Application ID
+ Scopes oidc.SpaceDelimitedArray
+ Expiry time.Time
+
+ User *models.User
+ UserID uint
+}
+
+func (t *Token) GetSubject() string {
+ return utils.UintToString(t.UserID)
+}
+
+func (t *Token) GetAudience() []string {
+ return []string{t.ClientID}
+}
+
+func (t *Token) GetScopes() []string {
+ return t.Scopes
+}
diff --git a/sherlock/internal/suitability_synchronization/suspend_role_assignments.go b/sherlock/internal/suitability_synchronization/suspend_role_assignments.go
index 20721701b..d0f308f8d 100644
--- a/sherlock/internal/suitability_synchronization/suspend_role_assignments.go
+++ b/sherlock/internal/suitability_synchronization/suspend_role_assignments.go
@@ -73,9 +73,9 @@ func suspendRoleAssignments(ctx context.Context, db *gorm.DB) error {
Suspended: utils.PointerTo(false),
},
}).Error; err != nil {
- errors = append(errors, fmt.Errorf("failed to un-suspend %s's assignment for %s: %w", assignment.User.NameOrEmailHandle(), *role.Name, err))
+ errors = append(errors, fmt.Errorf("failed to un-suspend %s's assignment for %s: %w", assignment.User.NameOrUsername(), *role.Name, err))
} else {
- summaries = append(summaries, fmt.Sprintf("un-suspended %s's assignment for %s", assignment.User.NameOrEmailHandle(), *role.Name))
+ summaries = append(summaries, fmt.Sprintf("un-suspended %s's assignment for %s", assignment.User.NameOrUsername(), *role.Name))
}
} else if !suitable && (assignment.Suspended == nil || !*assignment.Suspended) {
roleIDsRequiringPropagation[roleID] = struct{}{}
@@ -84,9 +84,9 @@ func suspendRoleAssignments(ctx context.Context, db *gorm.DB) error {
Suspended: utils.PointerTo(true),
},
}).Error; err != nil {
- errors = append(errors, fmt.Errorf("failed to suspend %s's assignment for %s: %w", assignment.User.NameOrEmailHandle(), *role.Name, err))
+ errors = append(errors, fmt.Errorf("failed to suspend %s's assignment for %s: %w", assignment.User.NameOrUsername(), *role.Name, err))
} else {
- summaries = append(summaries, fmt.Sprintf("suspended %s's assignment for %s", assignment.User.NameOrEmailHandle(), *role.Name))
+ summaries = append(summaries, fmt.Sprintf("suspended %s's assignment for %s", assignment.User.NameOrUsername(), *role.Name))
}
}
}