diff --git a/go-shared/go.mod b/go-shared/go.mod index 11aaed9a8..93ac5df5c 100644 --- a/go-shared/go.mod +++ b/go-shared/go.mod @@ -1,5 +1,5 @@ module github.com/broadinstitute/sherlock/go-shared -go 1.21 +go 1.22 require github.com/google/go-cmp v0.6.0 diff --git a/sherlock-go-client/go.mod b/sherlock-go-client/go.mod index 11ef6f1ef..50435f62a 100644 --- a/sherlock-go-client/go.mod +++ b/sherlock-go-client/go.mod @@ -1,6 +1,6 @@ module github.com/broadinstitute/sherlock/sherlock-go-client -go 1.21 +go 1.22 require ( github.com/go-openapi/errors v0.22.0 diff --git a/sherlock-webhook-proxy/go.mod b/sherlock-webhook-proxy/go.mod index a49047dd8..58fb84bad 100644 --- a/sherlock-webhook-proxy/go.mod +++ b/sherlock-webhook-proxy/go.mod @@ -1,6 +1,6 @@ module github.com/broadinstitute/sherlock/sherlock-webhook-proxy -go 1.21 +go 1.22 require ( github.com/GoogleCloudPlatform/functions-framework-go v1.8.1 diff --git a/sherlock/Dockerfile b/sherlock/Dockerfile index 87a81e440..4da88328d 100644 --- a/sherlock/Dockerfile +++ b/sherlock/Dockerfile @@ -1,11 +1,10 @@ -ARG GO_VERSION='1.21' -ARG ALPINE_VERSION='3.19' +ARG GO_VERSION='1.22' -FROM golang:${GO_VERSION}-alpine${ALPINE_VERSION} AS build +# https://github.com/microsoft/go-images/tree/microsoft/main +FROM mcr.microsoft.com/oss/go/microsoft/golang:${GO_VERSION}-fips-cbl-mariner2.0 AS build ARG BUILD_VERSION='development' WORKDIR /build/sherlock -ENV CGO_ENABLED=0 -ENV GOBIN=/bin +ENV CGO_ENABLED=1 COPY sherlock/go.mod sherlock/go.sum ./ COPY go-shared ../go-shared/ @@ -14,8 +13,13 @@ RUN go mod download && go mod verify COPY sherlock ./ RUN go build -buildvcs=false -ldflags="-X 'github.com/broadinstitute/sherlock/go-shared/pkg/version.BuildVersion=${BUILD_VERSION}'" -o /bin/sherlock ./cmd/... -# FROM alpine:${ALPINE_VERSION} as runtime <-- use this if you hit issues -FROM gcr.io/distroless/static:nonroot AS runtime +# Check that the binary is FIPS-capable, fail if not +RUN go get github.com/acardace/fips-detect && \ + go run github.com/acardace/fips-detect /bin/sherlock \ + | grep -E 'FIPS-capable Go binary.*Yes' + +# https://mcr.microsoft.com/en-us/product/cbl-mariner/distroless/minimal/about +FROM mcr.microsoft.com/cbl-mariner/distroless/base:2.0-nonroot AS runtime COPY --from=build /bin/sherlock /bin/sherlock ENTRYPOINT [ "/bin/sherlock" ] diff --git a/sherlock/config/default_config.yaml b/sherlock/config/default_config.yaml index def650945..fd6b1a9d5 100644 --- a/sherlock/config/default_config.yaml +++ b/sherlock/config/default_config.yaml @@ -69,6 +69,56 @@ db: ignoreNotFoundWarning: true level: warn +# Configures Sherlock's own OIDC provider, not to be confused with its capability to interpret tokens +# from IAP or GitHub Actions. +oidc: + enable: true + # The issuer URL of Sherlock itself. This should be scheme + host + "/oidc", because Sherlock + # serves its own OIDC provider at that sub-path. This should also generally be an in-cluster + # address, because the IAP in front of the public endpoints isn't in-spec. This is fine as long + # as we just use Sherlock as an in-cluster authorization service. + # + # An example is https://sherlock-api-service.sherlock.svc.cluster.local/oidc + issuerUrl: http://localhost:8080/oidc + # The *public* side of Sherlock's OIDC issuer. This should be a normally-accessible URL that should + # go to the same destination as issuerUrl above. This is automatically used in the OIDC discovery + # config to tell clients how to have *users* authenticate against Sherlock. The downstream system + # will talk to issuerUrl in-cluster, but end-users will need to talk to publicIssuerUrl. + # + # An example is https://sherlock.dsp-devops-prod.broadinstitute.org/oidc + publicIssuerUrl: http://localhost:8080/oidc + + # The key that Sherlock should use to AES-256 encrypt internal data it sends to clients. This is + # used in two places by the underlying OIDC library: + # 1. Encrypting "{Token.ID}:{User.ID}" to create access tokens returned to clients + # 2. Encrypting "{AuthRequest.ID}" to create authorization codes returned to clients + # This does need to be rotated but doing so is potentially disruptive; Sherlock will cease + # respecting access tokens or authorization codes it has issued. + # + # Sherlock will error on boot if this doesn't parse from a hex string to 32 bytes. You'll probably + # want to pass this in the environment with SHERLOCK_oidc_encryptionKeyHex. It should be passed + # in hex format. + encryptionKeyHex: 7265706c6163652d6d652d776974682d33322d627974652d6b65792d2d2d2d2d # "replace-me-with-32-byte-key----" + # The duration that ID and access tokens vended to clients should be valid for. + tokenDuration: 15m + # The duration that refresh tokens vended to clients should be valid for. + refreshTokenDuration: 30m + # The duration that a particular signing key should be used before being rotated. + signingKeyPrimaryDuration: 4h + # The time after which a signing key should be deleted (and its signatures no longer accepted) + # after it has been rotated. This should be longer than all token durations so that we + # continue to respect our own signatures until they'd expire on their own. + signingKeyPostRotationDuration: 2h + # When enabled, Sherlock will use Google Cloud KMS to symmetrically encrypt the private keys + # it stores in its own database. This is a defense-in-depth measure to prevent key leakage in + # the event of SQL injection or other database compromise. + # + # This must be true when mode is not "debug". + signingKeyEncryptionKMSEnable: false + # The fully-qualified name of the KMS key to use when signingKeyEncryptionKMSEnable is true. + signingKeyEncryptionKMSKeyName: projects/some-project/locations/some-location/keyRings/some-key-ring/cryptoKeys/some-key + + auth: githubActionsOIDC: issuer: https://token.actions.githubusercontent.com @@ -217,6 +267,18 @@ argoCd: environmentUrlFormatString: https://argocd.dsp-devops-prod.broadinstitute.org/applications?showFavorites=false&proj=&sync=&health=&namespace=&cluster=&labels=env%%253D%s model: + roles: + # When set, Sherlock won't ever report an Environment/Cluster RequiredRole field as being null. + # Instead, it will substitute this value in its place (even though it won't be stored in the database). + # This can be useful in that it means downstream consumers don't need null handling like + # `requiredRole ?? "all-users"`. While that's simple, it's actually easier at a security/compliance + # level to say that Sherlock defines it and anything else uses it verbatim. (This was a specific + # request from appsec for this reason`.) + # + # Even if this is set, Sherlock will allow setting the field to empty to clear it out -- but will then + # respond in the API as if it's been set to this value. Note that the role set here needs to already + # exist or downstream consumers could have issues. + substituteEmptyRequiredRoleWithValue: environments: templates: # Uses appVersionResolver = "none", chartVersionResolver = "latest", and helmfileRef = "HEAD" diff --git a/sherlock/config/test_config.yaml b/sherlock/config/test_config.yaml index 2efb17ba6..7e74fb878 100644 --- a/sherlock/config/test_config.yaml +++ b/sherlock/config/test_config.yaml @@ -44,6 +44,8 @@ pactbroker: enable: false model: + roles: + substituteEmptyRequiredRoleWithValue: all-users environments: templates: autoPopulateCharts: diff --git a/sherlock/db/migrations/000095_oidc.down.sql b/sherlock/db/migrations/000095_oidc.down.sql new file mode 100644 index 000000000..622c7e409 --- /dev/null +++ b/sherlock/db/migrations/000095_oidc.down.sql @@ -0,0 +1,11 @@ +drop table signing_keys; + +drop table tokens; + +drop table refresh_tokens; + +drop table auth_request_codes; + +drop table auth_requests; + +drop table clients; diff --git a/sherlock/db/migrations/000095_oidc.up.sql b/sherlock/db/migrations/000095_oidc.up.sql new file mode 100644 index 000000000..9b8cdd1d6 --- /dev/null +++ b/sherlock/db/migrations/000095_oidc.up.sql @@ -0,0 +1,98 @@ +create table clients +( + id text not null + primary key, + client_secret_hash bytea, + client_secret_salt bytea, + client_secret_iterations bigint, + client_redirect_uris text, + client_post_logout_redirect_uris text, + client_application_type text, + client_auth_method text, + client_id_token_lifetime bigint, + client_dev_mode boolean, + client_clock_skew bigint +); + +create table auth_requests +( + id text not null + primary key, + created_at timestamp with time zone, + done_at timestamp with time zone, + client_id text + constraint fk_auth_requests_client + references clients + on update cascade on delete cascade, + nonce text, + redirect_uri text, + response_type text, + response_mode text, + scopes text, + state text, + code_challenge text, + code_challenge_method text, + user_id bigint + constraint fk_auth_requests_user + references users + on update cascade on delete cascade +); + +create table auth_request_codes +( + code text not null + primary key, + created_at timestamp with time zone, + auth_request_id text + constraint fk_auth_request_codes_auth_request + references auth_requests + on update cascade on delete cascade +); + +create table refresh_tokens +( + id text not null + primary key, + created_at timestamp with time zone, + token_hash bytea unique, + client_id text + constraint fk_refresh_tokens_client + references clients + on update cascade on delete cascade, + scopes text, + original_auth_at timestamp with time zone, + user_id bigint + constraint fk_refresh_tokens_user + references users + on update cascade on delete cascade +); + +create table tokens +( + id text not null + primary key, + created_at timestamp with time zone, + refresh_token_id text + constraint fk_tokens_refresh_token + references refresh_tokens + on update cascade on delete cascade, + client_id text + constraint fk_tokens_client + references clients + on update cascade on delete cascade, + scopes text, + expiry timestamp with time zone, + user_id bigint + constraint fk_tokens_user + references users + on update cascade on delete cascade +); + +create table signing_keys +( + id text not null + primary key, + created_at timestamp with time zone, + public_key bytea, + private_key bytea +); diff --git a/sherlock/go.mod b/sherlock/go.mod index ec81b6d3f..0f2e0020f 100644 --- a/sherlock/go.mod +++ b/sherlock/go.mod @@ -1,9 +1,12 @@ module github.com/broadinstitute/sherlock/sherlock -go 1.21 +go 1.22 + +toolchain go1.22.5 require ( cloud.google.com/go/cloudsqlconn v1.11.1 + cloud.google.com/go/kms v1.18.3 contrib.go.opencensus.io/exporter/prometheus v0.4.2 github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0 github.com/PagerDuty/go-pagerduty v1.8.0 @@ -13,6 +16,7 @@ require ( github.com/dustinkirkland/golang-petname v0.0.0-20230626224747-e794b9370d49 github.com/gin-contrib/cors v1.7.2 github.com/gin-gonic/gin v1.10.0 + github.com/go-jose/go-jose/v4 v4.0.2 github.com/golang-migrate/migrate/v4 v4.17.1 github.com/google/go-cmp v0.6.0 github.com/google/go-github/v58 v58.0.0 @@ -29,10 +33,13 @@ require ( github.com/swaggo/files v1.0.1 github.com/swaggo/gin-swagger v1.6.0 github.com/swaggo/swag v1.16.3 + github.com/zitadel/oidc/v3 v3.26.0 go.opencensus.io v0.24.0 + golang.org/x/crypto v0.25.0 golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa golang.org/x/net v0.27.0 golang.org/x/oauth2 v0.21.0 + golang.org/x/text v0.16.0 google.golang.org/api v0.188.0 gorm.io/datatypes v1.2.1 gorm.io/driver/postgres v1.5.9 @@ -42,15 +49,19 @@ require ( replace github.com/broadinstitute/sherlock/go-shared => ../go-shared require ( + cloud.google.com/go v0.115.0 // indirect cloud.google.com/go/auth v0.7.0 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect cloud.google.com/go/compute/metadata v0.4.0 // indirect + cloud.google.com/go/iam v1.1.10 // indirect + cloud.google.com/go/longrunning v0.5.9 // indirect filippo.io/edwards25519 v1.1.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/azcore v1.12.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/internal v1.9.0 // indirect github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 // indirect github.com/KyleBanks/depth v1.2.1 // indirect github.com/beorn7/perks v1.0.1 // indirect + github.com/bmatcuk/doublestar/v4 v4.6.1 // indirect github.com/bytedance/sonic v1.11.6 // indirect github.com/bytedance/sonic/loader v0.1.1 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect @@ -62,9 +73,10 @@ require ( github.com/fsnotify/fsnotify v1.6.0 // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/gin-contrib/sse v0.1.0 // indirect + github.com/go-chi/chi/v5 v5.1.0 // indirect github.com/go-kit/log v0.2.1 // indirect github.com/go-logfmt/logfmt v0.5.1 // indirect - github.com/go-logr/logr v1.4.1 // indirect + github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-openapi/jsonpointer v0.19.6 // indirect github.com/go-openapi/jsonreference v0.20.2 // indirect @@ -82,6 +94,7 @@ require ( github.com/google/s2a-go v0.1.7 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect github.com/googleapis/gax-go/v2 v2.12.5 // indirect + github.com/gorilla/securecookie v1.1.2 // indirect github.com/gorilla/websocket v1.4.2 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect @@ -116,6 +129,8 @@ require ( github.com/mitchellh/reflectwalk v1.0.2 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/muhlemmer/gu v0.3.1 // indirect + github.com/muhlemmer/httpforwarded v0.1.0 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect github.com/pmezard/go-difflib v1.0.0 // indirect @@ -125,6 +140,8 @@ require ( github.com/prometheus/common v0.37.0 // indirect github.com/prometheus/procfs v0.8.0 // indirect github.com/prometheus/statsd_exporter v0.22.8 // indirect + github.com/rs/cors v1.11.0 // indirect + github.com/sirupsen/logrus v1.9.3 // indirect github.com/spf13/afero v1.11.0 // indirect github.com/spf13/cobra v1.8.1 // indirect github.com/spf13/pflag v1.0.5 // indirect @@ -132,19 +149,21 @@ require ( github.com/stretchr/objx v0.5.2 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect + github.com/zitadel/logging v0.6.0 // indirect + github.com/zitadel/schema v1.3.0 // indirect + go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect - go.opentelemetry.io/otel v1.24.0 // indirect - go.opentelemetry.io/otel/metric v1.24.0 // indirect - go.opentelemetry.io/otel/trace v1.24.0 // indirect + go.opentelemetry.io/otel v1.28.0 // indirect + go.opentelemetry.io/otel/metric v1.28.0 // indirect + go.opentelemetry.io/otel/trace v1.28.0 // indirect go.uber.org/atomic v1.10.0 // indirect golang.org/x/arch v0.8.0 // indirect - golang.org/x/crypto v0.25.0 // indirect golang.org/x/sync v0.7.0 // indirect golang.org/x/sys v0.22.0 // indirect - golang.org/x/text v0.16.0 // indirect golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 // indirect + google.golang.org/genproto v0.0.0-20240708141625-4ad9e859172b // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20240701130421-f6361c86f094 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240709173604-40e1e62336c5 // indirect google.golang.org/grpc v1.64.1 // indirect google.golang.org/protobuf v1.34.2 // indirect diff --git a/sherlock/go.sum b/sherlock/go.sum index 47a2f5c2d..45392da2d 100644 --- a/sherlock/go.sum +++ b/sherlock/go.sum @@ -13,6 +13,8 @@ cloud.google.com/go v0.56.0/go.mod h1:jr7tqZxxKOVYizybht9+26Z/gUq7tiRzu+ACVAMbKV cloud.google.com/go v0.57.0/go.mod h1:oXiQ6Rzq3RAkkY7N6t3TcE6jE+CIBBbA36lwQ1JyzZs= cloud.google.com/go v0.62.0/go.mod h1:jmCYTdRCQuc1PHIIJ/maLInMho30T/Y0M4hTdTShOYc= cloud.google.com/go v0.65.0/go.mod h1:O5N8zS7uWy9vkA9vayVHs65eM1ubvY4h553ofrNHObY= +cloud.google.com/go v0.115.0 h1:CnFSK6Xo3lDYRoBKEcAtia6VSC837/ZkJuRduSFnr14= +cloud.google.com/go v0.115.0/go.mod h1:8jIM5vVgoAEoiVxQ/O4BFTfHqulPZgs/ufEzMcFMdWU= cloud.google.com/go/auth v0.7.0 h1:kf/x9B3WTbBUHkC+1VS8wwwli9TzhSt0vSTVBmMR8Ts= cloud.google.com/go/auth v0.7.0/go.mod h1:D+WqdrpcjmiCgWrXmLLxOVq1GACoE36chW6KXoEvuIw= cloud.google.com/go/auth/oauth2adapt v0.2.2 h1:+TTV8aXpjeChS9M+aTtN/TjdQnzJvmzKFt//oWu7HX4= @@ -29,6 +31,12 @@ cloud.google.com/go/compute/metadata v0.4.0 h1:vHzJCWaM4g8XIcm8kopr3XmDA4Gy/lblD cloud.google.com/go/compute/metadata v0.4.0/go.mod h1:SIQh1Kkb4ZJ8zJ874fqVkslA29PRXuleyj6vOzlbK7M= cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE= cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk= +cloud.google.com/go/iam v1.1.10 h1:ZSAr64oEhQSClwBL670MsJAW5/RLiC6kfw3Bqmd5ZDI= +cloud.google.com/go/iam v1.1.10/go.mod h1:iEgMq62sg8zx446GCaijmA2Miwg5o3UbO+nI47WHJps= +cloud.google.com/go/kms v1.18.3 h1:8+Z2S4bQDSCdghB5ZA5dVDDJTLmnkRlowtFiXqMFd74= +cloud.google.com/go/kms v1.18.3/go.mod h1:y/Lcf6fyhbdn7MrG1VaDqXxM8rhOBc5rWcWAhcvZjQU= +cloud.google.com/go/longrunning v0.5.9 h1:haH9pAuXdPAMqHvzX0zlWQigXT7B0+CL4/2nXXdBo5k= +cloud.google.com/go/longrunning v0.5.9/go.mod h1:HD+0l9/OOW0za6UWdKJtXoFAX/BGg/3Wj8p10NeWF7c= cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I= cloud.google.com/go/pubsub v1.1.0/go.mod h1:EwwdRX2sKPjnvnqCa270oGRyludottCI76h+R3AArQw= cloud.google.com/go/pubsub v1.2.0/go.mod h1:jhfEVHT8odbXTkndysNHCcx0awwzvfOlguIAii9o8iA= @@ -87,6 +95,8 @@ github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+Ce github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= +github.com/bmatcuk/doublestar/v4 v4.6.1 h1:FH9SifrbvJhnlQpztAx++wlkk70QBf0iBWDwNy7PA4I= +github.com/bmatcuk/doublestar/v4 v4.6.1/go.mod h1:xBQ8jztBU6kakFMg+8WGxn0c6z1fTSPVIjEY1Wr7jzc= github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0= github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= @@ -158,9 +168,13 @@ github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU= github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y= +github.com/go-chi/chi/v5 v5.1.0 h1:acVI1TYaD+hhedDJ3r54HyA6sExp3HfXq7QWEEY/xMw= +github.com/go-chi/chi/v5 v5.1.0/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= +github.com/go-jose/go-jose/v4 v4.0.2 h1:R3l3kkBds16bO7ZFAEEcofK0MkrAJt3jlJznWZG0nvk= +github.com/go-jose/go-jose/v4 v4.0.2/go.mod h1:WVf9LFMHh/QVrmqrOfqun0C45tMe3RoiKJMPvgWwLfY= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= @@ -174,8 +188,8 @@ github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG github.com/go-logfmt/logfmt v0.5.1 h1:otpy5pqBCBZ1ng9RQ0dPu4PN7ba75Y/aA+UpowDyNVA= github.com/go-logfmt/logfmt v0.5.1/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= -github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ= -github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= +github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= @@ -233,6 +247,8 @@ github.com/golang/mock v1.4.0/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt github.com/golang/mock v1.4.1/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= +github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= +github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -275,6 +291,8 @@ github.com/google/go-github/v58 v58.0.0/go.mod h1:k4hxDKEfoWpSqFlc8LTpGd9fu2KrV1 github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= +github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= @@ -296,6 +314,8 @@ github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+ github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/googleapis/gax-go/v2 v2.12.5 h1:8gw9KZK8TiVKB6q3zHY3SBzLnrGp6HQjyfYBYGmXdxA= github.com/googleapis/gax-go/v2 v2.12.5/go.mod h1:BUDKcWo+RaKq5SC9vVYL0wLADa3VcfswbOMMRmB9H3E= +github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= +github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= @@ -487,6 +507,10 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= +github.com/muhlemmer/gu v0.3.1 h1:7EAqmFrW7n3hETvuAdmFmn4hS8W+z3LgKtrnow+YzNM= +github.com/muhlemmer/gu v0.3.1/go.mod h1:YHtHR+gxM+bKEIIs7Hmi9sPT3ZDUvTN/i88wQpZkrdM= +github.com/muhlemmer/httpforwarded v0.1.0 h1:x4DLrzXdliq8mprgUMR0olDvHGkou5BJsK/vWUetyzY= +github.com/muhlemmer/httpforwarded v0.1.0/go.mod h1:yo9czKedo2pdZhoXe+yDkGVbU0TJ0q9oQ90BVoDEtw0= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= @@ -556,6 +580,8 @@ github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6L github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/rs/cors v1.11.0 h1:0B9GE/r9Bc2UxRMMtymBkHTenPkHDv0CW4Y98GBY+po= +github.com/rs/cors v1.11.0/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= @@ -569,6 +595,8 @@ github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/slack-go/slack v0.13.1 h1:6UkM3U1OnbhPsYeb1IMkQ6HSNOSikWluwOncJt4Tz/o= github.com/slack-go/slack v0.13.1/go.mod h1:hlGi5oXA+Gt+yWTPP0plCdRKmjsDxecdHxYQdlMQKOw= github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= @@ -615,6 +643,12 @@ github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9de github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/zitadel/logging v0.6.0 h1:t5Nnt//r+m2ZhhoTmoPX+c96pbMarqJvW1Vq6xFTank= +github.com/zitadel/logging v0.6.0/go.mod h1:Y4CyAXHpl3Mig6JOszcV5Rqqsojj+3n7y2F591Mp/ow= +github.com/zitadel/oidc/v3 v3.26.0 h1:BG3OUK+JpuKz7YHJIyUxL5Sl2JV6ePkG42UP4Xv3J2w= +github.com/zitadel/oidc/v3 v3.26.0/go.mod h1:Cx6AYPTJO5q2mjqF3jaknbKOUjpq1Xui0SYvVhkKuXU= +github.com/zitadel/schema v1.3.0 h1:kQ9W9tvIwZICCKWcMvCEweXET1OcOyGEuFbHs4o5kg0= +github.com/zitadel/schema v1.3.0/go.mod h1:NptN6mkBDFvERUCvZHlvWmmME+gmZ44xzwRXwhzsbtc= go.etcd.io/etcd/api/v3 v3.5.4/go.mod h1:5GB2vv4A4AOn3yk7MftYGHkUfGtDHnEraIjym4dYz5A= go.etcd.io/etcd/client/pkg/v3 v3.5.4/go.mod h1:IJHfcCEKxYu1Os13ZdwCwIUTUVGYTSAM3YSwc9/Ac1g= go.etcd.io/etcd/client/v3 v3.5.4/go.mod h1:ZaRkVgBZC+L+dLCjTcF1hRXpgZXQPOvnA/Ak/gq3kiY= @@ -630,12 +664,14 @@ go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.4 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0/go.mod h1:Mjt1i1INqiaoZOMGR1RIUJN+i3ChKoFRqzrRQhlkbs0= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 h1:jq9TW8u3so/bN+JPT166wjOI6/vQPF6Xe7nMNIltagk= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0/go.mod h1:p8pYQP+m5XfbZm9fxtSKAbM6oIllS7s2AfxrChvc7iw= -go.opentelemetry.io/otel v1.24.0 h1:0LAOdjNmQeSTzGBzduGe/rU4tZhMwL5rWgtp9Ku5Jfo= -go.opentelemetry.io/otel v1.24.0/go.mod h1:W7b9Ozg4nkF5tWI5zsXkaKKDjdVjpD4oAt9Qi/MArHo= -go.opentelemetry.io/otel/metric v1.24.0 h1:6EhoGWWK28x1fbpA4tYTOWBkPefTDQnb8WSGXlc88kI= -go.opentelemetry.io/otel/metric v1.24.0/go.mod h1:VYhLe1rFfxuTXLgj4CBiyz+9WYBA8pNGJgDcSFRKBco= -go.opentelemetry.io/otel/trace v1.24.0 h1:CsKnnL4dUAr/0llH9FKuc698G04IrpWV0MQA/Y1YELI= -go.opentelemetry.io/otel/trace v1.24.0/go.mod h1:HPc3Xr/cOApsBI154IU0OI0HJexz+aw5uPdbs3UCjNU= +go.opentelemetry.io/otel v1.28.0 h1:/SqNcYk+idO0CxKEUOtKQClMK/MimZihKYMruSMViUo= +go.opentelemetry.io/otel v1.28.0/go.mod h1:q68ijF8Fc8CnMHKyzqL6akLO46ePnjkgfIMIjUIX9z4= +go.opentelemetry.io/otel/metric v1.28.0 h1:f0HGvSl1KRAU1DLgLGFjrwVyismPlnuU6JD6bOeuA5Q= +go.opentelemetry.io/otel/metric v1.28.0/go.mod h1:Fb1eVBFZmLVTMb6PPohq3TO9IIhUisDsbJoL/+uQW4s= +go.opentelemetry.io/otel/sdk v1.24.0 h1:YMPPDNymmQN3ZgczicBY3B6sf9n62Dlj9pWD3ucgoDw= +go.opentelemetry.io/otel/sdk v1.24.0/go.mod h1:KVrIYw6tEubO9E96HQpcmpTKDVn9gdv35HoYiQWGDFg= +go.opentelemetry.io/otel/trace v1.28.0 h1:GhQ9cUuQGmNDd5BTCP2dAvv75RdMxEfTmYejp+lkx9g= +go.opentelemetry.io/otel/trace v1.28.0/go.mod h1:jPyXzNPg6da9+38HEwElrQiHlVMTnVfM3/yv2OlIHaI= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= @@ -812,6 +848,7 @@ golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220708085239-5a0f0661e09d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -948,8 +985,10 @@ google.golang.org/genproto v0.0.0-20200729003335-053ba62fc06f/go.mod h1:FWY/as6D google.golang.org/genproto v0.0.0-20200804131852-c06518451d9c/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20200825200019-8632dd797987/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20210602131652-f16073e35f0c/go.mod h1:UODoCrxHCcBojKKwX1terBiRUaqAsFqJiF615XL43r0= -google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 h1:MuYw1wJzT+ZkybKfaOXKp5hJiZDn2iHaXRw0mRYdHSc= -google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4/go.mod h1:px9SlOOZBg1wM1zdnr8jEL4CNGUBZ+ZKYtNPApNQc4c= +google.golang.org/genproto v0.0.0-20240708141625-4ad9e859172b h1:dSTjko30weBaMj3eERKc0ZVXW4GudCswM3m+P++ukU0= +google.golang.org/genproto v0.0.0-20240708141625-4ad9e859172b/go.mod h1:FfBgJBJg9GcpPvKIuHSZ/aE1g2ecGL74upMzGZjiGEY= +google.golang.org/genproto/googleapis/api v0.0.0-20240701130421-f6361c86f094 h1:0+ozOGcrp+Y8Aq8TLNN2Aliibms5LEzsq99ZZmAGYm0= +google.golang.org/genproto/googleapis/api v0.0.0-20240701130421-f6361c86f094/go.mod h1:fJ/e3If/Q67Mj99hin0hMhiNyCRmt6BQ2aWIJshUSJw= google.golang.org/genproto/googleapis/rpc v0.0.0-20240709173604-40e1e62336c5 h1:SbSDUWW1PAO24TNpLdeheoYPd7kllICcLU52x6eD4kQ= google.golang.org/genproto/googleapis/rpc v0.0.0-20240709173604-40e1e62336c5/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY= google.golang.org/grpc v1.14.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= diff --git a/sherlock/html/close.html b/sherlock/html/close.html index 78f0cacad..4eb1107fc 100644 --- a/sherlock/html/close.html +++ b/sherlock/html/close.html @@ -3,7 +3,7 @@ Sherlock - + diff --git a/sherlock/html/logged-out.html b/sherlock/html/logged-out.html new file mode 100644 index 000000000..ca6ef777b --- /dev/null +++ b/sherlock/html/logged-out.html @@ -0,0 +1,9 @@ + + + + Sherlock - Logged Out + +

+ You've been logged out. You may close this window. +

+ diff --git a/sherlock/internal/api/login/handler_test.go b/sherlock/internal/api/login/handler_test.go new file mode 100644 index 000000000..516e636b5 --- /dev/null +++ b/sherlock/internal/api/login/handler_test.go @@ -0,0 +1,37 @@ +package login + +import ( + "github.com/broadinstitute/sherlock/sherlock/internal/middleware/authentication" + "github.com/broadinstitute/sherlock/sherlock/internal/middleware/authentication/test_users" + "github.com/broadinstitute/sherlock/sherlock/internal/models" + "github.com/broadinstitute/sherlock/sherlock/internal/oidc_models" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/suite" + "testing" +) + +type handlerSuite struct { + suite.Suite + test_users.TestUserHelper + models.TestSuiteHelper + oidc_models.TestClientHelper + + internalRouter *gin.Engine +} + +func TestLoginSuite(t *testing.T) { + suite.Run(t, new(handlerSuite)) +} + +func (s *handlerSuite) SetupSuite() { + s.TestSuiteHelper.SetupSuite() + // Reduces console output + gin.SetMode(gin.TestMode) +} + +func (s *handlerSuite) SetupTest() { + s.TestSuiteHelper.SetupTest() + s.internalRouter = gin.New() + s.internalRouter.Use(authentication.TestMiddleware(s.DB, s.TestData)...) + s.internalRouter.GET("/login", LoginGet) +} diff --git a/sherlock/internal/api/login/login.go b/sherlock/internal/api/login/login.go new file mode 100644 index 000000000..c467c14c6 --- /dev/null +++ b/sherlock/internal/api/login/login.go @@ -0,0 +1,55 @@ +package login + +import ( + "database/sql" + "fmt" + "github.com/broadinstitute/sherlock/sherlock/internal/errors" + "github.com/broadinstitute/sherlock/sherlock/internal/middleware/authentication" + "github.com/broadinstitute/sherlock/sherlock/internal/oidc_models" + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "gorm.io/gorm/clause" + "net/http" + "time" +) + +// LoginGet is meant to handle redirects to /login?id=... from the OIDC subsystem, +// read the IAP info from the request, "log the user in", and redirect back to the +// OIDC subsystem. +// +// This isn't in Swagger as it's not an API really -- it's meant to play nice with +// browsers. +func LoginGet(ctx *gin.Context) { + user, err := authentication.MustUseUser(ctx) + if err != nil { + return + } + + db, err := authentication.MustUseDB(ctx) + if err != nil { + return + } + + authRequestID := ctx.Query("id") + if authRequestID == "" { + errors.AbortRequest(ctx, fmt.Errorf("(%s) no auth request ID passed", errors.BadRequest)) + return + } + + parsedAuthRequestID, err := uuid.Parse(authRequestID) + if err != nil { + errors.AbortRequest(ctx, fmt.Errorf("(%s) invalid auth request ID passed", errors.BadRequest)) + return + } + + err = db.Omit(clause.Associations).Where(&oidc_models.AuthRequest{ID: parsedAuthRequestID}).Where("done_at is null").Updates(&oidc_models.AuthRequest{ + UserID: &user.ID, + DoneAt: sql.NullTime{Time: time.Now(), Valid: true}, + }).Error + if err != nil { + errors.AbortRequest(ctx, fmt.Errorf("could not update auth request: %w", err)) + return + } + + ctx.Redirect(http.StatusFound, fmt.Sprintf("/oidc/authorize/callback?id=%s", authRequestID)) +} diff --git a/sherlock/internal/api/login/login_test.go b/sherlock/internal/api/login/login_test.go new file mode 100644 index 000000000..56fa18fa1 --- /dev/null +++ b/sherlock/internal/api/login/login_test.go @@ -0,0 +1,58 @@ +package login + +import ( + "github.com/broadinstitute/sherlock/sherlock/internal/oidc_models" + "github.com/google/uuid" + "github.com/zitadel/oidc/v3/pkg/oidc" + "gorm.io/gorm/clause" + "net/http" + "net/http/httptest" +) + +func (s *handlerSuite) TestLoginGet_noAuthRequestID() { + request, err := http.NewRequest("GET", "/login", nil) + s.NoError(err) + recorder := httptest.NewRecorder() + s.internalRouter.ServeHTTP(recorder, request) + + s.Equal(http.StatusBadRequest, recorder.Code) +} + +func (s *handlerSuite) TestLoginGet_invalidAuthRequestID() { + request, err := http.NewRequest("GET", "/login?id=invalid", nil) + s.NoError(err) + recorder := httptest.NewRecorder() + s.internalRouter.ServeHTTP(recorder, request) + + s.Equal(http.StatusBadRequest, recorder.Code) +} + +func (s *handlerSuite) TestLoginGet() { + clientID, _, err := s.GenerateClient(s.DB) + s.NoError(err) + + authRequest := oidc_models.AuthRequest{ + ID: uuid.New(), + ClientID: clientID, + Nonce: "some-nonce", + RedirectURI: s.GeneratedClientRedirectURI(), + Scopes: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, "groups"}, + State: "some-state", + } + + s.NoError(s.DB.Omit(clause.Associations).Create(&authRequest).Error) + + request, err := http.NewRequest("GET", "/login?id="+authRequest.GetID(), nil) + s.NoError(err) + s.UseSuitableUserFor(request) + recorder := httptest.NewRecorder() + s.internalRouter.ServeHTTP(recorder, request) + + s.Equal(http.StatusFound, recorder.Code) + + // Check that the auth request was marked as done + var reloadedAuthRequest oidc_models.AuthRequest + s.NoError(s.DB.Where("id = ?", authRequest.ID.String()).First(&reloadedAuthRequest).Error) + s.True(reloadedAuthRequest.DoneAt.Valid) + s.Equal(s.TestData.User_Suitable().ID, *reloadedAuthRequest.UserID) +} diff --git a/sherlock/internal/api/sherlock/chart_releases_v3_test.go b/sherlock/internal/api/sherlock/chart_releases_v3_test.go index ea9de4a7a..29bca0cdf 100644 --- a/sherlock/internal/api/sherlock/chart_releases_v3_test.go +++ b/sherlock/internal/api/sherlock/chart_releases_v3_test.go @@ -3,6 +3,7 @@ package sherlock import ( "fmt" "github.com/broadinstitute/sherlock/go-shared/pkg/utils" + "github.com/broadinstitute/sherlock/sherlock/internal/config" "github.com/broadinstitute/sherlock/sherlock/internal/models" "github.com/stretchr/testify/assert" "gorm.io/gorm" @@ -345,6 +346,7 @@ func (s *handlerSuite) TestChartReleaseV3_toModel() { } func Test_chartReleaseFromModel(t *testing.T) { + config.LoadTestConfig() now := time.Now() type args struct { model models.ChartRelease @@ -411,8 +413,8 @@ func Test_chartReleaseFromModel(t *testing.T) { }, CiIdentifier: &CiIdentifierV3{CommonFields: CommonFields{ID: 2}}, ChartInfo: &ChartV3{CommonFields: CommonFields{ID: 3}, ChartV3Create: ChartV3Create{Name: "leonardo"}}, - ClusterInfo: &ClusterV3{CommonFields: CommonFields{ID: 4}, ClusterV3Create: ClusterV3Create{Name: "terra-prod"}}, - EnvironmentInfo: &EnvironmentV3{CommonFields: CommonFields{ID: 5}, EnvironmentV3Create: EnvironmentV3Create{Name: "prod"}}, + ClusterInfo: &ClusterV3{CommonFields: CommonFields{ID: 4}, ClusterV3Create: ClusterV3Create{Name: "terra-prod", ClusterV3Edit: ClusterV3Edit{RequiredRole: utils.PointerTo(config.Config.String("model.roles.substituteEmptyRequiredRoleWithValue"))}}}, + EnvironmentInfo: &EnvironmentV3{CommonFields: CommonFields{ID: 5}, EnvironmentV3Create: EnvironmentV3Create{Name: "prod", EnvironmentV3Edit: EnvironmentV3Edit{RequiredRole: utils.PointerTo(config.Config.String("model.roles.substituteEmptyRequiredRoleWithValue"))}}}, AppVersionReference: "leonardo/v1.0.0", AppVersionInfo: &AppVersionV3{CommonFields: CommonFields{ID: 8}, AppVersionV3Create: AppVersionV3Create{AppVersion: "v1.0.0"}}, ChartVersionReference: "leonardo/2.0.0", diff --git a/sherlock/internal/api/sherlock/clusters_v3.go b/sherlock/internal/api/sherlock/clusters_v3.go index 87e4647fd..e0259e06a 100644 --- a/sherlock/internal/api/sherlock/clusters_v3.go +++ b/sherlock/internal/api/sherlock/clusters_v3.go @@ -3,6 +3,7 @@ package sherlock import ( "fmt" "github.com/broadinstitute/sherlock/go-shared/pkg/utils" + "github.com/broadinstitute/sherlock/sherlock/internal/config" "github.com/broadinstitute/sherlock/sherlock/internal/models" "gorm.io/gorm" ) @@ -90,6 +91,8 @@ func clusterFromModel(model models.Cluster) ClusterV3 { ret.RequiredRole = model.RequiredRole.Name } else if model.RequiredRoleID != nil { ret.RequiredRole = utils.PointerTo(utils.UintToString(*model.RequiredRoleID)) + } else if substituteEmptyRole := config.Config.String("model.roles.substituteEmptyRequiredRoleWithValue"); substituteEmptyRole != "" { + ret.RequiredRole = &substituteEmptyRole } return ret } diff --git a/sherlock/internal/api/sherlock/clusters_v3_edit_test.go b/sherlock/internal/api/sherlock/clusters_v3_edit_test.go index 46755fd28..4165252b1 100644 --- a/sherlock/internal/api/sherlock/clusters_v3_edit_test.go +++ b/sherlock/internal/api/sherlock/clusters_v3_edit_test.go @@ -3,6 +3,7 @@ package sherlock import ( "fmt" "github.com/broadinstitute/sherlock/go-shared/pkg/utils" + "github.com/broadinstitute/sherlock/sherlock/internal/config" "github.com/broadinstitute/sherlock/sherlock/internal/errors" "github.com/broadinstitute/sherlock/sherlock/internal/models" "github.com/gin-gonic/gin" @@ -182,7 +183,9 @@ func (s *handlerSuite) TestClusterV3Edit_clearRequiredRole() { }), &got) s.Equal(http.StatusOK, code) - s.Nil(got.RequiredRole) + if s.NotNil(got.RequiredRole) { + s.Equal(config.Config.String("model.roles.substituteEmptyRequiredRoleWithValue"), *got.RequiredRole) + } var inDB models.Cluster s.NoError(s.DB.First(&inDB, toEdit.ID).Error) s.Nil(inDB.RequiredRoleID) diff --git a/sherlock/internal/api/sherlock/environments_v3.go b/sherlock/internal/api/sherlock/environments_v3.go index 13b3b396c..03d49a14c 100644 --- a/sherlock/internal/api/sherlock/environments_v3.go +++ b/sherlock/internal/api/sherlock/environments_v3.go @@ -4,6 +4,7 @@ import ( "database/sql" "fmt" "github.com/broadinstitute/sherlock/go-shared/pkg/utils" + "github.com/broadinstitute/sherlock/sherlock/internal/config" "github.com/broadinstitute/sherlock/sherlock/internal/errors" "github.com/broadinstitute/sherlock/sherlock/internal/models" "github.com/google/uuid" @@ -232,6 +233,8 @@ func environmentFromModel(model models.Environment) EnvironmentV3 { ret.RequiredRole = model.RequiredRole.Name } else if model.RequiredRoleID != nil { ret.RequiredRole = utils.PointerTo(utils.UintToString(*model.RequiredRoleID)) + } else if substituteEmptyRole := config.Config.String("model.roles.substituteEmptyRequiredRoleWithValue"); substituteEmptyRole != "" { + ret.RequiredRole = &substituteEmptyRole } return ret } diff --git a/sherlock/internal/api/sherlock/environments_v3_edit_test.go b/sherlock/internal/api/sherlock/environments_v3_edit_test.go index 8af9da544..6ff176d16 100644 --- a/sherlock/internal/api/sherlock/environments_v3_edit_test.go +++ b/sherlock/internal/api/sherlock/environments_v3_edit_test.go @@ -3,6 +3,7 @@ package sherlock import ( "fmt" "github.com/broadinstitute/sherlock/go-shared/pkg/utils" + "github.com/broadinstitute/sherlock/sherlock/internal/config" "github.com/broadinstitute/sherlock/sherlock/internal/errors" "github.com/broadinstitute/sherlock/sherlock/internal/models" "github.com/gin-gonic/gin" @@ -129,7 +130,9 @@ func (s *handlerSuite) TestEnvironmentsV3Edit_clearRequiredRole() { }), &got) s.Equal(http.StatusOK, code) - s.Nil(got.RequiredRole) + if s.NotNil(got.RequiredRole) { + s.Equal(config.Config.String("model.roles.substituteEmptyRequiredRoleWithValue"), *got.RequiredRole) + } var inDB models.Environment s.NoError(s.DB.First(&inDB, toEdit.ID).Error) s.Nil(inDB.RequiredRoleID) diff --git a/sherlock/internal/api/sherlock/environments_v3_test.go b/sherlock/internal/api/sherlock/environments_v3_test.go index 6aa037036..ff4de6377 100644 --- a/sherlock/internal/api/sherlock/environments_v3_test.go +++ b/sherlock/internal/api/sherlock/environments_v3_test.go @@ -3,6 +3,7 @@ package sherlock import ( "database/sql" "github.com/broadinstitute/sherlock/go-shared/pkg/utils" + "github.com/broadinstitute/sherlock/sherlock/internal/config" "github.com/broadinstitute/sherlock/sherlock/internal/models" "github.com/google/uuid" "github.com/stretchr/testify/assert" @@ -310,6 +311,7 @@ func (s *handlerSuite) TestEnvironmentV3_toModel() { } func Test_environmentFromModel(t *testing.T) { + config.LoadTestConfig() now := time.Now() nowString := utils.TimePtrToISO8601(&now) nowTimeParsedAgain, err := utils.ISO8601PtrToTime(nowString) @@ -326,7 +328,13 @@ func Test_environmentFromModel(t *testing.T) { { name: "empty", args: args{}, - want: EnvironmentV3{}, + want: EnvironmentV3{ + EnvironmentV3Create: EnvironmentV3Create{ + EnvironmentV3Edit: EnvironmentV3Edit{ + RequiredRole: utils.PointerTo(config.Config.String("model.roles.substituteEmptyRequiredRoleWithValue")), + }, + }, + }, }, { name: "full", @@ -380,8 +388,8 @@ func Test_environmentFromModel(t *testing.T) { UpdatedAt: now.Add(-time.Minute), }, CiIdentifier: &CiIdentifierV3{CommonFields: CommonFields{ID: 2}}, - TemplateEnvironmentInfo: &EnvironmentV3{CommonFields: CommonFields{ID: 3}, EnvironmentV3Create: EnvironmentV3Create{Name: "name-3"}}, - DefaultClusterInfo: &ClusterV3{CommonFields: CommonFields{ID: 4}, ClusterV3Create: ClusterV3Create{Name: "name-4"}}, + TemplateEnvironmentInfo: &EnvironmentV3{CommonFields: CommonFields{ID: 3}, EnvironmentV3Create: EnvironmentV3Create{Name: "name-3", EnvironmentV3Edit: EnvironmentV3Edit{RequiredRole: utils.PointerTo(config.Config.String("model.roles.substituteEmptyRequiredRoleWithValue"))}}}, + DefaultClusterInfo: &ClusterV3{CommonFields: CommonFields{ID: 4}, ClusterV3Create: ClusterV3Create{Name: "name-4", ClusterV3Edit: ClusterV3Edit{RequiredRole: utils.PointerTo(config.Config.String("model.roles.substituteEmptyRequiredRoleWithValue"))}}}, PagerdutyIntegrationInfo: &PagerdutyIntegrationV3{CommonFields: CommonFields{ID: 6}, PagerdutyID: "blah"}, OwnerInfo: &UserV3{CommonFields: CommonFields{ID: 5}, Email: "example@example.com", Suitable: utils.PointerTo(false), SuitabilityDescription: utils.PointerTo("no matching suitability record found or loaded; assuming unsuitable")}, @@ -428,6 +436,7 @@ func Test_environmentFromModel(t *testing.T) { EnvironmentV3Create: EnvironmentV3Create{ EnvironmentV3Edit: EnvironmentV3Edit{ PagerdutyIntegration: utils.PointerTo("6"), + RequiredRole: utils.PointerTo(config.Config.String("model.roles.substituteEmptyRequiredRoleWithValue")), }, }, }, diff --git a/sherlock/internal/boot/application.go b/sherlock/internal/boot/application.go index 97c3d09d8..ead5651f7 100644 --- a/sherlock/internal/boot/application.go +++ b/sherlock/internal/boot/application.go @@ -12,6 +12,7 @@ import ( "github.com/broadinstitute/sherlock/sherlock/internal/metrics" "github.com/broadinstitute/sherlock/sherlock/internal/middleware/authentication/gha_oidc" "github.com/broadinstitute/sherlock/sherlock/internal/models" + "github.com/broadinstitute/sherlock/sherlock/internal/oidc_models" "github.com/broadinstitute/sherlock/sherlock/internal/role_propagation" "github.com/broadinstitute/sherlock/sherlock/internal/suitability_synchronization" "github.com/gin-gonic/gin" @@ -136,6 +137,15 @@ func (a *Application) Start() { } } + if config.Config.Bool("oidc.enable") { + log.Info().Msgf("BOOT | initializing OIDC provider...") + if err = oidc_models.Init(ctx, a.gormDB); err != nil { + log.Fatal().Err(err).Msgf("oidc_models.Init() error") + } + go oidc_models.KeepSigningKeysRotated(ctx, a.gormDB) + go oidc_models.KeepExpiringRefreshTokens(ctx, a.gormDB) + } + go models.KeepAutoAssigningRoles(ctx, a.gormDB) log.Info().Msgf("BOOT | building Gin router...") diff --git a/sherlock/internal/boot/router.go b/sherlock/internal/boot/router.go index 26ca31f99..5bd99b83d 100644 --- a/sherlock/internal/boot/router.go +++ b/sherlock/internal/boot/router.go @@ -6,6 +6,7 @@ import ( "github.com/broadinstitute/sherlock/go-shared/pkg/version" "github.com/broadinstitute/sherlock/sherlock/docs" "github.com/broadinstitute/sherlock/sherlock/html" + "github.com/broadinstitute/sherlock/sherlock/internal/api/login" "github.com/broadinstitute/sherlock/sherlock/internal/api/misc" "github.com/broadinstitute/sherlock/sherlock/internal/api/sherlock" "github.com/broadinstitute/sherlock/sherlock/internal/config" @@ -16,11 +17,13 @@ import ( "github.com/broadinstitute/sherlock/sherlock/internal/middleware/csrf_protection" "github.com/broadinstitute/sherlock/sherlock/internal/middleware/headers" "github.com/broadinstitute/sherlock/sherlock/internal/middleware/logger" + "github.com/broadinstitute/sherlock/sherlock/internal/oidc_models" "github.com/gin-gonic/gin" swaggo_files "github.com/swaggo/files" swaggo_gin "github.com/swaggo/gin-swagger" "gorm.io/gorm" "net/http" + "strings" ) // @title Sherlock @@ -57,6 +60,10 @@ func BuildRouter(ctx context.Context, db *gorm.DB) *gin.Engine { cors.Cors(), headers.Headers()) + resourceMiddleware := make(gin.HandlersChain, 0) + resourceMiddleware = append(resourceMiddleware, csrf_protection.CsrfProtection()) + resourceMiddleware = append(resourceMiddleware, authentication.Middleware(db)...) + // Replace Gin's standard fallback responses with our standard error format for friendlier client behavior router.NoRoute(func(ctx *gin.Context) { errors.AbortRequest(ctx, fmt.Errorf("(%s) no handler for %s found", errors.NotFound, ctx.Request.URL.Path)) @@ -80,16 +87,27 @@ func BuildRouter(ctx context.Context, db *gorm.DB) *gin.Engine { })) router.GET("", func(ctx *gin.Context) { ctx.Redirect(http.StatusMovedPermanently, "/swagger/index.html") }) + if config.Config.Bool("oidc.enable") { + // delegate /oidc/* to OIDC library, trimming the path prefix because of how the library expects to receive requests + // https://broadinstitute.slack.com/archives/CQ6SL4N5T/p1721406732128199 + router.Any("/oidc/*any", func(ctx *gin.Context) { + req := ctx.Request.Clone(ctx) + req.RequestURI = strings.TrimPrefix(req.RequestURI, "/oidc") + req.URL.Path = strings.TrimPrefix(req.URL.Path, "/oidc") + oidc_models.Provider.ServeHTTP(ctx.Writer, req) + }) + // authenticate /login handler to complete OIDC auth requests + router.GET("/login", append(resourceMiddleware, login.LoginGet)...) + } + // routes under /api require authentication and may use the database - apiRouter := router.Group("api") - apiRouter.Use(csrf_protection.CsrfProtection()) - apiRouter.Use(authentication.Middleware(db)...) + apiRouter := router.Group("/api", resourceMiddleware...) // refactored sherlock API, under /api/{type}/v3 sherlock.ConfigureRoutes(apiRouter) // special error for the removed "v2" API, under /api/v2/{type} - apiRouter.Any("v2/*path", func(ctx *gin.Context) { + apiRouter.Any("/v2/*path", func(ctx *gin.Context) { errors.AbortRequest(ctx, fmt.Errorf("(%s) sherlock's v2 API has been removed; reach out to #dsp-devops-champions for help updating your client", errors.NotFound)) }) diff --git a/sherlock/internal/config/config.go b/sherlock/internal/config/config.go index faee1afe7..ef8f80e3b 100644 --- a/sherlock/internal/config/config.go +++ b/sherlock/internal/config/config.go @@ -112,6 +112,7 @@ func configureLogging(infoMessages ...string) { // log messages using Go's built-in logging. We can at least format those messages // correctly by redirecting that into zerolog, though it won't have proper leveling // information + stdlog.SetFlags(0) stdlog.SetOutput(log.Logger) if logLevel := Config.String("log.level"); logLevel != "" { diff --git a/sherlock/internal/models/user.go b/sherlock/internal/models/user.go index 5e7313134..1e1c240c1 100644 --- a/sherlock/internal/models/user.go +++ b/sherlock/internal/models/user.go @@ -153,11 +153,11 @@ func (u *User) AlphaNumericHyphenatedUsername() string { return string(ret) } -func (u *User) NameOrEmailHandle() string { +func (u *User) NameOrUsername() string { if u.Name != nil { return *u.Name } else { - return strings.Split(u.Email, "@")[0] + return u.AlphaNumericHyphenatedUsername() } } @@ -165,7 +165,7 @@ func (u *User) SlackReference(mention bool) string { if u.SlackID != nil && mention { return fmt.Sprintf("<@%s>", *u.SlackID) } else { - return fmt.Sprintf("", u.Email, u.NameOrEmailHandle()) + return fmt.Sprintf("", u.Email, u.NameOrUsername()) } } diff --git a/sherlock/internal/oidc_models/auth_request.go b/sherlock/internal/oidc_models/auth_request.go new file mode 100644 index 000000000..75fa4f151 --- /dev/null +++ b/sherlock/internal/oidc_models/auth_request.go @@ -0,0 +1,105 @@ +package oidc_models + +import ( + "database/sql" + "github.com/broadinstitute/sherlock/go-shared/pkg/utils" + "github.com/broadinstitute/sherlock/sherlock/internal/models" + "github.com/google/uuid" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" + "time" +) + +// "AuthRequest implements the op.AuthRequest interface" +var _ op.AuthRequest = &AuthRequest{} + +type AuthRequest struct { + ID uuid.UUID `gorm:"primaryKey"` + CreatedAt time.Time + DoneAt sql.NullTime + + Client *Client `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE"` + ClientID string // AKA Audience, Application ID + Nonce string + RedirectURI string // AKA CallbackURI + ResponseType oidc.ResponseType `gorm:"string"` + ResponseMode oidc.ResponseMode `gorm:"string"` + Scopes oidc.SpaceDelimitedArray + State string + + CodeChallenge string + CodeChallengeMethod oidc.CodeChallengeMethod `gorm:"string"` + + // The user won't be filled until the request has been logged-in to + User *models.User + UserID *uint +} + +func (r *AuthRequest) GetID() string { + return r.ID.String() +} + +func (r *AuthRequest) GetACR() string { + return "" +} + +func (r *AuthRequest) GetAMR() []string { + // Return an empty array because we don't know for sure what AMRs IAP enforced on the caller. + // https://openid.net/specs/openid-connect-core-1_0.html#IDToken + // https://www.rfc-editor.org/info/rfc8176 + return []string{} +} + +func (r *AuthRequest) GetAudience() []string { + return []string{r.ClientID} +} + +func (r *AuthRequest) GetAuthTime() time.Time { + return r.CreatedAt +} + +func (r *AuthRequest) GetClientID() string { + return r.ClientID +} + +func (r *AuthRequest) GetCodeChallenge() *oidc.CodeChallenge { + return &oidc.CodeChallenge{ + Challenge: r.CodeChallenge, + Method: r.CodeChallengeMethod, + } +} + +func (r *AuthRequest) GetNonce() string { + return r.Nonce +} + +func (r *AuthRequest) GetRedirectURI() string { + return r.RedirectURI +} + +func (r *AuthRequest) GetResponseType() oidc.ResponseType { + return r.ResponseType +} + +func (r *AuthRequest) GetResponseMode() oidc.ResponseMode { + return r.ResponseMode +} + +func (r *AuthRequest) GetScopes() []string { + return r.Scopes +} + +func (r *AuthRequest) GetState() string { + return r.State +} + +func (r *AuthRequest) GetSubject() string { + if r.UserID == nil { + return "" + } + return utils.UintToString(*r.UserID) +} + +func (r *AuthRequest) Done() bool { + return r.DoneAt.Valid +} diff --git a/sherlock/internal/oidc_models/auth_request_code.go b/sherlock/internal/oidc_models/auth_request_code.go new file mode 100644 index 000000000..8418c4059 --- /dev/null +++ b/sherlock/internal/oidc_models/auth_request_code.go @@ -0,0 +1,13 @@ +package oidc_models + +import ( + "github.com/google/uuid" + "time" +) + +type AuthRequestCode struct { + Code string `gorm:"primaryKey"` + CreatedAt time.Time + AuthRequestID uuid.UUID + AuthRequest *AuthRequest `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE"` +} diff --git a/sherlock/internal/oidc_models/boot.go b/sherlock/internal/oidc_models/boot.go new file mode 100644 index 000000000..a37351526 --- /dev/null +++ b/sherlock/internal/oidc_models/boot.go @@ -0,0 +1,72 @@ +package oidc_models + +import ( + kms "cloud.google.com/go/kms/apiv1" + "cloud.google.com/go/kms/apiv1/kmspb" + "context" + "fmt" + "github.com/broadinstitute/sherlock/sherlock/internal/clients/slack" + "github.com/broadinstitute/sherlock/sherlock/internal/config" + "gorm.io/gorm" + "time" +) + +var ( + kmsKey string + kmsClient *kms.KeyManagementClient +) + +func Init(ctx context.Context, db *gorm.DB) error { + if config.Config.Bool("oidc.signingKeyEncryptionKMSEnable") { + kmsKey = config.Config.String("oidc.signingKeyEncryptionKMSKeyName") + var err error + kmsClient, err = kms.NewKeyManagementClient(ctx) + if err != nil { + return fmt.Errorf("error creating KMS client: %w", err) + } + response, err := kmsClient.GetCryptoKey(ctx, &kmspb.GetCryptoKeyRequest{ + Name: kmsKey, + }) + if err != nil { + return fmt.Errorf("error getting KMS key '%s': %w", kmsKey, err) + } else if response.Purpose != kmspb.CryptoKey_ENCRYPT_DECRYPT { + return fmt.Errorf("KMS key '%s' is not an encrypt/decrypt key", kmsKey) + } + } else if config.Config.String("mode") != "debug" { + return fmt.Errorf("oidc.signingKeyEncryptionKMSEnable is false, but mode is not debug") + } + + if err := rotateSigningKeys(ctx, db); err != nil { + return fmt.Errorf("error rotating oidc signing keys: %w", err) + } + + return initProvider(db) +} + +func KeepSigningKeysRotated(ctx context.Context, db *gorm.DB) { + for { + select { + case <-ctx.Done(): + return + case <-time.After(15 * time.Minute): + err := rotateSigningKeys(ctx, db) + if err != nil { + slack.ReportError(ctx, "error rotating oidc signing keys", err) + } + } + } +} + +func KeepExpiringRefreshTokens(ctx context.Context, db *gorm.DB) { + for { + select { + case <-ctx.Done(): + return + case <-time.After(time.Minute): + err := expireRefreshTokens(db) + if err != nil { + slack.ReportError(ctx, "error expiring refresh tokens", err) + } + } + } +} diff --git a/sherlock/internal/oidc_models/client.go b/sherlock/internal/oidc_models/client.go new file mode 100644 index 000000000..25f165808 --- /dev/null +++ b/sherlock/internal/oidc_models/client.go @@ -0,0 +1,124 @@ +package oidc_models + +import ( + "fmt" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" + "time" +) + +// "Client implements the op.Client interface" +var _ op.Client = &Client{} + +type Client struct { + ID string `gorm:"primaryKey"` + ClientSecretHash []byte // PBKDF2 derived key, HMAC-SHA-512; should be empty for PKCE; Sherlock will derive the same number of bytes as the length of this field automatically + ClientSecretSalt []byte // Salt for ClientSecretHash; should be empty for PKCE + ClientSecretIterations int // Number of iterations for ClientSecretHash; should be empty for PKCE. https://cheatsheetseries.owasp.org/cheatsheets/Password_Storage_Cheat_Sheet.html#pbkdf2 + ClientRedirectURIs oidc.SpaceDelimitedArray // In dev mode, may include globs or use http + ClientPostLogoutRedirectURIs oidc.SpaceDelimitedArray // In dev mode, may include globs or use http + ClientApplicationType op.ApplicationType // "web", "user_agent", or "native" + ClientAuthMethod oidc.AuthMethod // "client_secret_basic", "client_secret_post", "none", or "private_key_jwt" + ClientIDTokenLifetime int64 // Nanoseconds + ClientDevMode bool // Allow insecure/nonstandard redirect URIs, like globs and http + ClientClockSkew int64 // Nanoseconds +} + +func (c *Client) GetID() string { + return c.ID +} + +func (c *Client) RedirectURIs() []string { + return c.ClientRedirectURIs +} + +func (c *Client) PostLogoutRedirectURIs() []string { + return c.ClientPostLogoutRedirectURIs +} + +func (c *Client) ApplicationType() op.ApplicationType { + return c.ClientApplicationType +} + +func (c *Client) AuthMethod() oidc.AuthMethod { + return c.ClientAuthMethod +} + +func (c *Client) ResponseTypes() []oidc.ResponseType { + return []oidc.ResponseType{oidc.ResponseTypeCode, oidc.ResponseTypeIDTokenOnly, oidc.ResponseTypeIDToken} +} + +func (c *Client) GrantTypes() []oidc.GrantType { + return []oidc.GrantType{oidc.GrantTypeCode, oidc.GrantTypeRefreshToken} +} + +func (c *Client) LoginURL(authRequestID string) string { + // To sherlock/sherlock/internal/api/login/login.go + return fmt.Sprintf("/login?id=%s", authRequestID) +} + +func (c *Client) AccessTokenType() op.AccessTokenType { + return op.AccessTokenTypeJWT +} + +func (c *Client) IDTokenLifetime() time.Duration { + return time.Duration(c.ClientIDTokenLifetime) +} + +func (c *Client) DevMode() bool { + return c.ClientDevMode +} + +func (c *Client) IDTokenUserinfoClaimsAssertion() bool { + // It technically violates spec to return userinfo claims in ID tokens when an access token is provided. + // The caller is expected to call the userinfo endpoint with the access token to get the userinfo data. + // But... Dex supports and actually expects these claims in the initial ID token. It can be configured + // to call the userinfo endpoint, but why make the extra roundtrip? + // + // In the future we could always have this be a bool field on Client itself to customize the behavior + // per client. + return true +} + +func (c *Client) ClockSkew() time.Duration { + return time.Duration(c.ClientClockSkew) +} + +func (c *Client) RestrictAdditionalIdTokenScopes() func(scopes []string) []string { + return func(scopes []string) []string { return scopes } +} + +func (c *Client) RestrictAdditionalAccessTokenScopes() func(scopes []string) []string { + // Hardcode no client-level extra restrictions on access token scopes + return func(scopes []string) []string { return scopes } +} + +func (c *Client) IsScopeAllowed(scope string) bool { + // Hardcode that groups are allowed, because currently that's the whole point of Sherlock doing OIDC + return scope == groupsClaim +} + +// "devModeGlobClient implements the op.HasRedirectGlobs interface" +var _ op.HasRedirectGlobs = &devModeGlobClient{} + +type devModeGlobClient struct { + Client +} + +func (c *devModeGlobClient) RedirectURIGlobs() []string { + return c.RedirectURIs() +} + +func (c *devModeGlobClient) PostLogoutRedirectURIGlobs() []string { + return c.PostLogoutRedirectURIs() +} + +// wrapPossibleDevModeClient takes a Client from the database returns an op.Client, possibly wrapped with +// devModeGlobClient based on Client.ClientDevMode. +func wrapPossibleDevModeClient(client Client) op.Client { + if client.DevMode() { + return &devModeGlobClient{client} + } else { + return &client + } +} diff --git a/sherlock/internal/oidc_models/oidc_models_test.go b/sherlock/internal/oidc_models/oidc_models_test.go new file mode 100644 index 000000000..fffda2aef --- /dev/null +++ b/sherlock/internal/oidc_models/oidc_models_test.go @@ -0,0 +1,23 @@ +package oidc_models + +import ( + "github.com/broadinstitute/sherlock/sherlock/internal/models" + "github.com/stretchr/testify/suite" + "testing" +) + +type oidcModelsSuite struct { + suite.Suite + models.TestSuiteHelper + TestClientHelper + storage *storageImpl +} + +func TestOidcModelsSuite(t *testing.T) { + suite.Run(t, new(oidcModelsSuite)) +} + +func (s *oidcModelsSuite) SetupTest() { + s.TestSuiteHelper.SetupTest() + s.storage = &storageImpl{db: s.DB} +} diff --git a/sherlock/internal/oidc_models/provider.go b/sherlock/internal/oidc_models/provider.go new file mode 100644 index 000000000..7bf7cd3e7 --- /dev/null +++ b/sherlock/internal/oidc_models/provider.go @@ -0,0 +1,43 @@ +package oidc_models + +import ( + "encoding/hex" + "fmt" + "github.com/broadinstitute/sherlock/sherlock/internal/config" + "github.com/zitadel/oidc/v3/pkg/op" + "golang.org/x/text/language" + "gorm.io/gorm" +) + +var Provider op.OpenIDProvider + +func initProvider(db *gorm.DB) error { + key, err := hex.DecodeString(config.Config.String("oidc.encryptionKeyHex")) + if err != nil { + return fmt.Errorf("could not decode oidc.encryptionKeyHex: %w", err) + } else if len(key) != 32 { + return fmt.Errorf("oidc.encryptionKeyHex must be 32 bytes long; got %d", len(key)) + } + + storage := &storageImpl{db: db} + conf := &op.Config{ + CryptoKey: ([32]byte)(key), + DefaultLogoutRedirectURI: "/static/logged-out.html", + CodeMethodS256: true, // Enable PKCE and S256 code challenge method + GrantTypeRefreshToken: true, // Allow refresh token grant user + SupportedUILocales: []language.Tag{language.AmericanEnglish}, + SupportedClaims: append(op.DefaultSupportedClaims, groupsClaim), // Technically more than we provide but better a superset than subset + } + options := []op.Option{ + op.WithCORSOptions(nil), + // If we expand our usage of the OIDC subsystem such that users need to connect to other endpoints, + // we may need to add custom URLs for those endpoints here. + op.WithCustomAuthEndpoint(op.NewEndpointWithURL("/authorize", config.Config.String("oidc.publicIssuerUrl")+"/authorize")), + } + if config.Config.String("mode") == "debug" { + options = append(options, op.WithAllowInsecure()) + } + + Provider, err = op.NewProvider(conf, storage, op.StaticIssuer(config.Config.String("oidc.issuerUrl")), options...) + return err +} diff --git a/sherlock/internal/oidc_models/refresh_token.go b/sherlock/internal/oidc_models/refresh_token.go new file mode 100644 index 000000000..1d584ca9a --- /dev/null +++ b/sherlock/internal/oidc_models/refresh_token.go @@ -0,0 +1,72 @@ +package oidc_models + +import ( + "errors" + "github.com/broadinstitute/sherlock/go-shared/pkg/utils" + "github.com/broadinstitute/sherlock/sherlock/internal/config" + "github.com/broadinstitute/sherlock/sherlock/internal/models" + "github.com/google/uuid" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" + "gorm.io/gorm" + "gorm.io/gorm/clause" + "time" +) + +var _ op.RefreshTokenRequest = &RefreshToken{} + +type RefreshToken struct { + ID uuid.UUID `gorm:"primaryKey"` + CreatedAt time.Time + + TokenHash []byte // SHA-512 hash + + Client *Client `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE"` + ClientID string // AKA Audience, Application ID + Scopes oidc.SpaceDelimitedArray + OriginalAuthAt time.Time + + User *models.User + UserID uint +} + +func expireRefreshTokens(db *gorm.DB) error { + err := db. + Omit(clause.Associations). + Where("created_at < ?", time.Now().Add(-config.Config.MustDuration("oidc.refreshTokenDuration"))). + Delete(&RefreshToken{}). + Error + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } else { + return nil + } +} + +func (r *RefreshToken) GetAMR() []string { + return []string{} +} + +func (r *RefreshToken) GetAudience() []string { + return []string{r.ClientID} +} + +func (r *RefreshToken) GetAuthTime() time.Time { + return r.OriginalAuthAt +} + +func (r *RefreshToken) GetClientID() string { + return r.ClientID +} + +func (r *RefreshToken) GetScopes() []string { + return r.Scopes +} + +func (r *RefreshToken) GetSubject() string { + return utils.UintToString(r.UserID) +} + +func (r *RefreshToken) SetCurrentScopes(scopes []string) { + r.Scopes = scopes +} diff --git a/sherlock/internal/oidc_models/signing_key.go b/sherlock/internal/oidc_models/signing_key.go new file mode 100644 index 000000000..566563998 --- /dev/null +++ b/sherlock/internal/oidc_models/signing_key.go @@ -0,0 +1,175 @@ +package oidc_models + +import ( + "cloud.google.com/go/kms/apiv1/kmspb" + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "fmt" + "github.com/broadinstitute/sherlock/sherlock/internal/config" + "github.com/go-jose/go-jose/v4" + "github.com/google/uuid" + "github.com/rs/zerolog/log" + "github.com/zitadel/oidc/v3/pkg/op" + "gorm.io/gorm" + "gorm.io/gorm/clause" + "time" +) + +type SigningKey struct { + ID uuid.UUID `gorm:"primaryKey"` + CreatedAt time.Time + PublicKey []byte + PrivateKey []byte +} + +func rotateSigningKeys(ctx context.Context, db *gorm.DB) error { + var validKeys []SigningKey + err := db. + Where("created_at > ?", time.Now(). + Add(-config.Config.MustDuration("oidc.signingKeyPrimaryDuration"))). + Omit(clause.Associations). + Find(&validKeys).Error + if err != nil { + return fmt.Errorf("error loading valid signing keys: %w", err) + } + if len(validKeys) == 0 { + newKey, err := saveNewSigningKey(ctx, db) + if err != nil { + return fmt.Errorf("error generating new signing key: %w", err) + } else { + log.Info().Msgf("OIDC | generated new signing key with ID %s", newKey.ID) + } + } + + var expiredKeys []SigningKey + err = db. + Where("created_at <= ?", time.Now(). + Add(-(config.Config.MustDuration("oidc.signingKeyPrimaryDuration"))). + Add(-config.Config.MustDuration("oidc.signingKeyPostRotationDuration"))). + Omit(clause.Associations). + Find(&expiredKeys).Error + if err != nil { + return fmt.Errorf("error loading expired signing keys: %w", err) + } + for _, key := range expiredKeys { + err = db.Omit(clause.Associations).Delete(&key).Error + if err != nil { + return fmt.Errorf("error deleting expired signing key: %w", err) + } + log.Info().Msgf("OIDC | deleted expired signing key with ID %s", key.ID) + } + return nil +} + +func saveNewSigningKey(ctx context.Context, db *gorm.DB) (*SigningKey, error) { + keyModel := &SigningKey{ + ID: uuid.New(), + } + + // Generate new private key + privateKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + if err != nil { + return nil, fmt.Errorf("error generating private key: %w", err) + } + + // Store public key plaintext in database for easy access + keyModel.PublicKey = elliptic.MarshalCompressed(privateKey.PublicKey.Curve, privateKey.PublicKey.X, privateKey.PublicKey.Y) + + // Store private key, encrypting with KMS if configured + privateKeyBytes, err := x509.MarshalECPrivateKey(privateKey) + if err != nil { + return nil, fmt.Errorf("error marshaling private key: %w", err) + } + if kmsKey != "" { + // Store encrypted + response, err := kmsClient.Encrypt(ctx, &kmspb.EncryptRequest{ + Name: kmsKey, + Plaintext: privateKeyBytes, + }) + if err != nil { + return nil, fmt.Errorf("error encrypting private key with KMS: %w", err) + } + keyModel.PrivateKey = response.Ciphertext + } else { + // Store plaintext + keyModel.PrivateKey = privateKeyBytes + } + + // Save database row + err = db.Create(keyModel).Error + if err != nil { + return nil, fmt.Errorf("error saving generated signing key to database: %w", err) + } + return keyModel, nil +} + +var _ op.Key = &publicSigningKey{} + +type publicSigningKey struct { + SigningKey +} + +func (p *publicSigningKey) ID() string { + return p.SigningKey.ID.String() +} + +func (p *publicSigningKey) Algorithm() jose.SignatureAlgorithm { + return jose.ES512 +} + +func (p *publicSigningKey) Use() string { + return "sig" +} + +func (p *publicSigningKey) Key() any { + x, y := elliptic.UnmarshalCompressed(elliptic.P521(), p.PublicKey) + return &ecdsa.PublicKey{ + Curve: elliptic.P521(), + X: x, + Y: y, + } +} + +var _ op.SigningKey = &decryptedPrivateSigningKey{} + +type decryptedPrivateSigningKey struct { + SigningKey + unencryptedPrivateKey *ecdsa.PrivateKey +} + +func (p *decryptedPrivateSigningKey) ID() string { + return p.SigningKey.ID.String() +} + +func (p *decryptedPrivateSigningKey) SignatureAlgorithm() jose.SignatureAlgorithm { + return jose.ES512 +} + +func (p *decryptedPrivateSigningKey) Key() any { + return p.unencryptedPrivateKey +} + +func decryptPrivateSigningKey(key *SigningKey) (*decryptedPrivateSigningKey, error) { + bytesToUnmarshall := key.PrivateKey + if kmsKey != "" { + response, err := kmsClient.Decrypt(context.Background(), &kmspb.DecryptRequest{ + Name: kmsKey, + Ciphertext: bytesToUnmarshall, + }) + if err != nil { + return nil, fmt.Errorf("error decrypting private signing key with KMS: %w", err) + } + bytesToUnmarshall = response.Plaintext + } + privateKey, err := x509.ParseECPrivateKey(bytesToUnmarshall) + if err != nil { + return nil, fmt.Errorf("error unmarshalling private signing key: %w", err) + } + return &decryptedPrivateSigningKey{ + SigningKey: *key, + unencryptedPrivateKey: privateKey, + }, nil +} diff --git a/sherlock/internal/oidc_models/storage.go b/sherlock/internal/oidc_models/storage.go new file mode 100644 index 000000000..aeabb3fe4 --- /dev/null +++ b/sherlock/internal/oidc_models/storage.go @@ -0,0 +1,581 @@ +package oidc_models + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/sha512" + "errors" + "fmt" + "github.com/broadinstitute/sherlock/go-shared/pkg/utils" + "github.com/broadinstitute/sherlock/sherlock/internal/config" + "github.com/broadinstitute/sherlock/sherlock/internal/models" + "github.com/go-jose/go-jose/v4" + "github.com/google/uuid" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" + "golang.org/x/crypto/pbkdf2" + "golang.org/x/text/language" + "gorm.io/gorm" + "gorm.io/gorm/clause" + "strings" + "time" +) + +const groupsClaim = "groups" + +var _ op.Storage = &storageImpl{} +var _ op.CanSetUserinfoFromRequest = &storageImpl{} + +type storageImpl struct { + db *gorm.DB +} + +func (s *storageImpl) CreateAuthRequest(_ context.Context, authRequest *oidc.AuthRequest, hintedUserID string) (op.AuthRequest, error) { + + requestUUID, err := uuid.NewUUID() + if err != nil { + return nil, fmt.Errorf("failed to create UUID: %w", err) + } + + request := AuthRequest{ + ID: requestUUID, + ClientID: authRequest.ClientID, + Nonce: authRequest.Nonce, + RedirectURI: authRequest.RedirectURI, + ResponseType: authRequest.ResponseType, + ResponseMode: authRequest.ResponseMode, + Scopes: authRequest.Scopes, + State: authRequest.State, + CodeChallenge: authRequest.CodeChallenge, + CodeChallengeMethod: authRequest.CodeChallengeMethod, + } + + if hintedUserID != "" { + parsedHintedUserID, err := utils.ParseUint(hintedUserID) + if err != nil { + return nil, fmt.Errorf("failed to parse user ID: %w", err) + } + request.UserID = &parsedHintedUserID + } + + err = s.db.Omit(clause.Associations).Create(&request).Error + if err != nil { + return nil, fmt.Errorf("failed to create auth request: %w", err) + } + + return &request, nil +} + +func (s *storageImpl) AuthRequestByID(_ context.Context, requestID string) (op.AuthRequest, error) { + validRequestUUID, err := uuid.Parse(requestID) + if err != nil { + return nil, fmt.Errorf("failed to parse request UUID: %w", err) + } + var request AuthRequest + err = s.db.Omit(clause.Associations).Where(&AuthRequest{ID: validRequestUUID}).Take(&request).Error + if err != nil { + return nil, fmt.Errorf("failed to get auth request: %w", err) + } + return &request, nil +} + +func (s *storageImpl) AuthRequestByCode(_ context.Context, code string) (op.AuthRequest, error) { + var requestCode AuthRequestCode + err := s.db.Preload("AuthRequest").Where(&AuthRequestCode{Code: code}).Take(&requestCode).Error + if err != nil { + return nil, fmt.Errorf("failed to get auth request code: %w", err) + } else if requestCode.AuthRequest == nil { + return nil, fmt.Errorf("auth request not found for code: %s", code) + } + return requestCode.AuthRequest, nil +} + +func (s *storageImpl) SaveAuthCode(_ context.Context, requestID string, code string) error { + validRequestUUID, err := uuid.Parse(requestID) + if err != nil { + return fmt.Errorf("failed to parse request UUID: %w", err) + } + requestCode := AuthRequestCode{ + Code: code, + AuthRequestID: validRequestUUID, + } + err = s.db.Omit(clause.Associations).Create(&requestCode).Error + if err != nil { + return fmt.Errorf("failed to save auth code: %w", err) + } + return nil +} + +func (s *storageImpl) DeleteAuthRequest(_ context.Context, requestID string) error { + validRequestUUID, err := uuid.Parse(requestID) + if err != nil { + return fmt.Errorf("failed to parse request UUID: %w", err) + } + // Auth codes deleted via foreign key cascade + err = s.db.Omit(clause.Associations).Where(&AuthRequest{ID: validRequestUUID}).Delete(&AuthRequest{}).Error + if err != nil { + return fmt.Errorf("failed to delete auth request: %w", err) + } + return nil +} + +func (s *storageImpl) CreateAccessToken(_ context.Context, request op.TokenRequest) (accessTokenID string, expiration time.Time, err error) { + var clientID string + var userID uint + switch req := request.(type) { + case *AuthRequest: + if !req.Done() || req.UserID == nil { + return "", time.Time{}, oidc.ErrLoginRequired() + } + clientID = req.ClientID + userID = *req.UserID + // It is possible for requests to be of another type -- like op.TokenExchangeRequest -- but we don't support + // that currently + default: + return "", time.Time{}, fmt.Errorf("unsupported request type: %T", request) + } + + token, err := s.createAccessToken(clientID, nil, request.GetScopes(), userID) + if err != nil { + return "", time.Time{}, err + } + return token.ID.String(), token.Expiry, nil +} + +func (s *storageImpl) CreateAccessAndRefreshTokens(_ context.Context, request op.TokenRequest, currentRefreshToken string) (accessTokenID string, newRefreshToken string, expiration time.Time, err error) { + var clientID string + var userID uint + var authTime time.Time + switch req := request.(type) { + case *AuthRequest: + if !req.Done() || req.UserID == nil { + return "", "", time.Time{}, oidc.ErrLoginRequired() + } + clientID = req.ClientID + userID = *req.UserID + authTime = req.GetAuthTime() + case *RefreshToken: + clientID = req.ClientID + userID = req.UserID + authTime = req.GetAuthTime() + default: + return "", "", time.Time{}, fmt.Errorf("unsupported request type: %T", request) + } + + var refreshTokenModel *RefreshToken + if currentRefreshToken == "" { + refreshTokenModel, newRefreshToken, err = s.createRefreshToken(clientID, request.GetScopes(), userID, authTime) + } else { + refreshTokenModel, newRefreshToken, err = s.renewRefreshToken(currentRefreshToken) + } + if err != nil { + return "", "", time.Time{}, fmt.Errorf("failed to create refresh token: %w", err) + } + + var tokenModel *Token + tokenModel, err = s.createAccessToken(clientID, &refreshTokenModel.ID, request.GetScopes(), userID) + if err != nil { + return "", "", time.Time{}, fmt.Errorf("failed to create access token: %w", err) + } + + return tokenModel.ID.String(), newRefreshToken, tokenModel.Expiry, nil +} + +func (s *storageImpl) TokenRequestByRefreshToken(_ context.Context, refreshTokenID string) (op.RefreshTokenRequest, error) { + parsedRefreshTokenID, err := uuid.Parse(refreshTokenID) + if err != nil { + return nil, fmt.Errorf("failed to parse refresh token ID: %w", err) + } + var refreshToken RefreshToken + err = s.db.Omit(clause.Associations).Where(&RefreshToken{ID: parsedRefreshTokenID}).Take(&refreshToken).Error + if err != nil { + return nil, fmt.Errorf("failed to get refresh token: %w", err) + } + return &refreshToken, nil +} + +func (s *storageImpl) TerminateSession(_ context.Context, userID string, clientID string) error { + parsedUserID, err := utils.ParseUint(userID) + if err != nil { + return fmt.Errorf("failed to parse user ID: %w", err) + } + err1 := s.db.Omit(clause.Associations).Where(&RefreshToken{ClientID: clientID, UserID: parsedUserID}).Delete(&RefreshToken{}).Error + // Most tokens should've been caught by the foreign key cascade, but just in case we'll make sure we wipe all matching tokens too + err2 := s.db.Omit(clause.Associations).Where(&Token{ClientID: clientID, UserID: parsedUserID}).Delete(&Token{}).Error + if err1 != nil && !errors.Is(err1, gorm.ErrRecordNotFound) { + return fmt.Errorf("failed to delete refresh tokens: %w", err1) + } + if err2 != nil && !errors.Is(err2, gorm.ErrRecordNotFound) { + return fmt.Errorf("failed to delete access tokens: %w", err2) + } + return nil +} + +func (s *storageImpl) RevokeToken(_ context.Context, tokenOrTokenID string, userID string, clientID string) *oidc.Error { + var tokenUUIDToRevoke, refreshTokenUUIDToRevoke *uuid.UUID + + // UserID can be empty! If it's passed we'll filter on it. + var queryUserID uint + if userID != "" { + var err error + queryUserID, err = utils.ParseUint(userID) + if err != nil { + return oidc.ErrServerError().WithParent(err).WithDescription("failed to parse user ID") + } + } + + tokenUUID, err := uuid.Parse(tokenOrTokenID) + if err != nil { + // Token's not a UUID, so all we can do is revoke a matching refresh token + hash := sha512.Sum512([]byte(tokenOrTokenID)) + var refreshToken RefreshToken + err = s.db.Omit(clause.Associations).Where(&RefreshToken{TokenHash: hash[:], ClientID: clientID, UserID: queryUserID}).Take(&refreshToken).Error + if err == nil { + refreshTokenUUIDToRevoke = &refreshToken.ID + } + } else { + // Look up as token UUID + var token Token + err = s.db.Omit(clause.Associations).Where(&Token{ID: tokenUUID, ClientID: clientID, UserID: queryUserID}).Take(&token).Error + if err == nil { + tokenUUIDToRevoke = &token.ID + refreshTokenUUIDToRevoke = token.RefreshTokenID + } else { + // Look up as refresh token UUID + var refreshToken RefreshToken + err = s.db.Omit(clause.Associations).Where(&RefreshToken{ID: tokenUUID, ClientID: clientID, UserID: queryUserID}).Take(&refreshToken).Error + if err == nil { + refreshTokenUUIDToRevoke = &refreshToken.ID + } + } + } + + if tokenUUIDToRevoke != nil { + err = s.db.Omit(clause.Associations).Where(&Token{ID: *tokenUUIDToRevoke}).Delete(&Token{}).Error + } + if refreshTokenUUIDToRevoke != nil { + err = s.revokeRefreshToken(*refreshTokenUUIDToRevoke) + } + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return oidc.ErrInvalidRequest().WithDescription("token not found") + } else { + return oidc.ErrServerError().WithParent(err).WithDescription("failed to revoke token") + } + } + return nil +} + +func (s *storageImpl) GetRefreshTokenInfo(_ context.Context, clientID string, token string) (userID string, tokenID string, err error) { + hash := sha512.Sum512([]byte(token)) + var refreshToken RefreshToken + err = s.db.Omit(clause.Associations).Where(&RefreshToken{TokenHash: hash[:], ClientID: clientID}).Take(&refreshToken).Error + if err != nil { + return "", "", fmt.Errorf("failed to get refresh token: %w", err) + } + return utils.UintToString(refreshToken.UserID), refreshToken.ID.String(), nil +} + +func (s *storageImpl) SigningKey(_ context.Context) (op.SigningKey, error) { + var key SigningKey + err := s.db.Omit(clause.Associations).Order("created_at desc").First(&key).Error + if err != nil { + return nil, fmt.Errorf("failed to get signing key: %w", err) + } + return decryptPrivateSigningKey(&key) +} + +func (s *storageImpl) KeySet(_ context.Context) ([]op.Key, error) { + var keys []SigningKey + err := s.db.Omit(clause.Associations).Order("created_at desc").Find(&keys).Error + if err != nil { + return nil, fmt.Errorf("failed to get signing keys: %w", err) + } + return utils.Map(keys, func(rawKey SigningKey) op.Key { + return &publicSigningKey{SigningKey: rawKey} + }), nil +} + +func (s *storageImpl) GetClientByClientID(_ context.Context, clientID string) (op.Client, error) { + var client Client + err := s.db.Omit(clause.Associations).Where(&Client{ID: clientID}).Take(&client).Error + if err != nil { + return nil, fmt.Errorf("failed to get client: %w", err) + } + return wrapPossibleDevModeClient(client), nil +} + +func (s *storageImpl) AuthorizeClientIDSecret(_ context.Context, clientID, clientSecret string) error { + var client Client + err := s.db.Omit(clause.Associations).Where(&Client{ID: clientID}).Take(&client).Error + if err != nil { + return fmt.Errorf("failed to get client: %w", err) + } + if len(client.ClientSecretSalt) == 0 || len(client.ClientSecretHash) == 0 || client.ClientSecretIterations == 0 { + return fmt.Errorf("client secret not configured; perhaps client should be using PKCE") + } + derivedKey := pbkdf2.Key([]byte(clientSecret), client.ClientSecretSalt, client.ClientSecretIterations, len(client.ClientSecretHash), sha512.New) + if !bytes.Equal(derivedKey, client.ClientSecretHash) { + return fmt.Errorf("client secret does not match") + } + return nil +} + +// SetUserinfoFromScopes is an empty implementation because according to Zitadel's example, SetUserinfoFromRequest +// should be implemented instead. +func (s *storageImpl) SetUserinfoFromScopes(_ context.Context, _ *oidc.UserInfo, _, _ string, _ []string) error { + return nil +} + +func (s *storageImpl) SetUserinfoFromRequest(_ context.Context, userinfo *oidc.UserInfo, request op.IDTokenRequest, scopes []string) error { + userID, err := utils.ParseUint(request.GetSubject()) + if err != nil { + return fmt.Errorf("failed to parse user ID: %w", err) + } + return s.setUserinfo(userinfo, userID, scopes) +} + +func (s *storageImpl) SetUserinfoFromToken(_ context.Context, userinfo *oidc.UserInfo, tokenID, subject, _ string) error { + parsedTokenID, err := uuid.Parse(tokenID) + if err != nil { + return fmt.Errorf("failed to parse token ID: %w", err) + } + var token Token + err = s.db.Omit(clause.Associations).Where(&Token{ID: parsedTokenID}).Take(&token).Error + if err != nil { + return fmt.Errorf("failed to get token: %w", err) + } + if utils.UintToString(token.UserID) != subject { + // This check is theoretically unnecessary because the library should be covering this, + // but we have the info so why not. + return fmt.Errorf("token mismatched with subject") + } + + // We ignore the origin argument to this function because CORS is handled by Gin middleware, no reason for us + // to implement that here. + + return s.setUserinfo(userinfo, token.UserID, token.Scopes) +} + +// SetIntrospectionFromToken has some arguments we ignore. We ignore the subject because we can get the user ID in +// a more verified way from the token. +func (s *storageImpl) SetIntrospectionFromToken(_ context.Context, introspection *oidc.IntrospectionResponse, tokenID, subject, clientID string) error { + parsedUserID, err := utils.ParseUint(subject) + if err != nil { + return fmt.Errorf("failed to parse user ID: %w", err) + } + + parsedTokenID, err := uuid.Parse(tokenID) + if err != nil { + return fmt.Errorf("failed to parse token ID: %w", err) + } + var token Token + err = s.db.Omit(clause.Associations).Where(&Token{ID: parsedTokenID, ClientID: clientID, UserID: parsedUserID}).Take(&token).Error + if err != nil { + return fmt.Errorf("failed to get token: %w", err) + } + + // In our case, we basically equate audience, client ID, and application. + // This kind of audience == client ID check is "the right way" to do this, + // though, so we do it to avoid a gotcha down the line. + for _, aud := range token.GetAudience() { + if aud == clientID { + userInfo := new(oidc.UserInfo) + err = s.setUserinfo(userInfo, token.UserID, token.Scopes) + if err != nil { + return fmt.Errorf("failed to set userinfo: %w", err) + } + introspection.SetUserInfo(userInfo) + introspection.Scope = token.Scopes + introspection.ClientID = token.ClientID + introspection.Active = true + return nil + } + } + return fmt.Errorf("token not valid for client") +} + +func (s *storageImpl) GetPrivateClaimsFromScopes(_ context.Context, userID, _ string, scopes []string) (map[string]any, error) { + claims := make(map[string]any) + for _, scope := range scopes { + switch scope { + case groupsClaim: + // Only get the user if we need to + parsedUserID, err := utils.ParseUint(userID) + if err != nil { + return nil, fmt.Errorf("failed to parse user ID: %w", err) + } + var user models.User + err = s.db.Where(&models.User{Model: gorm.Model{ID: parsedUserID}}).Scopes(models.ReadUserScope).Take(&user).Error + if err != nil { + return nil, fmt.Errorf("failed to get user: %w", err) + } + groups := make([]string, 0, len(user.Assignments)) + for _, assignment := range user.Assignments { + if assignment != nil && assignment.IsActive() && assignment.Role != nil && assignment.Role.Name != nil { + groups = append(groups, *assignment.Role.Name) + } + } + claims[groupsClaim] = groups + } + } + return claims, nil +} + +func (s *storageImpl) SignatureAlgorithms(_ context.Context) ([]jose.SignatureAlgorithm, error) { + // See signing_key.go, this is hardcoded + return []jose.SignatureAlgorithm{jose.ES512}, nil +} + +func (s *storageImpl) GetKeyByIDAndClientID(_ context.Context, _, _ string) (*jose.JSONWebKey, error) { + // What we're meant to do here is define a list of (client ID, key ID, public key) tuples in configuration or something. + // The idea is that a client application would have the private key, and it could then sign JWTs *to send to us*, per + // RFC 7523 (urn:ietf:params:oauth:grant-type:jwt-bearer). zitadel-oidc would call this function to get the public key, + // validate the signature, and then return the requested access token. This allows a client to get an access token for + // a user without user interaction. + // + // We don't support that grant type for Sherlock so we leave this unimplemented. If somehow this does get called, + // the error will bubble up as if the assertion was invalid. + // + // https://datatracker.ietf.org/doc/html/rfc7523 + return nil, fmt.Errorf("JWT Profile Authorization Grants (RFC 7523) aren't currently implemented") +} + +func (s *storageImpl) ValidateJWTProfileScopes(_ context.Context, _ string, scopes []string) ([]string, error) { + allowedScopes := make([]string, 0) + for _, scope := range scopes { + if scope == oidc.ScopeOpenID || scope == oidc.ScopeProfile || scope == oidc.ScopeEmail { + allowedScopes = append(allowedScopes, scope) + } + } + return allowedScopes, nil +} + +// Health isn't something we particularly need to rely on, because this OIDC provider exists in the context of a larger +// application that already has liveness and readiness probes configured. If the OIDC API is available, the OIDC provider +// will be healthy. +func (s *storageImpl) Health(_ context.Context) error { + if s.db == nil { + return fmt.Errorf("no database connection") + } + return nil +} + +func (s *storageImpl) createRefreshToken(clientID string, scopes []string, userID uint, authAt time.Time) (*RefreshToken, string, error) { + id, err := uuid.NewUUID() + if err != nil { + return nil, "", fmt.Errorf("failed to create UUID: %w", err) + } + refreshToken := make([]byte, 256) + _, err = rand.Read(refreshToken) + if err != nil { + return nil, "", fmt.Errorf("failed to generate refresh token: %w", err) + } + hash := sha512.Sum512(refreshToken) + refreshTokenModel := &RefreshToken{ + ID: id, + TokenHash: hash[:], + ClientID: clientID, + Scopes: scopes, + OriginalAuthAt: authAt, + UserID: userID, + } + err = s.db.Omit(clause.Associations).Create(refreshTokenModel).Error + if err != nil { + return nil, "", fmt.Errorf("failed to create refresh token: %w", err) + } + return refreshTokenModel, string(refreshToken), nil +} + +func (s *storageImpl) revokeRefreshToken(refreshTokenID uuid.UUID) error { + // Normal tokens are deleted via foreign key cascade + err := s.db.Omit(clause.Associations).Where(&RefreshToken{ID: refreshTokenID}).Delete(&RefreshToken{}).Error + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return fmt.Errorf("failed to delete refresh token: %w", err) + } + return nil +} + +func (s *storageImpl) renewRefreshToken(refreshToken string) (*RefreshToken, string, error) { + hash := sha512.Sum512([]byte(refreshToken)) + var refreshTokenModel RefreshToken + err := s.db.Omit(clause.Associations).Where(&RefreshToken{TokenHash: hash[:]}).Take(&refreshTokenModel).Error + if err != nil { + return nil, "", fmt.Errorf("failed to get refresh token: %w", err) + } + + err = s.revokeRefreshToken(refreshTokenModel.ID) + if err != nil { + return nil, "", fmt.Errorf("failed to revoke existing refresh token: %w", err) + } + + return s.createRefreshToken(refreshTokenModel.ClientID, refreshTokenModel.Scopes, refreshTokenModel.UserID, refreshTokenModel.OriginalAuthAt) +} + +func (s *storageImpl) createAccessToken(clientID string, refreshTokenID *uuid.UUID, scopes []string, userID uint) (*Token, error) { + id, err := uuid.NewUUID() + if err != nil { + return nil, fmt.Errorf("failed to create UUID: %w", err) + } + tokenModel := &Token{ + ID: id, + RefreshTokenID: refreshTokenID, + ClientID: clientID, + Scopes: scopes, + Expiry: time.Now().Add(config.Config.MustDuration("oidc.tokenDuration")), + UserID: userID, + } + + err = s.db.Omit(clause.Associations).Create(tokenModel).Error + if err != nil { + return nil, fmt.Errorf("failed to create access token: %w", err) + } + + return tokenModel, nil +} + +func (s *storageImpl) setUserinfo(userInfo *oidc.UserInfo, userID uint, scopes []string) error { + var user models.User + err := s.db.Where(&models.User{Model: gorm.Model{ID: userID}}).Scopes(models.ReadUserScope).Take(&user).Error + if err != nil { + return fmt.Errorf("failed to get user: %w", err) + } + for _, scope := range scopes { + switch scope { + case oidc.ScopeOpenID: + userInfo.Subject = utils.UintToString(user.ID) + case oidc.ScopeEmail: + userInfo.Email = user.Email + userInfo.EmailVerified = true + case oidc.ScopeProfile: + // Downstream systems expect consistent claims. We can get away with passing GivenName and FamilyName + // conditionally, but for Name it's worth it to always pass it, even if we potentially have to pass + // an ugly-looking email handle. + userInfo.Name = user.NameOrUsername() + userInfo.Nickname = user.NameOrUsername() + userInfo.PreferredUsername = user.AlphaNumericHyphenatedUsername() + userInfo.Locale = oidc.NewLocale(language.AmericanEnglish) + userInfo.UpdatedAt = oidc.FromTime(user.UpdatedAt) + if user.Name != nil { + nameParts := strings.Split(*user.Name, " ") + if len(nameParts) > 0 { + userInfo.GivenName = nameParts[0] + } + if len(nameParts) > 1 { + userInfo.FamilyName = nameParts[len(nameParts)-1] + } + } + userInfo.Website = "https://broad.io/beehive/r/user/" + user.Email // :shrug: + case groupsClaim: + groups := make([]string, 0, len(user.Assignments)) + for _, assignment := range user.Assignments { + if assignment != nil && assignment.IsActive() && assignment.Role != nil && assignment.Role.Name != nil { + groups = append(groups, *assignment.Role.Name) + } + } + userInfo.AppendClaims(groupsClaim, groups) + } + } + return nil +} diff --git a/sherlock/internal/oidc_models/storage_test.go b/sherlock/internal/oidc_models/storage_test.go new file mode 100644 index 000000000..32e2576bb --- /dev/null +++ b/sherlock/internal/oidc_models/storage_test.go @@ -0,0 +1,695 @@ +package oidc_models + +import ( + "context" + "database/sql" + "github.com/broadinstitute/sherlock/go-shared/pkg/utils" + "github.com/broadinstitute/sherlock/sherlock/internal/models" + "github.com/go-jose/go-jose/v4" + "github.com/google/uuid" + "github.com/zitadel/oidc/v3/pkg/oidc" + "golang.org/x/text/language" + "gorm.io/gorm/clause" + "strings" + "time" +) + +func (s *oidcModelsSuite) TestStorageImpl_CreateAuthRequest() { + clientID, _, err := s.GenerateClient(s.DB) + s.NoError(err) + + authRequest, err := s.storage.CreateAuthRequest(context.Background(), &oidc.AuthRequest{ + Scopes: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim}, + ResponseType: oidc.ResponseTypeIDTokenOnly, + ClientID: clientID, + RedirectURI: s.GeneratedClientRedirectURI(), + State: "some-state", + Nonce: "some-nonce", + ResponseMode: oidc.ResponseModeQuery, + CodeChallenge: "code-challenge", + CodeChallengeMethod: oidc.CodeChallengeMethodS256, + }, "") + s.NoError(err) + + var authRequests []AuthRequest + s.NoError(s.DB.Find(&authRequests).Error) + s.Len(authRequests, 1) + + s.Equal(authRequest.GetID(), authRequests[0].ID.String()) + s.Equal(clientID, authRequests[0].ClientID) + s.Equal(oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim}, authRequests[0].Scopes) + s.Equal(oidc.ResponseTypeIDTokenOnly, authRequests[0].ResponseType) + s.Equal(s.GeneratedClientRedirectURI(), authRequests[0].RedirectURI) + s.Equal("some-state", authRequests[0].State) + s.Equal("some-nonce", authRequests[0].Nonce) + s.Equal(oidc.ResponseModeQuery, authRequests[0].ResponseMode) + s.Equal("code-challenge", authRequests[0].CodeChallenge) + s.Equal(oidc.CodeChallengeMethodS256, authRequests[0].CodeChallengeMethod) +} + +func (s *oidcModelsSuite) TestStorageImpl_AuthRequestByID() { + clientID, _, err := s.GenerateClient(s.DB) + s.NoError(err) + + authRequest := &AuthRequest{ + ID: uuid.New(), + DoneAt: sql.NullTime{Time: time.Now(), Valid: true}, + ClientID: clientID, + Nonce: "some-nonce", + RedirectURI: s.GeneratedClientRedirectURI(), + Scopes: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim}, + State: "some-state", + UserID: utils.PointerTo(s.TestData.User_Suitable().ID), + } + s.NoError(s.DB.Omit(clause.Associations).Create(authRequest).Error) + + authRequestByID, err := s.storage.AuthRequestByID(context.Background(), authRequest.ID.String()) + s.NoError(err) + s.Equal(authRequest.ID.String(), authRequestByID.GetID()) +} + +func (s *oidcModelsSuite) TestStorageImpl_AuthRequestByCode() { + clientID, _, err := s.GenerateClient(s.DB) + s.NoError(err) + + authRequest := &AuthRequest{ + ID: uuid.New(), + DoneAt: sql.NullTime{Time: time.Now(), Valid: true}, + ClientID: clientID, + Nonce: "some-nonce", + RedirectURI: s.GeneratedClientRedirectURI(), + Scopes: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim}, + State: "some-state", + UserID: utils.PointerTo(s.TestData.User_Suitable().ID), + } + s.NoError(s.DB.Omit(clause.Associations).Create(authRequest).Error) + + code := "some-code" + s.NoError(s.storage.SaveAuthCode(context.Background(), authRequest.ID.String(), code)) + + authRequestByCode, err := s.storage.AuthRequestByCode(context.Background(), code) + s.NoError(err) + s.Equal(authRequest.ID.String(), authRequestByCode.GetID()) +} + +func (s *oidcModelsSuite) TestStorageImpl_SaveAuthCode() { + clientID, _, err := s.GenerateClient(s.DB) + s.NoError(err) + + authRequest := &AuthRequest{ + ID: uuid.New(), + DoneAt: sql.NullTime{Time: time.Now(), Valid: true}, + ClientID: clientID, + Nonce: "some-nonce", + RedirectURI: s.GeneratedClientRedirectURI(), + Scopes: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim}, + State: "some-state", + UserID: utils.PointerTo(s.TestData.User_Suitable().ID), + } + s.NoError(s.DB.Omit(clause.Associations).Create(authRequest).Error) + + code := "some-code" + s.NoError(s.storage.SaveAuthCode(context.Background(), authRequest.ID.String(), code)) + + var authRequestCodes []AuthRequestCode + s.NoError(s.DB.Find(&authRequestCodes).Error) + s.Len(authRequestCodes, 1) + s.Equal(code, authRequestCodes[0].Code) + s.Equal(authRequest.ID.String(), authRequestCodes[0].AuthRequestID.String()) +} + +func (s *oidcModelsSuite) TestStorageImpl_DeleteAuthRequest() { + clientID, _, err := s.GenerateClient(s.DB) + s.NoError(err) + + authRequest := &AuthRequest{ + ID: uuid.New(), + DoneAt: sql.NullTime{Time: time.Now(), Valid: true}, + ClientID: clientID, + Nonce: "some-nonce", + RedirectURI: s.GeneratedClientRedirectURI(), + Scopes: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim}, + State: "some-state", + UserID: utils.PointerTo(s.TestData.User_Suitable().ID), + } + s.NoError(s.DB.Omit(clause.Associations).Create(authRequest).Error) + + s.NoError(s.storage.DeleteAuthRequest(context.Background(), authRequest.ID.String())) + + var authRequests []AuthRequest + s.NoError(s.DB.Find(&authRequests).Error) + s.Len(authRequests, 0) +} + +func (s *oidcModelsSuite) TestStorageImpl_CreateAccessToken() { + clientID, _, err := s.GenerateClient(s.DB) + s.NoError(err) + accessTokenID, _, err := s.storage.CreateAccessToken(context.Background(), &AuthRequest{ + ID: uuid.New(), + DoneAt: sql.NullTime{Time: time.Now(), Valid: true}, + ClientID: clientID, + Nonce: "some-nonce", + RedirectURI: s.GeneratedClientRedirectURI(), + Scopes: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim}, + State: "some-state", + UserID: utils.PointerTo(s.TestData.User_Suitable().ID), + }) + s.NoError(err) + + var tokens []Token + s.NoError(s.DB.Find(&tokens).Error) + s.Len(tokens, 1) + s.Equal(accessTokenID, tokens[0].ID.String()) + s.Equal(clientID, tokens[0].ClientID) + s.Equal(oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim}, tokens[0].Scopes) + +} + +func (s *oidcModelsSuite) TestStorageImpl_CreateAccessAndRefreshTokens() { + clientID, _, err := s.GenerateClient(s.DB) + s.NoError(err) + authRequest := &AuthRequest{ + ID: uuid.New(), + DoneAt: sql.NullTime{Time: time.Now(), Valid: true}, + ClientID: clientID, + Nonce: "some-nonce", + RedirectURI: s.GeneratedClientRedirectURI(), + Scopes: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim}, + State: "some-state", + UserID: utils.PointerTo(s.TestData.User_Suitable().ID), + } + accessTokenID, _, _, err := s.storage.CreateAccessAndRefreshTokens(context.Background(), authRequest, "") + s.NoError(err) + + var tokens []Token + s.NoError(s.DB.Find(&tokens).Error) + s.Len(tokens, 1) + s.Equal(accessTokenID, tokens[0].ID.String()) + s.Equal(clientID, tokens[0].ClientID) + s.Equal(oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim}, tokens[0].Scopes) + + var refreshTokens []RefreshToken + s.NoError(s.DB.Find(&refreshTokens).Error) + s.Len(refreshTokens, 1) +} + +func (s *oidcModelsSuite) TestStorageImpl_TokenRequestByRefreshToken() { + clientID, _, err := s.GenerateClient(s.DB) + s.NoError(err) + refreshTokenModel, _, err := s.storage.createRefreshToken(clientID, []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim}, s.TestData.User_Suitable().ID, time.Now()) + s.NoError(err) + + tokenRequest, err := s.storage.TokenRequestByRefreshToken(context.Background(), refreshTokenModel.ID.String()) + s.NoError(err) + s.Equal(refreshTokenModel.OriginalAuthAt.Second(), tokenRequest.GetAuthTime().Second()) +} + +func (s *oidcModelsSuite) TestStorageImpl_TerminateSession() { + clientID, _, err := s.GenerateClient(s.DB) + s.NoError(err) + _, _, _, err = s.storage.CreateAccessAndRefreshTokens(context.Background(), &AuthRequest{ + ID: uuid.New(), + DoneAt: sql.NullTime{Time: time.Now(), Valid: true}, + ClientID: clientID, + Nonce: "some-nonce", + RedirectURI: s.GeneratedClientRedirectURI(), + Scopes: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim}, + State: "some-state", + UserID: utils.PointerTo(s.TestData.User_Suitable().ID), + }, "") + s.NoError(err) + s.NoError(s.storage.TerminateSession(context.Background(), utils.UintToString(s.TestData.User_Suitable().ID), clientID)) + var refreshTokens []RefreshToken + s.NoError(s.DB.Find(&refreshTokens).Error) + s.Len(refreshTokens, 0) + var tokens []Token + s.NoError(s.DB.Find(&tokens).Error) + s.Len(tokens, 0) +} + +func (s *oidcModelsSuite) TestStorageImpl_RevokeToken() { + clientID, _, err := s.GenerateClient(s.DB) + s.NoError(err) + s.Run("refresh token", func() { + _, refreshToken, _, err := s.storage.CreateAccessAndRefreshTokens(context.Background(), &AuthRequest{ + ID: uuid.New(), + DoneAt: sql.NullTime{Time: time.Now(), Valid: true}, + ClientID: clientID, + Nonce: "some-nonce", + RedirectURI: s.GeneratedClientRedirectURI(), + Scopes: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim}, + State: "some-state", + UserID: utils.PointerTo(s.TestData.User_Suitable().ID), + }, "") + s.NoError(err) + + s.Nil(s.storage.RevokeToken(context.Background(), refreshToken, utils.UintToString(s.TestData.User_Suitable().ID), clientID)) + + var refreshTokens []RefreshToken + s.NoError(s.DB.Find(&refreshTokens).Error) + s.Len(refreshTokens, 0) + var tokens []Token + s.NoError(s.DB.Find(&tokens).Error) + s.Len(tokens, 0) + }) + s.Run("token", func() { + accessTokenID, _, err := s.storage.CreateAccessToken(context.Background(), &AuthRequest{ + ID: uuid.New(), + DoneAt: sql.NullTime{Time: time.Now(), Valid: true}, + ClientID: clientID, + Nonce: "some-nonce", + RedirectURI: s.GeneratedClientRedirectURI(), + Scopes: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim}, + State: "some-state", + UserID: utils.PointerTo(s.TestData.User_Suitable().ID), + }) + s.NoError(err) + + s.Nil(s.storage.RevokeToken(context.Background(), accessTokenID, utils.UintToString(s.TestData.User_Suitable().ID), clientID)) + + var refreshTokens []RefreshToken + s.NoError(s.DB.Find(&refreshTokens).Error) + s.Len(refreshTokens, 0) + var tokens []Token + s.NoError(s.DB.Find(&tokens).Error) + s.Len(tokens, 0) + }) +} + +func (s *oidcModelsSuite) TestStorageImpl_GetRefreshTokenInfo() { + clientID, _, err := s.GenerateClient(s.DB) + s.NoError(err) + refreshTokenModel, rawRefreshToken, err := s.storage.createRefreshToken(clientID, []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim}, s.TestData.User_Suitable().ID, time.Now()) + s.NoError(err) + + userID, tokenID, err := s.storage.GetRefreshTokenInfo(context.Background(), clientID, rawRefreshToken) + s.NoError(err) + s.Equal(utils.UintToString(s.TestData.User_Suitable().ID), userID) + s.Equal(refreshTokenModel.ID.String(), tokenID) +} + +func (s *oidcModelsSuite) TestStorageImpl_SigningKey() { + key1, err := saveNewSigningKey(context.Background(), s.DB) + s.NoError(err) + key2, err := saveNewSigningKey(context.Background(), s.DB) + s.NoError(err) + s.NoError(s.DB.Model(&key1).UpdateColumn("created_at", time.Now().Add(-time.Hour)).Error) + signingKey, err := s.storage.SigningKey(context.Background()) + s.NoError(err) + s.Equal(key2.ID.String(), signingKey.ID()) +} + +func (s *oidcModelsSuite) TestStorageImpl_KeySet() { + key1, err := saveNewSigningKey(context.Background(), s.DB) + s.NoError(err) + key2, err := saveNewSigningKey(context.Background(), s.DB) + s.NoError(err) + keyset, err := s.storage.KeySet(context.Background()) + s.NoError(err) + s.Len(keyset, 2) + var key1Found, key2Found bool + for _, key := range keyset { + if key.ID() == key1.ID.String() { + key1Found = true + } + if key.ID() == key2.ID.String() { + key2Found = true + } + } + s.True(key1Found) + s.True(key2Found) +} + +func (s *oidcModelsSuite) TestStorageImpl_GetClientByClientID() { + clientID, _, err := s.GenerateClient(s.DB) + s.NoError(err) + client, err := s.storage.GetClientByClientID(context.Background(), clientID) + s.NoError(err) + s.Equal(clientID, client.GetID()) +} + +func (s *oidcModelsSuite) TestStorageImpl_AuthorizeClientIDSecret() { + clientID, clientSecret, err := s.GenerateClient(s.DB) + s.NoError(err) + s.Run("valid", func() { + s.NoError(s.storage.AuthorizeClientIDSecret(context.Background(), clientID, clientSecret)) + }) + s.Run("invalid", func() { + s.Error(s.storage.AuthorizeClientIDSecret(context.Background(), clientID, "invalid")) + }) +} + +func (s *oidcModelsSuite) TestStorageImpl_SetUserinfoFromScopes() { + clientID, _, err := s.GenerateClient(s.DB) + s.NoError(err) + userinfo := &oidc.UserInfo{} + s.NoError(s.storage.SetUserinfoFromScopes(context.Background(), userinfo, utils.UintToString(s.TestData.User_Suitable().ID), clientID, []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim})) + s.Empty(userinfo.Subject) // This method does nothing intentionally; the library says to implement the request handling instead +} + +func (s *oidcModelsSuite) TestStorageImpl_SetUserinfoFromRequest() { + userinfo := &oidc.UserInfo{} + s.NoError(s.storage.SetUserinfoFromRequest(context.Background(), userinfo, &AuthRequest{ + UserID: utils.PointerTo(s.TestData.User_Suitable().ID), + }, []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim})) + s.Equal(utils.UintToString(s.TestData.User_Suitable().ID), userinfo.Subject) +} + +func (s *oidcModelsSuite) TestStorageImpl_SetUserinfoFromToken() { + clientID, _, err := s.GenerateClient(s.DB) + s.NoError(err) + token := Token{ + ID: uuid.New(), + ClientID: clientID, + Scopes: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim}, + Expiry: time.Now().Add(time.Hour), + UserID: s.TestData.User_Suitable().ID, + } + s.NoError(s.DB.Omit(clause.Associations).Create(&token).Error) + userinfo := &oidc.UserInfo{} + s.NoError(s.storage.SetUserinfoFromToken(context.Background(), userinfo, token.ID.String(), utils.UintToString(s.TestData.User_Suitable().ID), "")) + s.Equal(utils.UintToString(s.TestData.User_Suitable().ID), userinfo.Subject) +} + +func (s *oidcModelsSuite) TestStorageImpl_SetIntrospectionFromToken() { + clientID, _, err := s.GenerateClient(s.DB) + s.NoError(err) + token := Token{ + ID: uuid.New(), + ClientID: clientID, + Scopes: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim}, + Expiry: time.Now().Add(time.Hour), + UserID: s.TestData.User_Suitable().ID, + } + s.NoError(s.DB.Omit(clause.Associations).Create(&token).Error) + introspection := &oidc.IntrospectionResponse{} + s.NoError(s.storage.SetIntrospectionFromToken(context.Background(), introspection, token.ID.String(), utils.UintToString(s.TestData.User_Suitable().ID), clientID)) + s.Equal(utils.UintToString(s.TestData.User_Suitable().ID), introspection.Subject) + s.Equal(token.Scopes, introspection.Scope) + s.Equal(clientID, introspection.ClientID) + s.True(introspection.Active) +} + +func (s *oidcModelsSuite) TestStorageImpl_GetPrivateClaimsFromScopes() { + s.Run("empty", func() { + claims, err := s.storage.GetPrivateClaimsFromScopes(context.Background(), utils.UintToString(s.TestData.User_Suitable().ID), "", []string{}) + s.NoError(err) + s.Empty(claims) + }) + s.Run("groups", func() { + claims, err := s.storage.GetPrivateClaimsFromScopes(context.Background(), utils.UintToString(s.TestData.User_Suitable().ID), "", []string{groupsClaim}) + s.NoError(err) + if s.NotEmpty(claims) && s.Contains(claims, groupsClaim) && s.IsType([]string{}, claims[groupsClaim]) { + groups := claims[groupsClaim].([]string) + for _, ra := range s.TestData.User_Suitable().Assignments { + if ra != nil && ra.IsActive() { + s.Contains(groups, *ra.Role.Name) + } + } + } + }) + s.Run("extra", func() { + claims, err := s.storage.GetPrivateClaimsFromScopes(context.Background(), utils.UintToString(s.TestData.User_Suitable().ID), "", []string{groupsClaim, "extra"}) + s.NoError(err) + if s.NotEmpty(claims) && s.Contains(claims, groupsClaim) && s.IsType([]string{}, claims[groupsClaim]) { + groups := claims[groupsClaim].([]string) + for _, ra := range s.TestData.User_Suitable().Assignments { + if ra != nil && ra.IsActive() { + s.Contains(groups, *ra.Role.Name) + } + } + } + }) +} + +func (s *oidcModelsSuite) TestStorageImpl_SignatureAlgorithms() { + algorithms, err := s.storage.SignatureAlgorithms(context.Background()) + s.NoError(err) + s.Equal([]jose.SignatureAlgorithm{jose.ES512}, algorithms) +} + +func (s *oidcModelsSuite) TestStorageImpl_GetKeyByIDAndClientID() { + _, err := s.storage.GetKeyByIDAndClientID(context.Background(), "", "") + s.ErrorContains(err, "aren't currently implemented") +} + +func (s *oidcModelsSuite) TestStorageImpl_ValidateJWTProfileScopes() { + s.Run("empty", func() { + scopes, err := s.storage.ValidateJWTProfileScopes(context.Background(), "", []string{}) + s.NoError(err) + s.Empty(scopes) + }) + s.Run("all allowed", func() { + scopes, err := s.storage.ValidateJWTProfileScopes(context.Background(), "", []string{oidc.ScopeOpenID, oidc.ScopeEmail, oidc.ScopeProfile}) + s.NoError(err) + s.Equal([]string{oidc.ScopeOpenID, oidc.ScopeEmail, oidc.ScopeProfile}, scopes) + }) + s.Run("extra stuff", func() { + scopes, err := s.storage.ValidateJWTProfileScopes(context.Background(), "", []string{oidc.ScopeOpenID, oidc.ScopeEmail, oidc.ScopeProfile, "extra"}) + s.NoError(err) + s.Equal([]string{oidc.ScopeOpenID, oidc.ScopeEmail, oidc.ScopeProfile}, scopes) + }) +} + +func (s *oidcModelsSuite) TestStorageImpl_Health() { + s.Run("error when nil db", func() { + storage := &storageImpl{} + s.Error(storage.Health(context.Background())) + }) + s.Run("no error", func() { + s.NoError(s.storage.Health(context.Background())) + }) +} + +func (s *oidcModelsSuite) TestStorageImpl_createRefreshToken() { + clientID, _, err := s.GenerateClient(s.DB) + s.NoError(err) + refreshToken, _, err := s.storage.createRefreshToken(clientID, []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim}, s.TestData.User_Suitable().ID, time.Now()) + s.NoError(err) + s.NotEmpty(refreshToken.ID) + s.Equal(clientID, refreshToken.ClientID) + s.Equal(oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim}, refreshToken.Scopes) + s.Equal(s.TestData.User_Suitable().ID, refreshToken.UserID) + + var refreshTokens []RefreshToken + s.NoError(s.DB.Find(&refreshTokens).Error) + s.Len(refreshTokens, 1) +} + +func (s *oidcModelsSuite) TestStorageImpl_revokeRefreshToken() { + clientID, _, err := s.GenerateClient(s.DB) + s.NoError(err) + refreshToken, _, err := s.storage.createRefreshToken(clientID, []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim}, s.TestData.User_Suitable().ID, time.Now()) + s.NoError(err) + + s.NoError(s.storage.revokeRefreshToken(refreshToken.ID)) + + var refreshTokens []RefreshToken + s.NoError(s.DB.Find(&refreshTokens).Error) + s.Len(refreshTokens, 0) +} + +func (s *oidcModelsSuite) TestStorageImpl_renewRefreshToken() { + clientID, _, err := s.GenerateClient(s.DB) + s.NoError(err) + refreshToken, rawRefreshToken, err := s.storage.createRefreshToken(clientID, []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim}, s.TestData.User_Suitable().ID, time.Now()) + s.NoError(err) + + newRefreshToken, _, err := s.storage.renewRefreshToken(rawRefreshToken) + s.NoError(err) + + s.NotEqual(refreshToken.ID, newRefreshToken.ID) + s.Equal(refreshToken.ClientID, newRefreshToken.ClientID) + s.Equal(refreshToken.Scopes, newRefreshToken.Scopes) + s.Equal(refreshToken.UserID, newRefreshToken.UserID) + + var refreshTokens []RefreshToken + s.NoError(s.DB.Find(&refreshTokens).Error) + s.Len(refreshTokens, 1) +} + +func (s *oidcModelsSuite) TestStorageImpl_createAccessToken() { + clientID, _, err := s.GenerateClient(s.DB) + s.NoError(err) + accessToken, err := s.storage.createAccessToken(clientID, nil, []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim}, s.TestData.User_Suitable().ID) + s.NoError(err) + + var tokens []Token + s.NoError(s.DB.Find(&tokens).Error) + s.Len(tokens, 1) + s.Equal(accessToken.ID.String(), tokens[0].ID.String()) + s.Equal(clientID, tokens[0].ClientID) + s.Equal(oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, groupsClaim}, tokens[0].Scopes) + s.Equal(s.TestData.User_Suitable().ID, tokens[0].UserID) +} + +func (s *oidcModelsSuite) TestStorageImpl_setUserInfo() { + s.Run("no scopes", func() { + userInfo := &oidc.UserInfo{} + s.NoError(s.storage.setUserinfo(userInfo, s.TestData.User_Suitable().ID, []string{})) + s.Empty(userInfo.Subject) + s.Empty(userInfo.Email) + s.Empty(userInfo.EmailVerified) + s.Empty(userInfo.Name) + s.Empty(userInfo.Nickname) + s.Empty(userInfo.PreferredUsername) + s.Empty(userInfo.Locale) + s.Empty(userInfo.UpdatedAt) + s.Empty(userInfo.GivenName) + s.Empty(userInfo.FamilyName) + s.Empty(userInfo.Website) + s.Empty(userInfo.Claims) + }) + s.Run("openid", func() { + userInfo := &oidc.UserInfo{} + s.NoError(s.storage.setUserinfo(userInfo, s.TestData.User_Suitable().ID, []string{oidc.ScopeOpenID})) + s.Equal(utils.UintToString(s.TestData.User_Suitable().ID), userInfo.Subject) + s.Empty(userInfo.Email) + s.Empty(userInfo.EmailVerified) + s.Empty(userInfo.Name) + s.Empty(userInfo.Nickname) + s.Empty(userInfo.PreferredUsername) + s.Empty(userInfo.Locale) + s.Empty(userInfo.UpdatedAt) + s.Empty(userInfo.GivenName) + s.Empty(userInfo.FamilyName) + s.Empty(userInfo.Website) + s.Empty(userInfo.Claims) + }) + s.Run("email", func() { + userInfo := &oidc.UserInfo{} + s.NoError(s.storage.setUserinfo(userInfo, s.TestData.User_Suitable().ID, []string{oidc.ScopeEmail})) + s.Empty(userInfo.Subject) + s.Equal(s.TestData.User_Suitable().Email, userInfo.Email) + s.True(bool(userInfo.EmailVerified)) + s.Empty(userInfo.Name) + s.Empty(userInfo.Nickname) + s.Empty(userInfo.PreferredUsername) + s.Empty(userInfo.Locale) + s.Empty(userInfo.UpdatedAt) + s.Empty(userInfo.GivenName) + s.Empty(userInfo.FamilyName) + s.Empty(userInfo.Website) + s.Empty(userInfo.Claims) + }) + s.Run("profile", func() { + userInfo := &oidc.UserInfo{} + s.NoError(s.storage.setUserinfo(userInfo, s.TestData.User_Suitable().ID, []string{oidc.ScopeProfile})) + s.Empty(userInfo.Subject) + s.Empty(userInfo.Email) + s.Empty(userInfo.EmailVerified) + s.Equal(utils.PointerTo(s.TestData.User_Suitable()).NameOrUsername(), userInfo.Name) + s.Equal(utils.PointerTo(s.TestData.User_Suitable()).NameOrUsername(), userInfo.Nickname) + s.Equal(utils.PointerTo(s.TestData.User_Suitable()).AlphaNumericHyphenatedUsername(), userInfo.PreferredUsername) + s.Equal(oidc.NewLocale(language.AmericanEnglish), userInfo.Locale) + s.Equal(oidc.FromTime(s.TestData.User_Suitable().UpdatedAt), userInfo.UpdatedAt) + if s.TestData.User_Suitable().Name != nil { + nameParts := strings.Split(*s.TestData.User_Suitable().Name, " ") + s.Equal(nameParts[0], userInfo.GivenName) + s.Equal(nameParts[len(nameParts)-1], userInfo.FamilyName) + } else { + s.Empty(userInfo.GivenName) + s.Empty(userInfo.FamilyName) + } + s.Equal("https://broad.io/beehive/r/user/"+s.TestData.User_Suitable().Email, userInfo.Website) + s.Empty(userInfo.Claims) + }) + s.Run("groups", func() { + userInfo := &oidc.UserInfo{} + s.NoError(s.storage.setUserinfo(userInfo, s.TestData.User_Suitable().ID, []string{groupsClaim})) + s.Empty(userInfo.Subject) + s.Empty(userInfo.Email) + s.Empty(userInfo.EmailVerified) + s.Empty(userInfo.Name) + s.Empty(userInfo.Nickname) + s.Empty(userInfo.PreferredUsername) + s.Empty(userInfo.Locale) + s.Empty(userInfo.UpdatedAt) + s.Empty(userInfo.GivenName) + s.Empty(userInfo.FamilyName) + s.Empty(userInfo.Website) + if s.NotEmpty(userInfo.Claims) && s.Contains(userInfo.Claims, groupsClaim) && s.IsType([]string{}, userInfo.Claims[groupsClaim]) { + groups := userInfo.Claims[groupsClaim].([]string) + for _, ra := range s.TestData.User_Suitable().Assignments { + if ra != nil && ra.IsActive() { + s.Contains(groups, *ra.Role.Name) + } + } + } + }) + s.Run("openid email profile groups", func() { + userInfo := &oidc.UserInfo{} + s.NoError(s.storage.setUserinfo(userInfo, s.TestData.User_Suitable().ID, []string{oidc.ScopeOpenID, oidc.ScopeEmail, oidc.ScopeProfile, groupsClaim})) + s.Equal(utils.UintToString(s.TestData.User_Suitable().ID), userInfo.Subject) + s.Equal(s.TestData.User_Suitable().Email, userInfo.Email) + s.True(bool(userInfo.EmailVerified)) + s.Equal(utils.PointerTo(s.TestData.User_Suitable()).NameOrUsername(), userInfo.Name) + s.Equal(utils.PointerTo(s.TestData.User_Suitable()).NameOrUsername(), userInfo.Nickname) + s.Equal(utils.PointerTo(s.TestData.User_Suitable()).AlphaNumericHyphenatedUsername(), userInfo.PreferredUsername) + s.Equal(oidc.NewLocale(language.AmericanEnglish), userInfo.Locale) + s.Equal(oidc.FromTime(s.TestData.User_Suitable().UpdatedAt), userInfo.UpdatedAt) + if s.TestData.User_Suitable().Name != nil { + nameParts := strings.Split(*s.TestData.User_Suitable().Name, " ") + s.Equal(nameParts[0], userInfo.GivenName) + s.Equal(nameParts[len(nameParts)-1], userInfo.FamilyName) + } else { + s.Empty(userInfo.GivenName) + s.Empty(userInfo.FamilyName) + } + s.Equal("https://broad.io/beehive/r/user/"+s.TestData.User_Suitable().Email, userInfo.Website) + if s.NotEmpty(userInfo.Claims) && s.Contains(userInfo.Claims, groupsClaim) && s.IsType([]string{}, userInfo.Claims[groupsClaim]) { + groups := userInfo.Claims[groupsClaim].([]string) + for _, ra := range s.TestData.User_Suitable().Assignments { + if ra != nil && ra.IsActive() { + s.Contains(groups, *ra.Role.Name) + } + } + } + }) + s.Run("groups only active role assignments", func() { + suspendedRole := models.Role{ + RoleFields: models.RoleFields{ + Name: utils.PointerTo("test-role"), + }, + } + s.SetSelfSuperAdminForDB() + s.NoError(s.DB.Create(&suspendedRole).Error) + suspendedRoleAssignment := models.RoleAssignment{ + UserID: s.TestData.User_Suitable().ID, + RoleID: suspendedRole.ID, + RoleAssignmentFields: models.RoleAssignmentFields{ + Suspended: utils.PointerTo(true), + }, + } + s.NoError(s.DB.Create(&suspendedRoleAssignment).Error) + + // reload user to get updated assignments + var user models.User + s.NoError(s.DB.Scopes(models.ReadUserScope).First(&user, s.TestData.User_Suitable().ID).Error) + + // make this test extra explicit -- set these bools to true at key points + var suspendedRoleAssignmentOnUser, suspendedRoleAssignmentOmittedFromUserinfo bool + + userInfo := &oidc.UserInfo{} + s.NoError(s.storage.setUserinfo(userInfo, user.ID, []string{groupsClaim})) + if s.NotEmpty(userInfo.Claims) && s.Contains(userInfo.Claims, groupsClaim) && s.IsType([]string{}, userInfo.Claims[groupsClaim]) { + groups := userInfo.Claims[groupsClaim].([]string) + for _, ra := range user.Assignments { + if ra != nil { + if *ra.Role.Name == *suspendedRole.Name { + suspendedRoleAssignmentOnUser = true + } + + if ra.IsActive() { + s.Contains(groups, *ra.Role.Name) + } else if s.NotContains(groups, *ra.Role.Name) && *ra.Role.Name == *suspendedRole.Name { + suspendedRoleAssignmentOmittedFromUserinfo = true + } + } + } + } + + s.True(suspendedRoleAssignmentOnUser) + s.True(suspendedRoleAssignmentOmittedFromUserinfo) + }) +} diff --git a/sherlock/internal/oidc_models/test_helper.go b/sherlock/internal/oidc_models/test_helper.go new file mode 100644 index 000000000..279c58ecd --- /dev/null +++ b/sherlock/internal/oidc_models/test_helper.go @@ -0,0 +1,64 @@ +package oidc_models + +import ( + "crypto/rand" + "crypto/sha512" + "encoding/hex" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" + "golang.org/x/crypto/pbkdf2" + "gorm.io/gorm" + "time" +) + +type TestClientHelper struct{} + +func (h TestClientHelper) GeneratedClientRedirectURI() string { + return "http://localhost:8080/test/fake/redirect" +} + +func (h TestClientHelper) GeneratedClientPostLogoutRedirectURI() string { + return "http://localhost:8080/test/fake/postlogout" +} + +func (h TestClientHelper) GenerateClient(db *gorm.DB) (clientID string, clientSecret string, err error) { + clientIDBytes := make([]byte, 16) + _, err = rand.Read(clientIDBytes) + if err != nil { + return "", "", err + } + clientID = hex.EncodeToString(clientIDBytes) + + clientSecretBytes := make([]byte, 32) + _, err = rand.Read(clientSecretBytes) + if err != nil { + return "", "", err + } + clientSecret = hex.EncodeToString(clientSecretBytes) + + clientSecretSalt := make([]byte, 32) + _, err = rand.Read(clientSecretSalt) + if err != nil { + return "", "", err + } + + testClientHashIterations := 1_000 // Low low low value!!! Needs to be 210_000+ in production, this just makes tests run faster + clientSecretHash := pbkdf2.Key([]byte(clientSecret), clientSecretSalt, testClientHashIterations, 32, sha512.New) + + client := Client{ + ID: clientID, + ClientSecretHash: clientSecretHash, + ClientSecretSalt: clientSecretSalt, + ClientSecretIterations: testClientHashIterations, + ClientRedirectURIs: oidc.SpaceDelimitedArray{h.GeneratedClientRedirectURI()}, + ClientPostLogoutRedirectURIs: oidc.SpaceDelimitedArray{h.GeneratedClientPostLogoutRedirectURI()}, + ClientApplicationType: op.ApplicationTypeWeb, + ClientAuthMethod: oidc.AuthMethodBasic, + ClientIDTokenLifetime: (15 * time.Minute).Nanoseconds(), + ClientDevMode: true, + ClientClockSkew: (15 * time.Second).Nanoseconds(), + } + + err = db.Create(&client).Error + return clientID, clientSecret, err +} diff --git a/sherlock/internal/oidc_models/token.go b/sherlock/internal/oidc_models/token.go new file mode 100644 index 000000000..fa2569ff3 --- /dev/null +++ b/sherlock/internal/oidc_models/token.go @@ -0,0 +1,40 @@ +package oidc_models + +import ( + "github.com/broadinstitute/sherlock/go-shared/pkg/utils" + "github.com/broadinstitute/sherlock/sherlock/internal/models" + "github.com/google/uuid" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" + "time" +) + +var _ op.TokenRequest = &Token{} + +type Token struct { + ID uuid.UUID `gorm:"primaryKey"` + CreatedAt time.Time + + RefreshToken *RefreshToken `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE"` + RefreshTokenID *uuid.UUID + + Client *Client `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE"` + ClientID string // AKA Audience, Application ID + Scopes oidc.SpaceDelimitedArray + Expiry time.Time + + User *models.User + UserID uint +} + +func (t *Token) GetSubject() string { + return utils.UintToString(t.UserID) +} + +func (t *Token) GetAudience() []string { + return []string{t.ClientID} +} + +func (t *Token) GetScopes() []string { + return t.Scopes +} diff --git a/sherlock/internal/suitability_synchronization/suspend_role_assignments.go b/sherlock/internal/suitability_synchronization/suspend_role_assignments.go index 20721701b..d0f308f8d 100644 --- a/sherlock/internal/suitability_synchronization/suspend_role_assignments.go +++ b/sherlock/internal/suitability_synchronization/suspend_role_assignments.go @@ -73,9 +73,9 @@ func suspendRoleAssignments(ctx context.Context, db *gorm.DB) error { Suspended: utils.PointerTo(false), }, }).Error; err != nil { - errors = append(errors, fmt.Errorf("failed to un-suspend %s's assignment for %s: %w", assignment.User.NameOrEmailHandle(), *role.Name, err)) + errors = append(errors, fmt.Errorf("failed to un-suspend %s's assignment for %s: %w", assignment.User.NameOrUsername(), *role.Name, err)) } else { - summaries = append(summaries, fmt.Sprintf("un-suspended %s's assignment for %s", assignment.User.NameOrEmailHandle(), *role.Name)) + summaries = append(summaries, fmt.Sprintf("un-suspended %s's assignment for %s", assignment.User.NameOrUsername(), *role.Name)) } } else if !suitable && (assignment.Suspended == nil || !*assignment.Suspended) { roleIDsRequiringPropagation[roleID] = struct{}{} @@ -84,9 +84,9 @@ func suspendRoleAssignments(ctx context.Context, db *gorm.DB) error { Suspended: utils.PointerTo(true), }, }).Error; err != nil { - errors = append(errors, fmt.Errorf("failed to suspend %s's assignment for %s: %w", assignment.User.NameOrEmailHandle(), *role.Name, err)) + errors = append(errors, fmt.Errorf("failed to suspend %s's assignment for %s: %w", assignment.User.NameOrUsername(), *role.Name, err)) } else { - summaries = append(summaries, fmt.Sprintf("suspended %s's assignment for %s", assignment.User.NameOrEmailHandle(), *role.Name)) + summaries = append(summaries, fmt.Sprintf("suspended %s's assignment for %s", assignment.User.NameOrUsername(), *role.Name)) } } }