diff --git a/vro-streamlit/pyproject.toml b/vro-streamlit/pyproject.toml index 7687007635..ca982aff67 100644 --- a/vro-streamlit/pyproject.toml +++ b/vro-streamlit/pyproject.toml @@ -23,7 +23,8 @@ testpaths = [ # Environment variables to use in pytests env = [ "ENV=test-environment", - "DEBUG=True" + "GITHUB_CLIENT_ID=github-client-id", + "GITHUB_CLIENT_SECRET=github-client-secret", ] [tool.coverage.run] diff --git a/vro-streamlit/src/dev-requirements.txt b/vro-streamlit/src/dev-requirements.txt index f3b019a268..80ee7f7ab4 100644 --- a/vro-streamlit/src/dev-requirements.txt +++ b/vro-streamlit/src/dev-requirements.txt @@ -1,8 +1,10 @@ -isort>=5.13 -mypy>=1.11 -pre-commit>=4.0 -pytest>=7.4 -pytest-cov>=5.0 -pytest-env>=1.0 -pytest-mock>=3.11 -ruff>=0.6 +isort==5.13.* +mypy==1.11.* +pre-commit==4.0.* +pytest==7.4.* +pytest-cov==5.0.* +pytest-env==1.0.* +pytest-mock==3.11.* +requests-mock==1.10.* +ruff==0.6.* +types-requests==2.32.* diff --git a/vro-streamlit/src/requirements.txt b/vro-streamlit/src/requirements.txt index 54451c3ee0..3d37e6ce6b 100644 --- a/vro-streamlit/src/requirements.txt +++ b/vro-streamlit/src/requirements.txt @@ -1,3 +1,5 @@ pandas==2.2.* -starlette>=0.40.0 +requests==2.32.* +starlette==0.40.* streamlit==1.39.* +validators==0.34.* diff --git a/vro-streamlit/src/vro_streamlit/auth/auth_exception.py b/vro-streamlit/src/vro_streamlit/auth/auth_exception.py new file mode 100644 index 0000000000..9de6558d47 --- /dev/null +++ b/vro-streamlit/src/vro_streamlit/auth/auth_exception.py @@ -0,0 +1,6 @@ +class ForbiddenException(Exception): + pass + + +class UnauthorizedException(Exception): + pass diff --git a/vro-streamlit/src/vro_streamlit/auth/auth_frontend.py b/vro-streamlit/src/vro_streamlit/auth/auth_frontend.py new file mode 100644 index 0000000000..8677d3e6bb --- /dev/null +++ b/vro-streamlit/src/vro_streamlit/auth/auth_frontend.py @@ -0,0 +1,64 @@ +import streamlit as st + +import vro_streamlit.auth.auth_service as auth_service +from vro_streamlit.auth.response_models import DeviceFlowResponse + +AUTH_LOG_IN_BUTTON = 'auth_log_in_button' +AUTH_LOG_OUT_BUTTON = 'auth_log_out_button' + + +def log_out() -> None: + if st.button('Log Out', key=AUTH_LOG_OUT_BUTTON): + if st.session_state.user and auth_service.log_out(st.session_state.user.access_token): + st.success('Logged out successfully.') + + st.session_state.user = None + st.rerun() + + +def initiate_login_flow() -> DeviceFlowResponse: + """Initiate the login flow and return device flow data.""" + return auth_service.initiate_device_flow() + + +def display_authorization_instructions(device_flow_data: DeviceFlowResponse) -> None: + """Display instructions for the user to authorize the application.""" + st.write('### Authorization Steps') + st.write("1. Visit [GitHub's verification page]({}).".format(device_flow_data.verification_uri)) + st.write('2. Enter the code:') + st.code(device_flow_data.user_code, language='markdown') + st.write('3. Complete the authorization within **{} minutes**.'.format(device_flow_data.expires_in // 60)) + + +def complete_login_flow(device_flow_data: DeviceFlowResponse) -> None: + """Complete the login flow by polling for token and fetching user info.""" + try: + access_token = auth_service.poll_for_token(device_flow_data.device_code, device_flow_data.interval) + # Step 3: Get Validated User + st.session_state.user = auth_service.get_validated_user(access_token) + st.session_state.login_success = True + except Exception as e: + st.session_state.user = None + st.session_state.login_failed = f'Login failed: ({e.__class__.__name__}) {e}' + + +def log_in() -> None: + if st.button('Log In with GitHub', key=AUTH_LOG_IN_BUTTON, help='Log in to access more features.'): + device_flow_data = initiate_login_flow() + display_authorization_instructions(device_flow_data) + # Step 2: Poll for Access Token + with st.spinner('Waiting for authorization...'): + complete_login_flow(device_flow_data) + + st.rerun() + + +def show() -> None: + if not st.session_state.user: + log_in() + else: + log_out() + + +if __name__ == '__main__': + show() diff --git a/vro-streamlit/src/vro_streamlit/auth/auth_service.py b/vro-streamlit/src/vro_streamlit/auth/auth_service.py index eba1817fea..ba57dc691a 100644 --- a/vro-streamlit/src/vro_streamlit/auth/auth_service.py +++ b/vro-streamlit/src/vro_streamlit/auth/auth_service.py @@ -1,9 +1,115 @@ +import logging +import time +from typing import Any + +import requests + +from vro_streamlit.auth.auth_exception import ForbiddenException, UnauthorizedException +from vro_streamlit.auth.response_models import DeviceFlowResponse, UserInfoResponse from vro_streamlit.auth.user import User +from vro_streamlit.config import GITHUB_CLIENT_ID, GITHUB_CLIENT_SECRET + +DEVICE_CODE_URL = 'https://github.com/login/device/code' +TOKEN_URL = 'https://github.com/login/oauth/access_token' +USER_URL = 'https://api.github.com/user' +USER_ORG_URL = 'https://api.github.com/user/orgs' +ORG = 'department-of-veterans-affairs' + + +def initiate_device_flow() -> DeviceFlowResponse: + """Initiate the OAuth Device Flow and return a DeviceFlowResponse object.""" + try: + response = requests.post( + DEVICE_CODE_URL, + data={'client_id': GITHUB_CLIENT_ID, 'scope': 'read:user,read:org'}, + headers={'Accept': 'application/json'}, + ) + response.raise_for_status() + data = response.json() + return DeviceFlowResponse( + data['device_code'], + data['user_code'], + data['verification_uri'], + int(data['expires_in']), + int(data['interval']), + ) + except requests.RequestException as e: + logging.error(f'Failed to initiate device flow: {e}') + raise + + +def poll_for_token(device_code: str, interval: int) -> str: + """Poll GitHub for an access token until the user authorizes the app.""" + while True: + response = requests.post( + TOKEN_URL, + data={ + 'client_id': GITHUB_CLIENT_ID, + 'device_code': device_code, + 'grant_type': 'urn:ietf:params:oauth:grant-type:device_code', + }, + headers={'Accept': 'application/json'}, + ) + + data = response.json() + if 'access_token' in data: + return str(data['access_token']) + elif data.get('error') == 'authorization_pending' or data.get('error') == 'slow_down': + interval = data.get('interval', interval) + time.sleep(interval) + else: + logging.error(f"Failed to fetch access token info: {data.get('error_description', 'Unknown error')}") + raise UnauthorizedException(f"Error while requesting access token: {data.get('error_description', 'Unknown error')}") + + +def fetch_user_info(access_token: str) -> UserInfoResponse: + """Fetch user info using the access token.""" + try: + response = requests.get(USER_URL, headers={'Authorization': f'token {access_token}'}) + response.raise_for_status() + data = response.json() + return UserInfoResponse(**data) + except requests.RequestException as e: + logging.error(f'Failed to fetch user info: {e}') + raise + + +def fetch_org_info(access_token: str) -> Any: + """Fetch user's organization info using the access token.""" + try: + response = requests.get( + USER_ORG_URL, + headers={'Authorization': f'token {access_token}'}, + ) + response.raise_for_status() + return response.json() + except requests.RequestException as e: + logging.error(f'Failed to fetch user organization info: {e}') + raise + +def get_validated_user(access_token: str) -> User: + org_info = fetch_org_info(access_token) + orgs = [org['login'] for org in org_info] + if ORG not in orgs: + raise ForbiddenException(f'You must be an authorized member of the {ORG} organization to use this app!') -def log_in() -> User: - return User('test') + user_info = fetch_user_info(access_token) + return User(access_token, user_info.login, user_info.avatar_url) -def log_out() -> bool: - return True +def log_out(access_token: str) -> bool: + """Revoke the access token.""" + REVOKE_URL = f'https://api.github.com/applications/{GITHUB_CLIENT_ID}/token' + try: + response = requests.delete( + REVOKE_URL, + auth=(GITHUB_CLIENT_ID, GITHUB_CLIENT_SECRET), # type: ignore[arg-type] + json={'access_token': access_token}, + headers={'Accept': 'application/vnd.github+json'}, + ) + response.raise_for_status() + return response.status_code == 204 + except requests.RequestException as e: + logging.error(f'Failed to revoke token: {e}') + return False diff --git a/vro-streamlit/src/vro_streamlit/auth/response_models.py b/vro-streamlit/src/vro_streamlit/auth/response_models.py new file mode 100644 index 0000000000..214076b891 --- /dev/null +++ b/vro-streamlit/src/vro_streamlit/auth/response_models.py @@ -0,0 +1,18 @@ +from typing import Any + + +class DeviceFlowResponse: + def __init__(self, device_code: str, user_code: str, verification_uri: str, expires_in: int, interval: int): + self.device_code: str = device_code + self.user_code: str = user_code + self.verification_uri: str = verification_uri + self.expires_in: int = expires_in + self.interval: int = interval + + +class UserInfoResponse: + def __init__(self, login: str, avatar_url: str, **kwargs: Any): + self.login = login + self.avatar_url = avatar_url + for key, value in kwargs.items(): + setattr(self, key, value) diff --git a/vro-streamlit/src/vro_streamlit/auth/user.py b/vro-streamlit/src/vro_streamlit/auth/user.py index c4d2b3799b..5cf34931ce 100644 --- a/vro-streamlit/src/vro_streamlit/auth/user.py +++ b/vro-streamlit/src/vro_streamlit/auth/user.py @@ -1,3 +1,5 @@ class User: - def __init__(self, username: str): - self.username: str = username + def __init__(self, access_token: str, username: str, avatar_url: str | None = None) -> None: + self.access_token = access_token + self.username = username + self.avatar_url = avatar_url diff --git a/vro-streamlit/src/vro_streamlit/config.py b/vro-streamlit/src/vro_streamlit/config.py index 15bebe37a5..0a5f4f2d6f 100644 --- a/vro-streamlit/src/vro_streamlit/config.py +++ b/vro-streamlit/src/vro_streamlit/config.py @@ -3,3 +3,8 @@ ENV = getenv('ENV', 'local') DEBUG = bool(strtobool(getenv('DEBUG', 'False'))) + +GITHUB_CLIENT_ID = getenv('GITHUB_CLIENT_ID') +GITHUB_CLIENT_SECRET = getenv('GITHUB_CLIENT_SECRET') +if not GITHUB_CLIENT_ID or not GITHUB_CLIENT_SECRET: + raise EnvironmentError('GITHUB_CLIENT_ID and GITHUB_CLIENT_SECRET must be set as environment variables.') diff --git a/vro-streamlit/src/vro_streamlit/directory/home.py b/vro-streamlit/src/vro_streamlit/directory/home.py index 10c800a005..b540249d19 100644 --- a/vro-streamlit/src/vro_streamlit/directory/home.py +++ b/vro-streamlit/src/vro_streamlit/directory/home.py @@ -1,34 +1,28 @@ from importlib.resources import files import streamlit as st +import validators -import vro_streamlit.auth.auth_service as auth +from vro_streamlit.auth import auth_frontend -LOGIN_BUTTON = 'home_login_button' LOGO = files('vro_streamlit').joinpath('static/streamlit-logo.png').read_bytes() -def update_login_status() -> None: - if not st.session_state.user: - st.session_state.user = auth.log_in() - else: - if auth.log_out(): - st.session_state.user = None - - def show() -> None: col1, col2 = st.columns([0.04, 0.96]) col1.image(LOGO, width=100) col2.header('Home') st.subheader('Welcome to the home page!') - user = st.session_state.get('user') + user = st.session_state.user if user: - st.write(f'Hello, {user.username}!') - st.button('Log Out', key=LOGIN_BUTTON, on_click=update_login_status) + if validators.url(user.avatar_url): + st.image(user.avatar_url, width=50) + st.write(f'Hello, **{user.username}**!') else: st.write('Please Log In') - st.button('Log In', key=LOGIN_BUTTON, on_click=update_login_status) + + auth_frontend.show() if __name__ == '__main__': diff --git a/vro-streamlit/src/vro_streamlit/main.py b/vro-streamlit/src/vro_streamlit/main.py index 108d4a2e5b..9afbed2143 100644 --- a/vro-streamlit/src/vro_streamlit/main.py +++ b/vro-streamlit/src/vro_streamlit/main.py @@ -1,48 +1,56 @@ +import logging + import streamlit as st import vro_streamlit.auth.auth_service as auth import vro_streamlit.config as config import vro_streamlit.directory.home as home +from vro_streamlit.config import DEBUG from vro_streamlit.directory.bie_events import claim_events, contention_events -LOGIN_BUTTON = 'sidebar_login_button' -LOGOUT_BUTTON = 'sidebar_logout_button' +LOG_OUT_BUTTON = 'sidebar_log_out_button' st.set_page_config(page_title='VRO Streamlit', layout='wide') +home_page = st.Page(home.show, title='Home', default=True) +# BIE events +bie_events = [ + st.Page(claim_events.show, title='Claim Events', url_path='/claim_events'), + st.Page(contention_events.show, title='Contention Events', url_path='/contention_events'), +] +# examples +examples = [ + st.Page('directory/examples/text.py', title='Text'), + st.Page('directory/examples/dataframes.py', title='Dataframes'), + st.Page('directory/examples/water_quality.py', title='Water Quality'), +] def init_session_state() -> None: st.session_state.setdefault('database_connected', True) + st.session_state.setdefault('logged_in', False) st.session_state.setdefault('user', None) -def update_login_status() -> None: - if not st.session_state.user: - st.session_state.user = auth.log_in() +def create_navigation() -> None: + if st.session_state.user: + nav = st.navigation({'Main': [home_page], 'BIE Events': bie_events, 'Examples': examples}) else: - if auth.log_out(): - st.session_state.user = None + nav = st.navigation({'Main': [home_page], 'Examples': examples}) + nav.run() -def create_navigation() -> None: - home_page = st.Page(home.show, title='Home', default=True) - # BIE events - bie_events = [ - st.Page(claim_events.show, title='Claim Events', url_path='/claim_events'), - st.Page(contention_events.show, title='Contention Events', url_path='/contention_events'), - ] - # examples - examples = [ - st.Page('directory/examples/text.py', title='Text'), - st.Page('directory/examples/dataframes.py', title='Dataframes'), - st.Page('directory/examples/water_quality.py', title='Water Quality'), - ] - nav = st.navigation({'Main': [home_page], 'BIE Events': bie_events, 'Examples': examples}) - nav.run() +def log_out() -> None: + try: + auth.log_out(st.session_state.user.access_token) + st.success('Logged out successfully.') + except Exception as e: + logging.error(f'Failed to revoke token, but logged out anyways: {e}') + st.session_state.user = None def create_sidebar() -> None: with st.sidebar: + user = st.session_state.user with st.container(border=True): col1, col2 = st.columns(2) with col1: @@ -51,15 +59,36 @@ def create_sidebar() -> None: st.markdown('Authorized', help='User authorization status') with col2: st.markdown(f'`{config.ENV}`') - st.markdown(':large_green_circle:' if st.session_state.database_connected else ':red_circle:', unsafe_allow_html=True) - st.markdown(':large_green_circle:' if st.session_state.user else ':red_circle:', unsafe_allow_html=True) + st.markdown(':large_green_circle:' if st.session_state.database_connected else ':red_circle:') + st.markdown(':large_green_circle:' if user is not None else ':red_circle:') + + if user is not None: + st.button('Log Out', use_container_width=True, on_click=log_out, key=LOG_OUT_BUTTON) + + +def show_login_status() -> None: + if 'login_success' in st.session_state: + st.success('Login successful!') + del st.session_state['login_success'] + + if 'login_failed' in st.session_state: + st.error(st.session_state.login_failed) + del st.session_state['login_failed'] + - button_text = 'Log Out' if st.session_state.user else 'Log In' - button_key = LOGOUT_BUTTON if st.session_state.user else LOGIN_BUTTON - st.button(button_text, use_container_width=True, on_click=update_login_status, key=button_key) +def show_debug() -> None: # pragma: no cover + if DEBUG: + user = st.session_state.user + with st.sidebar.container(border=True): + st.write('Session State') + ss_dict = st.session_state.to_dict() + ss_dict['user'] = user.__dict__ if user else None + st.json(ss_dict) if __name__ == '__main__': init_session_state() create_sidebar() create_navigation() + show_login_status() + show_debug() diff --git a/vro-streamlit/src/vro_streamlit/util/__init__.py b/vro-streamlit/src/vro_streamlit/util/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/vro-streamlit/test/auth/__init__.py b/vro-streamlit/test/auth/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/vro-streamlit/test/auth/test_auth_frontend.py b/vro-streamlit/test/auth/test_auth_frontend.py new file mode 100644 index 0000000000..42df58a62b --- /dev/null +++ b/vro-streamlit/test/auth/test_auth_frontend.py @@ -0,0 +1,55 @@ +from test.conftest import ACCESS_TOKEN, APP_TEST_TIMEOUT, AVATAR_URL, USERNAME +from test.util import assert_markdown_contains_values + +import pytest +from streamlit.testing.v1 import AppTest + +from vro_streamlit.auth import auth_frontend +from vro_streamlit.auth.response_models import DeviceFlowResponse, UserInfoResponse +from vro_streamlit.auth.user import User + + +@pytest.fixture() +def app_test(): + app_test = AppTest.from_file('src/vro_streamlit/auth/auth_frontend.py', default_timeout=APP_TEST_TIMEOUT) + app_test.session_state.user = None + return app_test + + +@pytest.fixture() +def auth_service(mocker): + auth_service = mocker.patch('vro_streamlit.auth.auth_service') + auth_service.initiate_device_flow.return_value = DeviceFlowResponse('device_code', 'user_code', 'verification_uri', 600, 5) + auth_service.poll_for_token.return_value = ACCESS_TOKEN + auth_service.fetch_user_info.return_value = UserInfoResponse(login=USERNAME, avatar_url=AVATAR_URL) + auth_service.get_validated_user.return_value = User(ACCESS_TOKEN, USERNAME, AVATAR_URL) + auth_service.log_out.return_value = True + return auth_service + + +def test_show_user_not_logged_in(app_test, auth_service) -> None: + app_test.run() + assert not app_test.exception + + # Click log in button + app_test.button(key=auth_frontend.AUTH_LOG_IN_BUTTON).click().run() + assert not app_test.exception + assert_markdown_contains_values(app_test.code, 'user_code') + + auth_service.initiate_device_flow.assert_called_once() + auth_service.poll_for_token.assert_called_once() + auth_service.get_validated_user.assert_called_once() + app_test.session_state.user = User(ACCESS_TOKEN, USERNAME, AVATAR_URL) + + +def test_show_user_logged_in(app_test, auth_service) -> None: + app_test.session_state.user = User(ACCESS_TOKEN, USERNAME, AVATAR_URL) + app_test.run() + assert not app_test.exception + + # Click log out button + app_test.button(key=auth_frontend.AUTH_LOG_OUT_BUTTON).click().run() + assert not app_test.exception + + auth_service.log_out.assert_called_once() + app_test.session_state.user = User(ACCESS_TOKEN, USERNAME, AVATAR_URL) diff --git a/vro-streamlit/test/auth/test_auth_service.py b/vro-streamlit/test/auth/test_auth_service.py new file mode 100644 index 0000000000..b56d07fed4 --- /dev/null +++ b/vro-streamlit/test/auth/test_auth_service.py @@ -0,0 +1,118 @@ +from test.conftest import ACCESS_TOKEN, AVATAR_URL, USERNAME +from unittest.mock import patch + +import pytest +import requests + +from vro_streamlit.auth.auth_exception import UnauthorizedException +from vro_streamlit.auth.auth_service import ( + ORG, + fetch_org_info, + fetch_user_info, + initiate_device_flow, + log_out, + poll_for_token, +) + + +def test_initiate_device_flow_returns_correct_data(): + with patch('requests.post') as mock_post: + mock_post.return_value.json.return_value = { + 'device_code': 'device_code', + 'user_code': 'user_code', + 'verification_uri': 'verification_uri', + 'expires_in': 900, + 'interval': 5, + } + device_flow = initiate_device_flow() + assert device_flow.device_code == 'device_code' + assert device_flow.user_code == 'user_code' + assert device_flow.verification_uri == 'verification_uri' + assert device_flow.expires_in == 900 + assert device_flow.interval == 5 + + +def test_initiate_device_flow_raises_exception_on_failure(): + with patch('requests.post') as mock_post: + response = mock_post.return_value + response.status_code = 400 + response.raise_for_status.side_effect = requests.HTTPError + with pytest.raises(requests.HTTPError): + initiate_device_flow() + + +def test_poll_for_token_returns_access_token(): + with patch('requests.post') as mock_post: + mock_post.return_value.json.return_value = {'access_token': 'access_token'} + access_token = poll_for_token('device_code', 1) + assert access_token == 'access_token' + + +def test_poll_for_token_returns_authorization_pending(): + with patch('requests.post') as mock_post: + mock_post.return_value.json.side_effect = [{'error': 'authorization_pending'}, {'access_token': ACCESS_TOKEN}] + poll_for_token('device_code', 1) + assert mock_post.call_count == 2 + + +def test_poll_for_token_returns_slowdown(): + with patch('time.sleep') as mock_sleep, patch('requests.post') as mock_post: + mock_post.return_value.json.side_effect = [{'error': 'slow_down', 'interval': 2}, {'access_token': ACCESS_TOKEN}] + poll_for_token('device_code', 1) + assert mock_post.call_count == 2 + assert mock_sleep.call_count == 1 + assert mock_sleep.call_args[0][0] == 2 + + +def test_poll_for_token_raises_exception_on_error(): + with patch('requests.post') as mock_post: + mock_post.return_value.json.return_value = {'error': 'invalid_request', 'error_description': 'Invalid request'} + with pytest.raises(UnauthorizedException, match='Error while requesting access token: Invalid request'): + poll_for_token('device_code', 1) + + +def test_fetch_user_info_returns_data(): + with patch('requests.get') as mock_get: + mock_get.return_value.json.return_value = {'login': USERNAME, 'avatar_url': AVATAR_URL} + user_info = fetch_user_info('access_token') + assert user_info.login == USERNAME + assert user_info.avatar_url == AVATAR_URL + + +def test_fetch_user_info_raises_exception_on_failure(): + with patch('requests.get') as mock_get: + response = mock_get.return_value + response.status_code = 400 + response.raise_for_status.side_effect = requests.HTTPError + with pytest.raises(requests.HTTPError): + fetch_user_info('access_token') + + +def test_fetch_org_info_returns_data(): + with patch('requests.get') as mock_get: + mock_get.return_value.json.return_value = {'login': ORG} + org_info = fetch_org_info('access_token') + assert org_info['login'] == ORG + + +def test_fetch_org_info_raises_exception_on_failure(): + with patch('requests.get') as mock_get: + response = mock_get.return_value + response.status_code = 400 + response.raise_for_status.side_effect = requests.HTTPError + with pytest.raises(requests.HTTPError): + fetch_org_info('access_token') + + +def test_log_out_returns_true_on_success(): + with patch('requests.delete') as mock_delete: + mock_delete.return_value.status_code = 204 + assert log_out('access_token') is True + + +def test_log_out_raises_exception_on_failure(): + with patch('requests.delete') as mock_delete: + response = mock_delete.return_value + response.return_value.status_code = 400 + response.raise_for_status.side_effect = requests.HTTPError + assert log_out('access_token') is False diff --git a/vro-streamlit/test/conftest.py b/vro-streamlit/test/conftest.py index 550223ff9f..e9fda1bf26 100644 --- a/vro-streamlit/test/conftest.py +++ b/vro-streamlit/test/conftest.py @@ -1,18 +1,7 @@ -from unittest.mock import Mock - -import pytest - -from vro_streamlit.auth.user import User - """Pytest configuration. This file is automatically loaded by pytest before any test.""" + APP_TEST_TIMEOUT = 5 USERNAME = 'test' - - -@pytest.fixture(autouse=True) -def auth_service(mocker): - auth_service = Mock() - auth_service.log_in.return_value = User(USERNAME) - auth_service.log_out.return_value = True - return mocker.patch('vro_streamlit.auth.auth_service', auth_service) +AVATAR_URL = 'http://test.com/avatar.png' +ACCESS_TOKEN = 'test_access_token' diff --git a/vro-streamlit/test/directory/test_home.py b/vro-streamlit/test/directory/test_home.py index 792b50e877..6714643a8d 100644 --- a/vro-streamlit/test/directory/test_home.py +++ b/vro-streamlit/test/directory/test_home.py @@ -2,10 +2,15 @@ from streamlit.testing.v1 import AppTest from vro_streamlit.auth.user import User -from vro_streamlit.directory.home import LOGIN_BUTTON +from vro_streamlit.directory import home # noqa: F401 -from ..conftest import APP_TEST_TIMEOUT, USERNAME -from ..util import assert_button_contains_label, assert_markdown_contains_values +from ..conftest import ACCESS_TOKEN, APP_TEST_TIMEOUT, USERNAME +from ..util import assert_markdown_contains_values + + +@pytest.fixture() +def auth_frontend(mocker): + return mocker.patch('vro_streamlit.auth.auth_frontend') @pytest.fixture() @@ -15,35 +20,24 @@ def app_test(): return app_test -def test_home(app_test) -> None: +def test_home(app_test, auth_frontend) -> None: app_test.run() assert not app_test.exception assert app_test.header[0].value == 'Home' assert app_test.subheader[0].value == 'Welcome to the home page!' + auth_frontend.show.assert_called_once() -def test_home_user_is_none(app_test) -> None: - # Initial page load +def test_home_user_is_none(app_test, auth_frontend) -> None: app_test.run() assert not app_test.exception assert_markdown_contains_values(app_test.markdown, 'Please Log In') - assert_button_contains_label(app_test.button[0], 'Log In') + auth_frontend.show.assert_called_once() - # Click page which reloads page - app_test.button(key=LOGIN_BUTTON).click().run() - assert_markdown_contains_values(app_test.markdown, f'Hello, {USERNAME}!') - assert_button_contains_label(app_test.button[0], 'Log Out') - -def test_home_user_is_not_none(app_test) -> None: - # Initial page load - app_test.session_state.user = User(USERNAME) +def test_home_user_is_not_none(app_test, auth_frontend) -> None: + app_test.session_state.user = User(ACCESS_TOKEN, USERNAME) app_test.run() assert not app_test.exception - assert_markdown_contains_values(app_test.markdown, f'Hello, {USERNAME}!') - assert_button_contains_label(app_test.button[0], 'Log Out') - - # Click page which reloads page - app_test.button(key=LOGIN_BUTTON).click().run() - assert_markdown_contains_values(app_test.markdown, 'Please Log In') - assert_button_contains_label(app_test.button[0], 'Log In') + assert_markdown_contains_values(app_test.markdown, f'Hello, **{USERNAME}**!') + auth_frontend.show.assert_called_once() diff --git a/vro-streamlit/test/test_main.py b/vro-streamlit/test/test_main.py index 94b330018f..2bbfdb5e49 100644 --- a/vro-streamlit/test/test_main.py +++ b/vro-streamlit/test/test_main.py @@ -1,10 +1,12 @@ -from test.conftest import APP_TEST_TIMEOUT, USERNAME +from test.conftest import ACCESS_TOKEN, APP_TEST_TIMEOUT, USERNAME +from unittest.mock import patch import pytest from streamlit.testing.v1 import AppTest from util import assert_markdown_contains_all_values from vro_streamlit.auth.user import User +from vro_streamlit.main import LOG_OUT_BUTTON @pytest.fixture() @@ -20,8 +22,8 @@ def app_test(): [ pytest.param(False, ':red_circle:', None, ':red_circle:', id='not connected, not authorized'), pytest.param(True, ':large_green_circle:', None, ':red_circle:', id='connected, not authorized'), - pytest.param(False, ':red_circle:', User(USERNAME), ':large_green_circle:', id='not connected, authorized'), - pytest.param(True, ':large_green_circle:', User(USERNAME), ':large_green_circle:', id='connected, authorized'), + pytest.param(False, ':red_circle:', User(ACCESS_TOKEN, USERNAME), ':large_green_circle:', id='not connected, authorized'), + pytest.param(True, ':large_green_circle:', User(ACCESS_TOKEN, USERNAME), ':large_green_circle:', id='connected, authorized'), ], ) def test_main_not_logged_in(app_test, db_connected, db_connected_icon, user, authorized_icon) -> None: @@ -57,3 +59,40 @@ def test_main_defaults(app_test) -> None: app_test.sidebar.markdown, ['Environment', 'Database', 'Authorized', '`test-environment`', ':red_circle:', ':red_circle:'], ) + + +def test_show_login_success_status(app_test): + app_test.session_state.login_success = True + app_test.run() + assert not app_test.exception + assert app_test.success[0].value == 'Login successful!' + + +def test_show_login_failed_status(app_test): + app_test.session_state.login_failed = 'Login Failed with message' + app_test.run() + assert not app_test.exception + assert app_test.error[0].value == 'Login Failed with message' + + +def test_main_log_out_success(app_test): + with patch('vro_streamlit.auth.auth_service') as auth_service: + auth_service.log_out.return_value = True + app_test.session_state.user = User(ACCESS_TOKEN, USERNAME) + app_test.run() + app_test.button(key=LOG_OUT_BUTTON).click().run() + assert not app_test.exception + auth_service.log_out.assert_called_once() + assert app_test.success[0].value == 'Logged out successfully.' + assert app_test.session_state.user is None + + +def test_main_log_out_error(app_test): + with patch('vro_streamlit.auth.auth_service') as auth_service: + auth_service.log_out.side_effect = Exception('Nope') + app_test.session_state.user = User(ACCESS_TOKEN, USERNAME) + app_test.run() + app_test.button(key=LOG_OUT_BUTTON).click().run() + assert not app_test.exception + auth_service.log_out.assert_called_once() + assert app_test.session_state.user is None