Skip to content

Commit

Permalink
Download archives by session
Browse files Browse the repository at this point in the history
This replaces the previous download of only the latest archives
with a means to download the archives which were chosen for a
given session.
  • Loading branch information
PeterJCLaw committed Jan 9, 2021
1 parent c9718ad commit 9932c15
Show file tree
Hide file tree
Showing 6 changed files with 195 additions and 41 deletions.
18 changes: 15 additions & 3 deletions code_submitter/extract_archives.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,43 @@
import asyncio
import zipfile
import argparse
from typing import cast
from pathlib import Path

import databases
from sqlalchemy.sql import select

from . import utils, config
from .tables import Session


async def async_main(output_archive: Path) -> None:
async def async_main(output_archive: Path, session_name: str) -> None:
output_archive.parent.mkdir(parents=True, exist_ok=True)

database = databases.Database(config.DATABASE_URL)

session_id = cast(int, await database.fetch_one(select([
Session.c.id,
]).where(
Session.c.name == session_name,
)))

with zipfile.ZipFile(output_archive) as zf:
async with database.transaction():
utils.collect_submissions(database, zf)
utils.collect_submissions(database, zf, session_id)


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument('session_name', type=str)
parser.add_argument('output_archive', type=Path)
return parser.parse_args()


def main(args: argparse.Namespace) -> None:
asyncio.get_event_loop().run_until_complete(async_main(args.output_archive))
asyncio.get_event_loop().run_until_complete(
async_main(args.output_archive, args.session_name),
)


if __name__ == '__main__':
Expand Down
33 changes: 26 additions & 7 deletions code_submitter/server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import io
import zipfile
import datetime
from typing import cast

import databases
from sqlalchemy.sql import select
Expand All @@ -16,7 +16,7 @@

from . import auth, utils, config
from .auth import User, BLUESHIRT_SCOPE
from .tables import Archive, ChoiceHistory
from .tables import Archive, Session, ChoiceHistory

database = databases.Database(config.DATABASE_URL, force_rollback=config.TESTING)
templates = Jinja2Templates(directory='templates')
Expand Down Expand Up @@ -49,10 +49,14 @@ async def homepage(request: Request) -> Response:
Archive.c.created.desc(),
),
)
sessions = await database.fetch_all(
Session.select().order_by(Session.c.created.desc()),
)
return templates.TemplateResponse('index.html', {
'request': request,
'chosen': chosen,
'uploads': uploads,
'sessions': sessions,
'BLUESHIRT_SCOPE': BLUESHIRT_SCOPE,
})

Expand Down Expand Up @@ -137,14 +141,25 @@ async def create_session(request: Request) -> Response:


@requires(['authenticated', BLUESHIRT_SCOPE])
@database.transaction()
async def download_submissions(request: Request) -> Response:
session_id = cast(int, request.path_params['session_id'])

session = await database.fetch_one(
Session.select().where(Session.c.id == session_id),
)

if session is None:
return Response(
f"{session_id!r} is not a valid session id",
status_code=404,
)

buffer = io.BytesIO()
with zipfile.ZipFile(buffer, mode='w') as zf:
await utils.collect_submissions(database, zf)
await utils.collect_submissions(database, zf, session_id)

filename = 'submissions-{now}.zip'.format(
now=datetime.datetime.now(datetime.timezone.utc),
)
filename = f"submissions-{session['name']}.zip"

return Response(
buffer.getvalue(),
Expand All @@ -157,7 +172,11 @@ async def download_submissions(request: Request) -> Response:
Route('/', endpoint=homepage, methods=['GET']),
Route('/upload', endpoint=upload, methods=['POST']),
Route('/create-session', endpoint=create_session, methods=['POST']),
Route('/download-submissions', endpoint=download_submissions, methods=['GET']),
Route(
'/download-submissions/{session_id:int}',
endpoint=download_submissions,
methods=['GET'],
),
]

middleware = [
Expand Down
16 changes: 6 additions & 10 deletions code_submitter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,24 @@

async def get_chosen_submissions(
database: databases.Database,
session_id: int,
) -> Dict[str, Tuple[int, bytes]]:
"""
Return a mapping of teams to their the chosen archive.
"""

# Note: Ideally we'd group by team in SQL, however that doesn't seem to work
# properly -- we don't get the ordering applied before the grouping.

rows = await database.fetch_all(
select([
Archive.c.id,
Archive.c.team,
Archive.c.content,
ChoiceHistory.c.created,
]).select_from(
Archive.join(ChoiceHistory),
).order_by(
Archive.c.team,
ChoiceHistory.c.created.asc(),
Archive.join(ChoiceHistory).join(ChoiceForSession),
).where(
Session.c.id == session_id,
),
)

# Rely on later keys replacing earlier occurrences of the same key.
return {x['team']: (x['id'], x['content']) for x in rows}


Expand Down Expand Up @@ -88,8 +83,9 @@ def summarise(submissions: Dict[str, Tuple[int, bytes]]) -> str:
async def collect_submissions(
database: databases.Database,
zipfile: ZipFile,
session_id: int,
) -> None:
submissions = await get_chosen_submissions(database)
submissions = await get_chosen_submissions(database, session_id)

for team, (_, content) in submissions.items():
zipfile.writestr(f'{team.upper()}.zip', content)
Expand Down
34 changes: 30 additions & 4 deletions templates/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,40 @@
<body>
<div class="container">
<h1>Virtual Competition Code Submission</h1>
{% if BLUESHIRT_SCOPE in request.auth.scopes %}
<div class="row">
<div class="col-sm-6">
<a download href="{{ url_for('download_submissions') }}">
Download current chosen submissions
</a>
<h3>Sessions</h3>
<table class="table table-striped">
<tr>
<th scope="col">Name</th>
<th scope="col">Created</th>
<th scope="col">By</th>
{% if BLUESHIRT_SCOPE in request.auth.scopes %}
<th scope="col">Download</th>
{% endif %}
</tr>
{% for session in sessions %}
<tr>
<td>{{ session.name }}</td>
<td>{{ session.created }}</td>
<td>{{ session.username }}</td>
<!-- TODO: teams in the session -->
{% if BLUESHIRT_SCOPE in request.auth.scopes %}
<td>
<a
download
href="{{ url_for('download_submissions', session_id=session.id) }}"
>
</a>
</td>
{% endif %}
</tr>
{% endfor %}
</table>
</div>
</div>
{% if BLUESHIRT_SCOPE in request.auth.scopes %}
<div class="row">
<div class="col-sm-6">
<form
Expand Down
74 changes: 65 additions & 9 deletions tests/tests_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
from sqlalchemy.sql import select
from starlette.testclient import TestClient

from code_submitter.tables import Archive, Session, ChoiceHistory
from code_submitter.tables import (
Archive,
Session,
ChoiceHistory,
ChoiceForSession,
)


class AppTests(test_utils.DatabaseTestCase):
Expand All @@ -17,11 +22,11 @@ def setUp(self) -> None:
# App import must happen after TESTING environment setup
from code_submitter.server import app

def url_for(name: str) -> str:
def url_for(name: str, **path_params: str) -> str:
# While it makes for uglier tests, we do need to use more absolute
# paths here so that the urls emitted contain the root_path from the
# ASGI server and in turn work correctly under proxy.
return 'http://testserver{}'.format(app.url_path_for(name))
return 'http://testserver{}'.format(app.url_path_for(name, **path_params))

test_client = TestClient(app)
self.session = test_client.__enter__()
Expand Down Expand Up @@ -315,7 +320,13 @@ def test_create_session(self) -> None:
)

def test_no_download_link_for_non_blueshirt(self) -> None:
download_url = self.url_for('download_submissions')
session_id = self.await_(self.database.execute(
Session.insert().values(
name="Test session",
username='blueshirt',
),
))
download_url = self.url_for('download_submissions', session_id=session_id)

response = self.session.get(self.url_for('homepage'))

Expand All @@ -325,20 +336,50 @@ def test_no_download_link_for_non_blueshirt(self) -> None:
def test_shows_download_link_for_blueshirt(self) -> None:
self.session.auth = ('blueshirt', 'blueshirt')

download_url = self.url_for('download_submissions')
session_id = self.await_(self.database.execute(
Session.insert().values(
name="Test session",
username='blueshirt',
),
))
download_url = self.url_for('download_submissions', session_id=session_id)

response = self.session.get(self.url_for('homepage'))
html = response.text
self.assertIn(download_url, html)

def test_download_submissions_requires_blueshirt(self) -> None:
response = self.session.get(self.url_for('download_submissions'))
session_id = self.await_(self.database.execute(
Session.insert().values(
name="Test session",
username='blueshirt',
),
))
response = self.session.get(
self.url_for('download_submissions', session_id=session_id),
)
self.assertEqual(403, response.status_code)

def test_download_submissions_when_invalid_session(self) -> None:
self.session.auth = ('blueshirt', 'blueshirt')
response = self.session.get(
self.url_for('download_submissions', session_id='4'),
)
self.assertEqual(404, response.status_code)

def test_download_submissions_when_none(self) -> None:
self.session.auth = ('blueshirt', 'blueshirt')

response = self.session.get(self.url_for('download_submissions'))
session_id = self.await_(self.database.execute(
Session.insert().values(
name="Test session",
username='blueshirt',
),
))

response = self.session.get(
self.url_for('download_submissions', session_id=session_id),
)
self.assertEqual(200, response.status_code)

with zipfile.ZipFile(io.BytesIO(response.content)) as zf:
Expand All @@ -359,15 +400,30 @@ def test_download_submissions(self) -> None:
created=datetime.datetime(2020, 8, 8, 12, 0),
),
))
self.await_(self.database.execute(
choice_id = self.await_(self.database.execute(
ChoiceHistory.insert().values(
archive_id=8888888888,
username='test_user',
created=datetime.datetime(2020, 9, 9, 12, 0),
),
))

response = self.session.get(self.url_for('download_submissions'))
session_id = self.await_(self.database.execute(
Session.insert().values(
name="Test session",
username='blueshirt',
),
))
self.await_(self.database.execute(
ChoiceForSession.insert().values(
choice_id=choice_id,
session_id=session_id,
),
))

response = self.session.get(
self.url_for('download_submissions', session_id=session_id),
)
self.assertEqual(200, response.status_code)

with zipfile.ZipFile(io.BytesIO(response.content)) as zf:
Expand Down
Loading

0 comments on commit 9932c15

Please sign in to comment.