Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REF] Manually verify ID token using PyJWT instead of google_auth #386

Merged
merged 6 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion app/api/routers/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
router = APIRouter(prefix="/query", tags=["query"])

# Adapted from info in https://github.com/tiangolo/fastapi/discussions/9137#discussioncomment-5157382
# I believe for us this is purely for documentatation/a nice looking interactive API docs page,
# and doesn't actually have any bearing on the ID token validation process.
oauth2_scheme = OAuth2(
flows={
"implicit": {
"authorizationUrl": "https://accounts.google.com/o/oauth2/auth",
"authorizationUrl": "https://neurobagel.ca.auth0.com/authorize",
}
},
# Don't automatically error out when request is not authenticated, to support optional authentication
Expand Down
43 changes: 28 additions & 15 deletions app/api/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,55 @@

import os

import jwt
from fastapi import HTTPException, status
from fastapi.security.utils import get_authorization_scheme_param
from google.auth.exceptions import GoogleAuthError
from google.auth.transport import requests
from google.oauth2 import id_token
from jwt import PyJWKClient, PyJWTError

AUTH_ENABLED = os.environ.get("NB_ENABLE_AUTH", "True").lower() == "true"
CLIENT_ID = os.environ.get("NB_QUERY_CLIENT_ID", None)

KEYS_URL = "https://neurobagel.ca.auth0.com/.well-known/jwks.json"
ISSUER = "https://neurobagel.ca.auth0.com/"
# We only need to define the JWKS client once because get_signing_key_from_jwt will handle key rotations
# by automatically fetching updated keys when needed
# See https://github.com/jpadilla/pyjwt/blob/3ebbb22f30f2b1b41727b269a08b427e9a85d6bb/jwt/jwks_client.py#L96-L115
JWKS_CLIENT = PyJWKClient(KEYS_URL)


def check_client_id():
"""Check if the CLIENT_ID environment variable is set."""
# By default, if CLIENT_ID is not provided to verify_oauth2_token,
# Google will simply skip verifying the audience claim of ID tokens.
# This however can be a security risk, so we mandate that CLIENT_ID is set.
# The CLIENT_ID is needed to verify the audience claim of ID tokens.
if AUTH_ENABLED and CLIENT_ID is None:
raise ValueError(
"Authentication has been enabled (NB_ENABLE_AUTH) but the environment variable NB_QUERY_CLIENT_ID is not set. "
"Please set NB_QUERY_CLIENT_ID to the Google client ID for your Neurobagel query tool deployment, to verify the audience claim of ID tokens."
"Please set NB_QUERY_CLIENT_ID to the client ID for your Neurobagel query tool deployment, to verify the audience claim of ID tokens."
)


def verify_token(token: str):
"""Verify the Google ID token. Raise an HTTPException if the token is invalid."""
# Adapted from https://developers.google.com/identity/gsi/web/guides/verify-google-id-token#python
"""Verify the provided ID token. Raise an HTTPException if the token is invalid."""
try:
# Extract the token from the "Bearer" scheme
# (See https://github.com/tiangolo/fastapi/blob/master/fastapi/security/oauth2.py#L473-L485)
# TODO: Check also if scheme of token is "Bearer"?
alyssadai marked this conversation as resolved.
Show resolved Hide resolved
_, param = get_authorization_scheme_param(token)
id_info = id_token.verify_oauth2_token(
param, requests.Request(), CLIENT_ID
_, extracted_token = get_authorization_scheme_param(token)

# Determine which key was used to sign the token
# Adapted from https://pyjwt.readthedocs.io/en/stable/usage.html#retrieve-rsa-signing-keys-from-a-jwks-endpoint
signing_key = JWKS_CLIENT.get_signing_key_from_jwt(extracted_token)

jwt.decode(

Check warning on line 43 in app/api/security.py

View check run for this annotation

Codecov / codecov/patch

app/api/security.py#L43

Added line #L43 was not covered by tests
jwt=extracted_token,
key=signing_key,
options={
"verify_signature": True,
"require": ["aud", "iss", "exp", "iat"],
},
audience=CLIENT_ID,
issuer=ISSUER,
)
# TODO: Remove print statement or turn into logging
print("Token verified: ", id_info)
except (GoogleAuthError, ValueError) as exc:
except (PyJWTError, ValueError) as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Invalid token: {exc}",
Expand Down
5 changes: 4 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@ anyio==3.6.2
attrs==22.1.0
cachetools==5.3.3
certifi==2024.7.4
cffi==1.17.1
cfgv==3.3.1
charset-normalizer==3.3.2
click==8.1.3
colorama==0.4.6
coverage==7.0.0
cryptography==44.0.0
distlib==0.3.6
exceptiongroup==1.0.4
fastapi==0.115.4
filelock==3.8.0
google-auth==2.32.0
h11==0.14.0
httpcore==0.16.2
httpx==0.23.1
Expand All @@ -28,7 +29,9 @@ pluggy==1.0.0
pre-commit==3.6.0
pyasn1==0.6.0
pyasn1_modules==0.4.0
pycparser==2.22
pydantic==1.10.13
PyJWT==2.10.1
pyparsing==3.0.9
pytest==7.2.0
python-dateutil==2.8.2
Expand Down
Loading