Skip to content

Commit

Permalink
Further refactor auth datatypes for safety, consistency.
Browse files Browse the repository at this point in the history
  • Loading branch information
liffiton committed Nov 28, 2024
1 parent 101650b commit 481acf5
Show file tree
Hide file tree
Showing 23 changed files with 155 additions and 153 deletions.
6 changes: 3 additions & 3 deletions src/codehelp/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def get_available_contexts() -> list[ContextConfig]:
db = get_db()
auth = get_auth()

class_id = auth.class_id
class_id = auth.cur_class.class_id if auth.cur_class else None
# Only return contexts that are available:
# current date anywhere on earth (using UTC+12) is at or after the saved date
context_rows = db.execute("SELECT * FROM contexts WHERE class_id=? AND available <= date('now', '+12 hours') ORDER BY class_order ASC", [class_id]).fetchall()
Expand All @@ -135,7 +135,7 @@ def get_context_string_by_id(ctx_id: int) -> str | None:
context_row = db.execute("SELECT * FROM context_strings WHERE id=?", [ctx_id]).fetchone()
else:
# for non-admin users, double-check that the context is in the current class
class_id = auth.class_id
class_id = auth.cur_class.class_id if auth.cur_class else None
context_row = db.execute("SELECT * FROM context_strings WHERE class_id=? AND id=?", [class_id, ctx_id]).fetchone()

if not context_row:
Expand All @@ -151,7 +151,7 @@ def get_context_by_name(ctx_name: str) -> ContextConfig | None:
db = get_db()
auth = get_auth()

class_id = auth.class_id
class_id = auth.cur_class.class_id if auth.cur_class else None

context_row = db.execute("SELECT * FROM contexts WHERE class_id=? AND name=?", [class_id, ctx_name]).fetchone()

Expand Down
33 changes: 15 additions & 18 deletions src/codehelp/context_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from markupsafe import Markup
from werkzeug.wrappers.response import Response

from gened.auth import get_auth, instructor_required
from gened.auth import get_auth_class, instructor_required
from gened.class_config import register_extra_section
from gened.db import get_db

Expand Down Expand Up @@ -50,8 +50,8 @@ def register(app: Flask) -> None:

def config_section_render() -> Markup:
db = get_db()
auth = get_auth()
class_id = auth.class_id
cur_class = get_auth_class()
class_id = cur_class.class_id

contexts = db.execute("""
SELECT id, name, CAST(available AS TEXT) AS available
Expand Down Expand Up @@ -81,10 +81,10 @@ def check_valid_context(f: Callable[P, R]) -> Callable[P, Response | R]:
@wraps(f)
def decorated_function(*args: P.args, **kwargs: P.kwargs) -> Response | R:
db = get_db()
auth = get_auth()
cur_class = get_auth_class()

# verify the given context is in the user's current class
class_id = auth.class_id
class_id = cur_class.class_id
ctx_id = kwargs['ctx_id']
context_row = db.execute("SELECT * FROM contexts WHERE id=?", [ctx_id]).fetchone()
if context_row['class_id'] != class_id:
Expand Down Expand Up @@ -159,24 +159,22 @@ def _insert_context(class_id: int, name: str, config: str, available: str) -> in

@bp.route("/create", methods=["POST"])
def create_context() -> Response:
auth = get_auth()
assert auth.class_id
cur_class = get_auth_class()

context = ContextConfig.from_request_form(request.form)
_insert_context(auth.class_id, context.name, context.to_json(), "9999-12-31") # defaults to hidden
_insert_context(cur_class.class_id, context.name, context.to_json(), "9999-12-31") # defaults to hidden
return redirect(url_for("class_config.config_form"))


@bp.route("/copy/", methods=[]) # just for url_for() in js code
@bp.route("/copy/<int:ctx_id>", methods=["POST"])
@check_valid_context
def copy_context(ctx_row: Row, ctx_id: int) -> Response:
auth = get_auth()
assert auth.class_id
cur_class = get_auth_class()

# passing existing name, but _insert_context will take care of finding
# a new, unused name in the class.
_insert_context(auth.class_id, ctx_row['name'], ctx_row['config'], ctx_row['available'])
_insert_context(cur_class.class_id, ctx_row['name'], ctx_row['config'], ctx_row['available'])
return redirect(url_for("class_config.config_form"))


Expand All @@ -188,9 +186,8 @@ def update_context(ctx_id: int, ctx_row: Row) -> Response:
context = ContextConfig.from_request_form(request.form)

# names must be unique within a class: check/look for an unused name
auth = get_auth()
assert auth.class_id
name = _make_unique_context_name(auth.class_id, context.name, ctx_id)
cur_class = get_auth_class()
name = _make_unique_context_name(cur_class.class_id, context.name, ctx_id)

db.execute("UPDATE contexts SET name=?, config=? WHERE id=?", [name, context.to_json(), ctx_id])
db.commit()
Expand All @@ -215,9 +212,9 @@ def delete_context(ctx_id: int, ctx_row: Row) -> Response:
@bp.route("/update_order", methods=["POST"])
def update_order() -> str:
db = get_db()
auth = get_auth()
cur_class = get_auth_class()

class_id = auth.class_id # Get the current class to ensure we don't change another class.
class_id = cur_class.class_id # Get the current class to ensure we don't change another class.

ordered_ids = request.json
assert isinstance(ordered_ids, list)
Expand All @@ -233,9 +230,9 @@ def update_order() -> str:
@bp.route("/update_available", methods=["POST"])
def update_available() -> str:
db = get_db()
auth = get_auth()
cur_class = get_auth_class()

class_id = auth.class_id # Get the current class to ensure we don't change another class.
class_id = cur_class.class_id # Get the current class to ensure we don't change another class.

data = request.json
assert isinstance(data, dict)
Expand Down
4 changes: 2 additions & 2 deletions src/codehelp/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def run_query(llm: LLMConfig, context: ContextConfig | None, code: str, error: s
def record_query(context: ContextConfig | None, code: str, error: str, issue: str) -> int:
db = get_db()
auth = get_auth()
role_id = auth.role_id
role_id = auth.cur_class.role_id if auth.cur_class else None

if context is not None:
context_name = context.name
Expand Down Expand Up @@ -253,7 +253,7 @@ def help_request(llm: LLMConfig) -> Response:
def load_test(llm: LLMConfig) -> Response:
# Require that we're logged in as the load_test admin user
auth = get_auth()
if auth.display_name != 'load_test':
if auth.user is None or auth.user.display_name != 'load_test':
return abort(403)

context = ContextConfig(name="__LOADTEST_Context")
Expand Down
2 changes: 1 addition & 1 deletion src/codehelp/templates/context_config.html
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ <h3 class="title is-4">Schedule '<span x-text="ctx.name"></span>'</h3>
{% macro link_display_alpinejs(route) -%}
{# unholy combination of Jinja templating (route available in python) and Javascript string replacement (context available in JS)... #}
{
link_URL: '{{ url_for(route, class_id=auth.class_id, ctx_name='__replace__', _external=True) }}'.replace('__replace__', encodeURIComponent(ctx.name)),
link_URL: '{{ url_for(route, class_id=auth.cur_class.class_id, ctx_name='__replace__', _external=True) }}'.replace('__replace__', encodeURIComponent(ctx.name)),
copied: false,
copy_url() {
navigator.clipboard.writeText(this.link_URL);
Expand Down
4 changes: 2 additions & 2 deletions src/codehelp/templates/context_edit_form.html
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
<div class="container">
{% if context %}
{# We're editing an existing context. #}
<h1 class="title">Editing context '{{ context.name }}' in class {{ auth.class_name }}</h1>
<h1 class="title">Editing context '{{ context.name }}' in class {{ auth.cur_class.class_name }}</h1>
<form class="wide-labels" action="{{ url_for(".update_context", ctx_id=context.id) }}" method="post">
{% else %}
<h1 class="title">Create context in class {{ auth.class_name }}</h1>
<h1 class="title">Create context in class {{ auth.cur_class.class_name }}</h1>
<form class="wide-labels" action="{{ url_for(".create_context") }}" method="post">
{% endif %}

Expand Down
4 changes: 2 additions & 2 deletions src/codehelp/templates/help_form.html
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
{# debounce on the submit handler so that the form's actual submit fires *before* the form elements are disabled #}
<form class="wide-labels" action="{{url_for('helper.help_request')}}" method="post" x-data="{loading: false}" x-on:pageshow.window="loading = false" x-on:submit.debounce.10ms="loading = true">

{% if auth.class_name %}
{% if auth.cur_class %}
<div class="field is-horizontal">
<div class="field-label">
<label class="label">Class:</label>
</div>
<div class="field-body">
{{ auth.class_name }}
{{ auth.cur_class.class_name }}
</div>
</div>
{% elif llm.tokens_remaining != None %}
Expand Down
4 changes: 2 additions & 2 deletions src/codehelp/templates/landing.html
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ <h2 class="has-text-white">Ask it...</h2>
</div>
</div>
<p class="has-text-centered mt-5 mb-3">
{% if auth.user_id %}
{% if auth.user %}
<a class="button is-link is-light is-rounded is-size-4" href="{{ url_for('helper.help_form') }}">
Try it now!
</a>
Expand Down Expand Up @@ -66,7 +66,7 @@ <h2>For Instructors</h2>
<li>Everyone will sign in <b>automatically</b> (no separate login) via a link from your course page.</li>
<li>Takes some time to set up, and may require support from your LMS administrator.</li>
</ul>
{% if auth.user_id and auth.auth_provider != 'demo' %}
{% if auth.user and auth.user.auth_provider != 'demo' %}
<a class="button button-inline is-link is-size-5" href="{{ url_for("profile.main") }}">Go to your Profile page</a> to manually create a class.
{% else %}
<a class="button button-inline is-link is-size-5" href="{{ url_for("auth.login", next=url_for("profile.main")) }}">Sign in using Google, GitHub, or Microsoft</a> and manually create a class from your profile page.
Expand Down
2 changes: 1 addition & 1 deletion src/codehelp/templates/tutor_nav_item.html
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
SPDX-License-Identifier: AGPL-3.0-only
#}

{% if auth.user_id and "chats_experiment" in auth.class_experiments %}
{% if auth.user and "chats_experiment" in auth.class_experiments %}
<a class="navbar-item has-text-success" href="{{ url_for('tutor.tutor_form') }}">
<div class="icon-text">
<span class="icon">
Expand Down
4 changes: 2 additions & 2 deletions src/codehelp/tutor.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def create_chat(topic: str, context: ContextConfig | None) -> int:
db = get_db()
auth = get_auth()
user_id = auth.user_id
role_id = auth.role_id
role_id = auth.cur_class.role_id if auth.cur_class else None

if context is not None:
context_name = context.name
Expand Down Expand Up @@ -196,7 +196,7 @@ def get_chat(chat_id: int) -> tuple[list[ChatCompletionMessageParam], str, str,
access_allowed = \
(auth.user_id == chat_row['user_id']) \
or auth.is_admin \
or (auth.role == 'instructor' and auth.class_id == chat_row['class_id'])
or (auth.cur_class and auth.cur_class.role == 'instructor' and auth.cur_class.class_id == chat_row['class_id'])

if not access_allowed:
raise AccessDeniedError
Expand Down
Loading

0 comments on commit 481acf5

Please sign in to comment.