diff --git a/code_submitter/extract_archives.py b/code_submitter/extract_archives.py index 60ea3f9..4f77826 100644 --- a/code_submitter/extract_archives.py +++ b/code_submitter/extract_archives.py @@ -3,31 +3,43 @@ import asyncio import zipfile import argparse +from typing import cast from pathlib import Path import databases +from sqlalchemy.sql import select from . import utils, config +from .tables import Session -async def async_main(output_archive: Path) -> None: +async def async_main(output_archive: Path, session_name: str) -> None: output_archive.parent.mkdir(parents=True, exist_ok=True) database = databases.Database(config.DATABASE_URL) + session_id = cast(int, await database.fetch_one(select([ + Session.c.id, + ]).where( + Session.c.name == session_name, + ))) + with zipfile.ZipFile(output_archive) as zf: async with database.transaction(): - utils.collect_submissions(database, zf) + utils.collect_submissions(database, zf, session_id) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() + parser.add_argument('session_name', type=str) parser.add_argument('output_archive', type=Path) return parser.parse_args() def main(args: argparse.Namespace) -> None: - asyncio.get_event_loop().run_until_complete(async_main(args.output_archive)) + asyncio.get_event_loop().run_until_complete( + async_main(args.output_archive, args.session_name), + ) if __name__ == '__main__': diff --git a/code_submitter/server.py b/code_submitter/server.py index c4acee1..8ad46a8 100644 --- a/code_submitter/server.py +++ b/code_submitter/server.py @@ -1,6 +1,6 @@ import io import zipfile -import datetime +from typing import cast import databases from sqlalchemy.sql import select @@ -16,7 +16,7 @@ from . import auth, utils, config from .auth import User, BLUESHIRT_SCOPE -from .tables import Archive, ChoiceHistory +from .tables import Archive, Session, ChoiceHistory database = databases.Database(config.DATABASE_URL, force_rollback=config.TESTING) templates = Jinja2Templates(directory='templates') @@ -49,10 +49,14 @@ async def homepage(request: Request) -> Response: Archive.c.created.desc(), ), ) + sessions = await database.fetch_all( + Session.select().order_by(Session.c.created.desc()), + ) return templates.TemplateResponse('index.html', { 'request': request, 'chosen': chosen, 'uploads': uploads, + 'sessions': sessions, 'BLUESHIRT_SCOPE': BLUESHIRT_SCOPE, }) @@ -137,14 +141,25 @@ async def create_session(request: Request) -> Response: @requires(['authenticated', BLUESHIRT_SCOPE]) +@database.transaction() async def download_submissions(request: Request) -> Response: + session_id = cast(int, request.path_params['session_id']) + + session = await database.fetch_one( + Session.select().where(Session.c.id == session_id), + ) + + if session is None: + return Response( + f"{session_id!r} is not a valid session id", + status_code=404, + ) + buffer = io.BytesIO() with zipfile.ZipFile(buffer, mode='w') as zf: - await utils.collect_submissions(database, zf) + await utils.collect_submissions(database, zf, session_id) - filename = 'submissions-{now}.zip'.format( - now=datetime.datetime.now(datetime.timezone.utc), - ) + filename = f"submissions-{session['name']}.zip" return Response( buffer.getvalue(), @@ -157,7 +172,11 @@ async def download_submissions(request: Request) -> Response: Route('/', endpoint=homepage, methods=['GET']), Route('/upload', endpoint=upload, methods=['POST']), Route('/create-session', endpoint=create_session, methods=['POST']), - Route('/download-submissions', endpoint=download_submissions, methods=['GET']), + Route( + '/download-submissions/{session_id:int}', + endpoint=download_submissions, + methods=['GET'], + ), ] middleware = [ diff --git a/code_submitter/utils.py b/code_submitter/utils.py index a943a20..27e1cbc 100644 --- a/code_submitter/utils.py +++ b/code_submitter/utils.py @@ -9,29 +9,24 @@ async def get_chosen_submissions( database: databases.Database, + session_id: int, ) -> Dict[str, Tuple[int, bytes]]: """ Return a mapping of teams to their the chosen archive. """ - # Note: Ideally we'd group by team in SQL, however that doesn't seem to work - # properly -- we don't get the ordering applied before the grouping. - rows = await database.fetch_all( select([ Archive.c.id, Archive.c.team, Archive.c.content, - ChoiceHistory.c.created, ]).select_from( - Archive.join(ChoiceHistory), - ).order_by( - Archive.c.team, - ChoiceHistory.c.created.asc(), + Archive.join(ChoiceHistory).join(ChoiceForSession), + ).where( + Session.c.id == session_id, ), ) - # Rely on later keys replacing earlier occurrences of the same key. return {x['team']: (x['id'], x['content']) for x in rows} @@ -88,8 +83,9 @@ def summarise(submissions: Dict[str, Tuple[int, bytes]]) -> str: async def collect_submissions( database: databases.Database, zipfile: ZipFile, + session_id: int, ) -> None: - submissions = await get_chosen_submissions(database) + submissions = await get_chosen_submissions(database, session_id) for team, (_, content) in submissions.items(): zipfile.writestr(f'{team.upper()}.zip', content) diff --git a/templates/index.html b/templates/index.html index e63fdfc..16d8336 100644 --- a/templates/index.html +++ b/templates/index.html @@ -31,14 +31,40 @@

Virtual Competition Code Submission

- {% if BLUESHIRT_SCOPE in request.auth.scopes %}
- - Download current chosen submissions - +

Sessions

+ + + + + + {% if BLUESHIRT_SCOPE in request.auth.scopes %} + + {% endif %} + + {% for session in sessions %} + + + + + + {% if BLUESHIRT_SCOPE in request.auth.scopes %} + + {% endif %} + + {% endfor %} +
NameCreatedByDownload
{{ session.name }}{{ session.created }}{{ session.username }} + + ▼ + +
+ {% if BLUESHIRT_SCOPE in request.auth.scopes %}
None: # App import must happen after TESTING environment setup from code_submitter.server import app - def url_for(name: str) -> str: + def url_for(name: str, **path_params: str) -> str: # While it makes for uglier tests, we do need to use more absolute # paths here so that the urls emitted contain the root_path from the # ASGI server and in turn work correctly under proxy. - return 'http://testserver{}'.format(app.url_path_for(name)) + return 'http://testserver{}'.format(app.url_path_for(name, **path_params)) test_client = TestClient(app) self.session = test_client.__enter__() @@ -315,7 +320,13 @@ def test_create_session(self) -> None: ) def test_no_download_link_for_non_blueshirt(self) -> None: - download_url = self.url_for('download_submissions') + session_id = self.await_(self.database.execute( + Session.insert().values( + name="Test session", + username='blueshirt', + ), + )) + download_url = self.url_for('download_submissions', session_id=session_id) response = self.session.get(self.url_for('homepage')) @@ -325,20 +336,50 @@ def test_no_download_link_for_non_blueshirt(self) -> None: def test_shows_download_link_for_blueshirt(self) -> None: self.session.auth = ('blueshirt', 'blueshirt') - download_url = self.url_for('download_submissions') + session_id = self.await_(self.database.execute( + Session.insert().values( + name="Test session", + username='blueshirt', + ), + )) + download_url = self.url_for('download_submissions', session_id=session_id) response = self.session.get(self.url_for('homepage')) html = response.text self.assertIn(download_url, html) def test_download_submissions_requires_blueshirt(self) -> None: - response = self.session.get(self.url_for('download_submissions')) + session_id = self.await_(self.database.execute( + Session.insert().values( + name="Test session", + username='blueshirt', + ), + )) + response = self.session.get( + self.url_for('download_submissions', session_id=session_id), + ) self.assertEqual(403, response.status_code) + def test_download_submissions_when_invalid_session(self) -> None: + self.session.auth = ('blueshirt', 'blueshirt') + response = self.session.get( + self.url_for('download_submissions', session_id='4'), + ) + self.assertEqual(404, response.status_code) + def test_download_submissions_when_none(self) -> None: self.session.auth = ('blueshirt', 'blueshirt') - response = self.session.get(self.url_for('download_submissions')) + session_id = self.await_(self.database.execute( + Session.insert().values( + name="Test session", + username='blueshirt', + ), + )) + + response = self.session.get( + self.url_for('download_submissions', session_id=session_id), + ) self.assertEqual(200, response.status_code) with zipfile.ZipFile(io.BytesIO(response.content)) as zf: @@ -359,7 +400,7 @@ def test_download_submissions(self) -> None: created=datetime.datetime(2020, 8, 8, 12, 0), ), )) - self.await_(self.database.execute( + choice_id = self.await_(self.database.execute( ChoiceHistory.insert().values( archive_id=8888888888, username='test_user', @@ -367,7 +408,22 @@ def test_download_submissions(self) -> None: ), )) - response = self.session.get(self.url_for('download_submissions')) + session_id = self.await_(self.database.execute( + Session.insert().values( + name="Test session", + username='blueshirt', + ), + )) + self.await_(self.database.execute( + ChoiceForSession.insert().values( + choice_id=choice_id, + session_id=session_id, + ), + )) + + response = self.session.get( + self.url_for('download_submissions', session_id=session_id), + ) self.assertEqual(200, response.status_code) with zipfile.ZipFile(io.BytesIO(response.content)) as zf: diff --git a/tests/tests_utils.py b/tests/tests_utils.py index b341912..ffac35a 100644 --- a/tests/tests_utils.py +++ b/tests/tests_utils.py @@ -5,7 +5,12 @@ import test_utils from code_submitter import utils -from code_submitter.tables import Archive, ChoiceHistory, ChoiceForSession +from code_submitter.tables import ( + Archive, + Session, + ChoiceHistory, + ChoiceForSession, +) class UtilsTests(test_utils.InTransactionTestCase): @@ -41,18 +46,20 @@ def setUp(self) -> None: )) def test_get_chosen_submissions_nothing_chosen(self) -> None: - result = self.await_(utils.get_chosen_submissions(self.database)) + result = self.await_( + utils.get_chosen_submissions(self.database, session_id=0), + ) self.assertEqual({}, result) def test_get_chosen_submissions_multiple_chosen(self) -> None: - self.await_(self.database.execute( + choice_id_1 = self.await_(self.database.execute( ChoiceHistory.insert().values( archive_id=8888888888, username='someone_else', created=datetime.datetime(2020, 8, 8, 12, 0), ), )) - self.await_(self.database.execute( + choice_id_2 = self.await_(self.database.execute( ChoiceHistory.insert().values( archive_id=1111111111, username='test_user', @@ -66,8 +73,28 @@ def test_get_chosen_submissions_multiple_chosen(self) -> None: created=datetime.datetime(2020, 2, 2, 12, 0), ), )) + session_id = self.await_(self.database.execute( + Session.insert().values( + name="Test session", + username='blueshirt', + ), + )) + self.await_(self.database.execute( + ChoiceForSession.insert().values( + choice_id=choice_id_1, + session_id=session_id, + ), + )) + self.await_(self.database.execute( + ChoiceForSession.insert().values( + choice_id=choice_id_2, + session_id=session_id, + ), + )) - result = self.await_(utils.get_chosen_submissions(self.database)) + result = self.await_( + utils.get_chosen_submissions(self.database, session_id), + ) self.assertEqual( { 'SRZ2': (1111111111, b'1111111111'), @@ -118,14 +145,14 @@ def test_create_session(self) -> None: ) def test_collect_submissions(self) -> None: - self.await_(self.database.execute( + choice_id_1 = self.await_(self.database.execute( ChoiceHistory.insert().values( archive_id=8888888888, username='someone_else', created=datetime.datetime(2020, 8, 8, 12, 0), ), )) - self.await_(self.database.execute( + choice_id_2 = self.await_(self.database.execute( ChoiceHistory.insert().values( archive_id=1111111111, username='test_user', @@ -139,9 +166,27 @@ def test_collect_submissions(self) -> None: created=datetime.datetime(2020, 2, 2, 12, 0), ), )) + session_id = self.await_(self.database.execute( + Session.insert().values( + name="Test session", + username='blueshirt', + ), + )) + self.await_(self.database.execute( + ChoiceForSession.insert().values( + choice_id=choice_id_1, + session_id=session_id, + ), + )) + self.await_(self.database.execute( + ChoiceForSession.insert().values( + choice_id=choice_id_2, + session_id=session_id, + ), + )) with zipfile.ZipFile(io.BytesIO(), mode='w') as zf: - self.await_(utils.collect_submissions(self.database, zf)) + self.await_(utils.collect_submissions(self.database, zf, session_id)) self.assertEqual( {