diff --git a/code_submitter/extract_archives.py b/code_submitter/extract_archives.py index 60ea3f9..4f77826 100644 --- a/code_submitter/extract_archives.py +++ b/code_submitter/extract_archives.py @@ -3,31 +3,43 @@ import asyncio import zipfile import argparse +from typing import cast from pathlib import Path import databases +from sqlalchemy.sql import select from . import utils, config +from .tables import Session -async def async_main(output_archive: Path) -> None: +async def async_main(output_archive: Path, session_name: str) -> None: output_archive.parent.mkdir(parents=True, exist_ok=True) database = databases.Database(config.DATABASE_URL) + session_id = cast(int, await database.fetch_one(select([ + Session.c.id, + ]).where( + Session.c.name == session_name, + ))) + with zipfile.ZipFile(output_archive) as zf: async with database.transaction(): - utils.collect_submissions(database, zf) + utils.collect_submissions(database, zf, session_id) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() + parser.add_argument('session_name', type=str) parser.add_argument('output_archive', type=Path) return parser.parse_args() def main(args: argparse.Namespace) -> None: - asyncio.get_event_loop().run_until_complete(async_main(args.output_archive)) + asyncio.get_event_loop().run_until_complete( + async_main(args.output_archive, args.session_name), + ) if __name__ == '__main__': diff --git a/code_submitter/server.py b/code_submitter/server.py index c4acee1..8ad46a8 100644 --- a/code_submitter/server.py +++ b/code_submitter/server.py @@ -1,6 +1,6 @@ import io import zipfile -import datetime +from typing import cast import databases from sqlalchemy.sql import select @@ -16,7 +16,7 @@ from . import auth, utils, config from .auth import User, BLUESHIRT_SCOPE -from .tables import Archive, ChoiceHistory +from .tables import Archive, Session, ChoiceHistory database = databases.Database(config.DATABASE_URL, force_rollback=config.TESTING) templates = Jinja2Templates(directory='templates') @@ -49,10 +49,14 @@ async def homepage(request: Request) -> Response: Archive.c.created.desc(), ), ) + sessions = await database.fetch_all( + Session.select().order_by(Session.c.created.desc()), + ) return templates.TemplateResponse('index.html', { 'request': request, 'chosen': chosen, 'uploads': uploads, + 'sessions': sessions, 'BLUESHIRT_SCOPE': BLUESHIRT_SCOPE, }) @@ -137,14 +141,25 @@ async def create_session(request: Request) -> Response: @requires(['authenticated', BLUESHIRT_SCOPE]) +@database.transaction() async def download_submissions(request: Request) -> Response: + session_id = cast(int, request.path_params['session_id']) + + session = await database.fetch_one( + Session.select().where(Session.c.id == session_id), + ) + + if session is None: + return Response( + f"{session_id!r} is not a valid session id", + status_code=404, + ) + buffer = io.BytesIO() with zipfile.ZipFile(buffer, mode='w') as zf: - await utils.collect_submissions(database, zf) + await utils.collect_submissions(database, zf, session_id) - filename = 'submissions-{now}.zip'.format( - now=datetime.datetime.now(datetime.timezone.utc), - ) + filename = f"submissions-{session['name']}.zip" return Response( buffer.getvalue(), @@ -157,7 +172,11 @@ async def download_submissions(request: Request) -> Response: Route('/', endpoint=homepage, methods=['GET']), Route('/upload', endpoint=upload, methods=['POST']), Route('/create-session', endpoint=create_session, methods=['POST']), - Route('/download-submissions', endpoint=download_submissions, methods=['GET']), + Route( + '/download-submissions/{session_id:int}', + endpoint=download_submissions, + methods=['GET'], + ), ] middleware = [ diff --git a/code_submitter/utils.py b/code_submitter/utils.py index a943a20..27e1cbc 100644 --- a/code_submitter/utils.py +++ b/code_submitter/utils.py @@ -9,29 +9,24 @@ async def get_chosen_submissions( database: databases.Database, + session_id: int, ) -> Dict[str, Tuple[int, bytes]]: """ Return a mapping of teams to their the chosen archive. """ - # Note: Ideally we'd group by team in SQL, however that doesn't seem to work - # properly -- we don't get the ordering applied before the grouping. - rows = await database.fetch_all( select([ Archive.c.id, Archive.c.team, Archive.c.content, - ChoiceHistory.c.created, ]).select_from( - Archive.join(ChoiceHistory), - ).order_by( - Archive.c.team, - ChoiceHistory.c.created.asc(), + Archive.join(ChoiceHistory).join(ChoiceForSession), + ).where( + Session.c.id == session_id, ), ) - # Rely on later keys replacing earlier occurrences of the same key. return {x['team']: (x['id'], x['content']) for x in rows} @@ -88,8 +83,9 @@ def summarise(submissions: Dict[str, Tuple[int, bytes]]) -> str: async def collect_submissions( database: databases.Database, zipfile: ZipFile, + session_id: int, ) -> None: - submissions = await get_chosen_submissions(database) + submissions = await get_chosen_submissions(database, session_id) for team, (_, content) in submissions.items(): zipfile.writestr(f'{team.upper()}.zip', content) diff --git a/templates/index.html b/templates/index.html index e63fdfc..16d8336 100644 --- a/templates/index.html +++ b/templates/index.html @@ -31,14 +31,40 @@
Name | +Created | +By | + {% if BLUESHIRT_SCOPE in request.auth.scopes %} +Download | + {% endif %} +
---|---|---|---|
{{ session.name }} | +{{ session.created }} | +{{ session.username }} | + + {% if BLUESHIRT_SCOPE in request.auth.scopes %} ++ + ▼ + + | + {% endif %} +