Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Authorize Generic user for Streamlit #3672

Merged
merged 5 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion vro-streamlit/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ testpaths = [
# Environment variables to use in pytests
env = [
"ENV=test-environment",
"DEBUG=True"
"GITHUB_CLIENT_ID=github-client-id",
"GITHUB_CLIENT_SECRET=github-client-secret",
]

[tool.coverage.run]
Expand Down
18 changes: 10 additions & 8 deletions vro-streamlit/src/dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
isort>=5.13
mypy>=1.11
pre-commit>=4.0
pytest>=7.4
pytest-cov>=5.0
pytest-env>=1.0
pytest-mock>=3.11
ruff>=0.6
isort==5.13.*
mypy==1.11.*
pre-commit==4.0.*
pytest==7.4.*
pytest-cov==5.0.*
pytest-env==1.0.*
pytest-mock==3.11.*
requests-mock==1.10.*
ruff==0.6.*
types-requests==2.32.*
4 changes: 3 additions & 1 deletion vro-streamlit/src/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
pandas==2.2.*
starlette>=0.40.0
requests==2.32.*
starlette==0.40.*
streamlit==1.39.*
validators==0.34.*
6 changes: 6 additions & 0 deletions vro-streamlit/src/vro_streamlit/auth/auth_exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class ForbiddenException(Exception):
pass


class UnauthorizedException(Exception):
pass
64 changes: 64 additions & 0 deletions vro-streamlit/src/vro_streamlit/auth/auth_frontend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import streamlit as st

import vro_streamlit.auth.auth_service as auth_service
from vro_streamlit.auth.response_models import DeviceFlowResponse

AUTH_LOG_IN_BUTTON = 'auth_log_in_button'
AUTH_LOG_OUT_BUTTON = 'auth_log_out_button'


def log_out() -> None:
if st.button('Log Out', key=AUTH_LOG_OUT_BUTTON):
if st.session_state.user and auth_service.log_out(st.session_state.user.access_token):
st.success('Logged out successfully.')

st.session_state.user = None
st.rerun()


def initiate_login_flow() -> DeviceFlowResponse:
"""Initiate the login flow and return device flow data."""
return auth_service.initiate_device_flow()


def display_authorization_instructions(device_flow_data: DeviceFlowResponse) -> None:
"""Display instructions for the user to authorize the application."""
st.write('### Authorization Steps')
st.write("1. Visit [GitHub's verification page]({}).".format(device_flow_data.verification_uri))
st.write('2. Enter the code:')
st.code(device_flow_data.user_code, language='markdown')
st.write('3. Complete the authorization within **{} minutes**.'.format(device_flow_data.expires_in // 60))


def complete_login_flow(device_flow_data: DeviceFlowResponse) -> None:
"""Complete the login flow by polling for token and fetching user info."""
try:
access_token = auth_service.poll_for_token(device_flow_data.device_code, device_flow_data.interval)
# Step 3: Get Validated User
st.session_state.user = auth_service.get_validated_user(access_token)
st.session_state.login_success = True
except Exception as e:
st.session_state.user = None
st.session_state.login_failed = f'Login failed: ({e.__class__.__name__}) {e}'


def log_in() -> None:
if st.button('Log In with GitHub', key=AUTH_LOG_IN_BUTTON, help='Log in to access more features.'):
device_flow_data = initiate_login_flow()
display_authorization_instructions(device_flow_data)
# Step 2: Poll for Access Token
with st.spinner('Waiting for authorization...'):
complete_login_flow(device_flow_data)

st.rerun()


def show() -> None:
if not st.session_state.user:
log_in()
else:
log_out()


if __name__ == '__main__':
show()
114 changes: 110 additions & 4 deletions vro-streamlit/src/vro_streamlit/auth/auth_service.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,115 @@
import logging
import time
from typing import Any

import requests

from vro_streamlit.auth.auth_exception import ForbiddenException, UnauthorizedException
from vro_streamlit.auth.response_models import DeviceFlowResponse, UserInfoResponse
from vro_streamlit.auth.user import User
from vro_streamlit.config import GITHUB_CLIENT_ID, GITHUB_CLIENT_SECRET

DEVICE_CODE_URL = 'https://github.com/login/device/code'
TOKEN_URL = 'https://github.com/login/oauth/access_token'
USER_URL = 'https://api.github.com/user'
USER_ORG_URL = 'https://api.github.com/user/orgs'
ORG = 'department-of-veterans-affairs'
dfitchett marked this conversation as resolved.
Show resolved Hide resolved


def initiate_device_flow() -> DeviceFlowResponse:
"""Initiate the OAuth Device Flow and return a DeviceFlowResponse object."""
try:
response = requests.post(
DEVICE_CODE_URL,
data={'client_id': GITHUB_CLIENT_ID, 'scope': 'read:user,read:org'},
headers={'Accept': 'application/json'},
)
response.raise_for_status()
data = response.json()
return DeviceFlowResponse(
data['device_code'],
data['user_code'],
data['verification_uri'],
int(data['expires_in']),
int(data['interval']),
)
except requests.RequestException as e:
logging.error(f'Failed to initiate device flow: {e}')
raise


def poll_for_token(device_code: str, interval: int) -> str:
"""Poll GitHub for an access token until the user authorizes the app."""
while True:
response = requests.post(
TOKEN_URL,
data={
'client_id': GITHUB_CLIENT_ID,
'device_code': device_code,
'grant_type': 'urn:ietf:params:oauth:grant-type:device_code',
},
headers={'Accept': 'application/json'},
)

data = response.json()
if 'access_token' in data:
return str(data['access_token'])
elif data.get('error') == 'authorization_pending' or data.get('error') == 'slow_down':
interval = data.get('interval', interval)
time.sleep(interval)
else:
logging.error(f"Failed to fetch access token info: {data.get('error_description', 'Unknown error')}")
raise UnauthorizedException(f"Error while requesting access token: {data.get('error_description', 'Unknown error')}")


def fetch_user_info(access_token: str) -> UserInfoResponse:
"""Fetch user info using the access token."""
try:
response = requests.get(USER_URL, headers={'Authorization': f'token {access_token}'})
response.raise_for_status()
data = response.json()
return UserInfoResponse(**data)
except requests.RequestException as e:
logging.error(f'Failed to fetch user info: {e}')
raise


def fetch_org_info(access_token: str) -> Any:
"""Fetch user's organization info using the access token."""
try:
response = requests.get(
USER_ORG_URL,
headers={'Authorization': f'token {access_token}'},
)
response.raise_for_status()
return response.json()
except requests.RequestException as e:
logging.error(f'Failed to fetch user organization info: {e}')
raise


def get_validated_user(access_token: str) -> User:
org_info = fetch_org_info(access_token)
orgs = [org['login'] for org in org_info]
if ORG not in orgs:
raise ForbiddenException(f'You must be an authorized member of the {ORG} organization to use this app!')

def log_in() -> User:
return User('test')
user_info = fetch_user_info(access_token)
return User(access_token, user_info.login, user_info.avatar_url)


def log_out() -> bool:
return True
def log_out(access_token: str) -> bool:
"""Revoke the access token."""
REVOKE_URL = f'https://api.github.com/applications/{GITHUB_CLIENT_ID}/token'
try:
response = requests.delete(
REVOKE_URL,
auth=(GITHUB_CLIENT_ID, GITHUB_CLIENT_SECRET), # type: ignore[arg-type]
json={'access_token': access_token},
headers={'Accept': 'application/vnd.github+json'},
)
response.raise_for_status()
return response.status_code == 204
except requests.RequestException as e:
logging.error(f'Failed to revoke token: {e}')
return False
18 changes: 18 additions & 0 deletions vro-streamlit/src/vro_streamlit/auth/response_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import Any


class DeviceFlowResponse:
def __init__(self, device_code: str, user_code: str, verification_uri: str, expires_in: int, interval: int):
self.device_code: str = device_code
self.user_code: str = user_code
self.verification_uri: str = verification_uri
self.expires_in: int = expires_in
self.interval: int = interval


class UserInfoResponse:
def __init__(self, login: str, avatar_url: str, **kwargs: Any):
self.login = login
self.avatar_url = avatar_url
for key, value in kwargs.items():
setattr(self, key, value)
6 changes: 4 additions & 2 deletions vro-streamlit/src/vro_streamlit/auth/user.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
class User:
def __init__(self, username: str):
self.username: str = username
def __init__(self, access_token: str, username: str, avatar_url: str | None = None) -> None:
self.access_token = access_token
self.username = username
self.avatar_url = avatar_url
5 changes: 5 additions & 0 deletions vro-streamlit/src/vro_streamlit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,8 @@

ENV = getenv('ENV', 'local')
DEBUG = bool(strtobool(getenv('DEBUG', 'False')))

GITHUB_CLIENT_ID = getenv('GITHUB_CLIENT_ID')
GITHUB_CLIENT_SECRET = getenv('GITHUB_CLIENT_SECRET')
if not GITHUB_CLIENT_ID or not GITHUB_CLIENT_SECRET:
raise EnvironmentError('GITHUB_CLIENT_ID and GITHUB_CLIENT_SECRET must be set as environment variables.')
22 changes: 8 additions & 14 deletions vro-streamlit/src/vro_streamlit/directory/home.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,28 @@
from importlib.resources import files

import streamlit as st
import validators

import vro_streamlit.auth.auth_service as auth
from vro_streamlit.auth import auth_frontend

LOGIN_BUTTON = 'home_login_button'
LOGO = files('vro_streamlit').joinpath('static/streamlit-logo.png').read_bytes()


def update_login_status() -> None:
if not st.session_state.user:
st.session_state.user = auth.log_in()
else:
if auth.log_out():
st.session_state.user = None


def show() -> None:
col1, col2 = st.columns([0.04, 0.96])
col1.image(LOGO, width=100)
col2.header('Home')
st.subheader('Welcome to the home page!')

user = st.session_state.get('user')
user = st.session_state.user
if user:
st.write(f'Hello, {user.username}!')
st.button('Log Out', key=LOGIN_BUTTON, on_click=update_login_status)
if validators.url(user.avatar_url):
st.image(user.avatar_url, width=50)
st.write(f'Hello, **{user.username}**!')
else:
st.write('Please Log In')
st.button('Log In', key=LOGIN_BUTTON, on_click=update_login_status)

auth_frontend.show()


if __name__ == '__main__':
Expand Down
Loading