Skip to content

Commit

Permalink
Refactor auth dicts to dataclasses for better type checking and safety.
Browse files Browse the repository at this point in the history
  • Loading branch information
liffiton committed Nov 28, 2024
1 parent 739d683 commit 101650b
Show file tree
Hide file tree
Showing 27 changed files with 193 additions and 192 deletions.
8 changes: 4 additions & 4 deletions src/codehelp/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def get_available_contexts() -> list[ContextConfig]:
db = get_db()
auth = get_auth()

class_id = auth['class_id']
class_id = auth.class_id
# Only return contexts that are available:
# current date anywhere on earth (using UTC+12) is at or after the saved date
context_rows = db.execute("SELECT * FROM contexts WHERE class_id=? AND available <= date('now', '+12 hours') ORDER BY class_order ASC", [class_id]).fetchall()
Expand All @@ -130,12 +130,12 @@ def get_context_string_by_id(ctx_id: int) -> str | None:
db = get_db()
auth = get_auth()

if auth['is_admin']:
if auth.is_admin:
# admin can grab any context
context_row = db.execute("SELECT * FROM context_strings WHERE id=?", [ctx_id]).fetchone()
else:
# for non-admin users, double-check that the context is in the current class
class_id = auth['class_id']
class_id = auth.class_id
context_row = db.execute("SELECT * FROM context_strings WHERE class_id=? AND id=?", [class_id, ctx_id]).fetchone()

if not context_row:
Expand All @@ -151,7 +151,7 @@ def get_context_by_name(ctx_name: str) -> ContextConfig | None:
db = get_db()
auth = get_auth()

class_id = auth['class_id']
class_id = auth.class_id

context_row = db.execute("SELECT * FROM contexts WHERE class_id=? AND name=?", [class_id, ctx_name]).fetchone()

Expand Down
20 changes: 10 additions & 10 deletions src/codehelp/context_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def register(app: Flask) -> None:
def config_section_render() -> Markup:
db = get_db()
auth = get_auth()
class_id = auth['class_id']
class_id = auth.class_id

contexts = db.execute("""
SELECT id, name, CAST(available AS TEXT) AS available
Expand Down Expand Up @@ -84,7 +84,7 @@ def decorated_function(*args: P.args, **kwargs: P.kwargs) -> Response | R:
auth = get_auth()

# verify the given context is in the user's current class
class_id = auth['class_id']
class_id = auth.class_id
ctx_id = kwargs['ctx_id']
context_row = db.execute("SELECT * FROM contexts WHERE id=?", [ctx_id]).fetchone()
if context_row['class_id'] != class_id:
Expand Down Expand Up @@ -160,10 +160,10 @@ def _insert_context(class_id: int, name: str, config: str, available: str) -> in
@bp.route("/create", methods=["POST"])
def create_context() -> Response:
auth = get_auth()
assert auth['class_id']
assert auth.class_id

context = ContextConfig.from_request_form(request.form)
_insert_context(auth['class_id'], context.name, context.to_json(), "9999-12-31") # defaults to hidden
_insert_context(auth.class_id, context.name, context.to_json(), "9999-12-31") # defaults to hidden
return redirect(url_for("class_config.config_form"))


Expand All @@ -172,11 +172,11 @@ def create_context() -> Response:
@check_valid_context
def copy_context(ctx_row: Row, ctx_id: int) -> Response:
auth = get_auth()
assert auth['class_id']
assert auth.class_id

# passing existing name, but _insert_context will take care of finding
# a new, unused name in the class.
_insert_context(auth['class_id'], ctx_row['name'], ctx_row['config'], ctx_row['available'])
_insert_context(auth.class_id, ctx_row['name'], ctx_row['config'], ctx_row['available'])
return redirect(url_for("class_config.config_form"))


Expand All @@ -189,8 +189,8 @@ def update_context(ctx_id: int, ctx_row: Row) -> Response:

# names must be unique within a class: check/look for an unused name
auth = get_auth()
assert auth['class_id']
name = _make_unique_context_name(auth['class_id'], context.name, ctx_id)
assert auth.class_id
name = _make_unique_context_name(auth.class_id, context.name, ctx_id)

db.execute("UPDATE contexts SET name=?, config=? WHERE id=?", [name, context.to_json(), ctx_id])
db.commit()
Expand All @@ -217,7 +217,7 @@ def update_order() -> str:
db = get_db()
auth = get_auth()

class_id = auth['class_id'] # Get the current class to ensure we don't change another class.
class_id = auth.class_id # Get the current class to ensure we don't change another class.

ordered_ids = request.json
assert isinstance(ordered_ids, list)
Expand All @@ -235,7 +235,7 @@ def update_available() -> str:
db = get_db()
auth = get_auth()

class_id = auth['class_id'] # Get the current class to ensure we don't change another class.
class_id = auth.class_id # Get the current class to ensure we don't change another class.

data = request.json
assert isinstance(data, dict)
Expand Down
10 changes: 5 additions & 5 deletions src/codehelp/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def help_form(llm: LLMConfig, query_id: int | None = None, class_id: int | None
else:
# no query specified,
# but we can pre-select the most recently used context, if available
recent_row = db.execute("SELECT context_name FROM queries WHERE queries.user_id=? ORDER BY id DESC LIMIT 1", [auth['user_id']]).fetchone()
recent_row = db.execute("SELECT context_name FROM queries WHERE queries.user_id=? ORDER BY id DESC LIMIT 1", [auth.user_id]).fetchone()
if recent_row:
selected_context_name = recent_row['context_name']

Expand Down Expand Up @@ -192,7 +192,7 @@ def run_query(llm: LLMConfig, context: ContextConfig | None, code: str, error: s
def record_query(context: ContextConfig | None, code: str, error: str, issue: str) -> int:
db = get_db()
auth = get_auth()
role_id = auth['role_id']
role_id = auth.role_id

if context is not None:
context_name = context.name
Expand All @@ -204,7 +204,7 @@ def record_query(context: ContextConfig | None, code: str, error: str, issue: st

cur = db.execute(
"INSERT INTO queries (context_name, context_string_id, code, error, issue, user_id, role_id) VALUES (?, ?, ?, ?, ?, ?, ?)",
[context_name, context_string_id, code, error, issue, auth['user_id'], role_id]
[context_name, context_string_id, code, error, issue, auth.user_id, role_id]
)
new_row_id = cur.lastrowid
db.commit()
Expand Down Expand Up @@ -253,7 +253,7 @@ def help_request(llm: LLMConfig) -> Response:
def load_test(llm: LLMConfig) -> Response:
# Require that we're logged in as the load_test admin user
auth = get_auth()
if auth['display_name'] != 'load_test':
if auth.display_name != 'load_test':
return abort(403)

context = ContextConfig(name="__LOADTEST_Context")
Expand All @@ -279,7 +279,7 @@ def post_helpful() -> str:

query_id = int(request.form['id'])
value = int(request.form['value'])
db.execute("UPDATE queries SET helpful=? WHERE id=? AND user_id=?", [value, query_id, auth['user_id']])
db.execute("UPDATE queries SET helpful=? WHERE id=? AND user_id=?", [value, query_id, auth.user_id])
db.commit()
return ""

Expand Down
4 changes: 2 additions & 2 deletions src/codehelp/templates/context_config.html
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ <h3 class="title is-4">Schedule '<span x-text="ctx.name"></span>'</h3>
{% macro link_display_alpinejs(route) -%}
{# unholy combination of Jinja templating (route available in python) and Javascript string replacement (context available in JS)... #}
{
link_URL: '{{ url_for(route, class_id=auth['class_id'], ctx_name='__replace__', _external=True) }}'.replace('__replace__', encodeURIComponent(ctx.name)),
link_URL: '{{ url_for(route, class_id=auth.class_id, ctx_name='__replace__', _external=True) }}'.replace('__replace__', encodeURIComponent(ctx.name)),
copied: false,
copy_url() {
navigator.clipboard.writeText(this.link_URL);
Expand Down Expand Up @@ -207,7 +207,7 @@ <h3 class="title is-4">
</div>
</div>
</div>
{% if "chats_experiment" in auth['class_experiments'] %}
{% if "chats_experiment" in auth.class_experiments %}
<div class="field has-addons is-horizontal" x-data="{{ link_display_alpinejs('tutor.tutor_form') }}">
<div class="field-label label is-normal">Tutor Chat:</div>
<div class="field-body" style="flex-grow: 5;">
Expand Down
4 changes: 2 additions & 2 deletions src/codehelp/templates/context_edit_form.html
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
<div class="container">
{% if context %}
{# We're editing an existing context. #}
<h1 class="title">Editing context '{{ context.name }}' in class {{ auth['class_name'] }}</h1>
<h1 class="title">Editing context '{{ context.name }}' in class {{ auth.class_name }}</h1>
<form class="wide-labels" action="{{ url_for(".update_context", ctx_id=context.id) }}" method="post">
{% else %}
<h1 class="title">Create context in class {{ auth['class_name'] }}</h1>
<h1 class="title">Create context in class {{ auth.class_name }}</h1>
<form class="wide-labels" action="{{ url_for(".create_context") }}" method="post">
{% endif %}

Expand Down
4 changes: 2 additions & 2 deletions src/codehelp/templates/help_form.html
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
{# debounce on the submit handler so that the form's actual submit fires *before* the form elements are disabled #}
<form class="wide-labels" action="{{url_for('helper.help_request')}}" method="post" x-data="{loading: false}" x-on:pageshow.window="loading = false" x-on:submit.debounce.10ms="loading = true">

{% if auth['class_name'] %}
{% if auth.class_name %}
<div class="field is-horizontal">
<div class="field-label">
<label class="label">Class:</label>
</div>
<div class="field-body">
{{ auth['class_name'] }}
{{ auth.class_name }}
</div>
</div>
{% elif llm.tokens_remaining != None %}
Expand Down
6 changes: 3 additions & 3 deletions src/codehelp/templates/help_view.html
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
{% if query %}

<div class="container">
{% if auth['user_id'] != query.user_id %}
{% if auth.user_id != query.user_id %}
<div class="field is-horizontal">
<div class="field-label">
<label class="label">User:</label>
Expand Down Expand Up @@ -124,7 +124,7 @@ <h1><span class="title is-size-4">Response</span> <span class="subtitle ml-5 is-
</div>
</div>

{% if auth['user_id'] == query.user_id and 'main' in responses %}
{% if auth.user_id == query.user_id and 'main' in responses %}
<div class="card-content p-2 pl-5" style="background: #e8e8e8;" x-data="{helpful: {{"null" if query.helpful == None else query.helpful}}}">
<script type="text/javascript">
function post_helpful(value) {
Expand Down Expand Up @@ -163,7 +163,7 @@ <h1><span class="title is-size-4">Response</span> <span class="subtitle ml-5 is-
</div>
{% endif %}

{% if auth['is_tester'] and 'main' in responses %}
{% if auth.is_tester and 'main' in responses %}
<div class="card-content content p-2 pl-5">
<h2 class="is-size-5">Related Topics</h2>
{% if topics %}
Expand Down
4 changes: 2 additions & 2 deletions src/codehelp/templates/landing.html
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ <h2 class="has-text-white">Ask it...</h2>
</div>
</div>
<p class="has-text-centered mt-5 mb-3">
{% if auth['user_id'] %}
{% if auth.user_id %}
<a class="button is-link is-light is-rounded is-size-4" href="{{ url_for('helper.help_form') }}">
Try it now!
</a>
Expand Down Expand Up @@ -66,7 +66,7 @@ <h2>For Instructors</h2>
<li>Everyone will sign in <b>automatically</b> (no separate login) via a link from your course page.</li>
<li>Takes some time to set up, and may require support from your LMS administrator.</li>
</ul>
{% if auth['user_id'] and auth['auth_provider'] != 'demo' %}
{% if auth.user_id and auth.auth_provider != 'demo' %}
<a class="button button-inline is-link is-size-5" href="{{ url_for("profile.main") }}">Go to your Profile page</a> to manually create a class.
{% else %}
<a class="button button-inline is-link is-size-5" href="{{ url_for("auth.login", next=url_for("profile.main")) }}">Sign in using Google, GitHub, or Microsoft</a> and manually create a class from your profile page.
Expand Down
2 changes: 1 addition & 1 deletion src/codehelp/templates/tutor_nav_item.html
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
SPDX-License-Identifier: AGPL-3.0-only
#}

{% if auth['user_id'] and "chats_experiment" in auth['class_experiments'] %}
{% if auth.user_id and "chats_experiment" in auth.class_experiments %}
<a class="navbar-item has-text-success" href="{{ url_for('tutor.tutor_form') }}">
<div class="icon-text">
<span class="icon">
Expand Down
12 changes: 6 additions & 6 deletions src/codehelp/tutor.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ def chat_interface(chat_id: int) -> str | Response:
def create_chat(topic: str, context: ContextConfig | None) -> int:
db = get_db()
auth = get_auth()
user_id = auth['user_id']
role_id = auth['role_id']
user_id = auth.user_id
role_id = auth.role_id

if context is not None:
context_name = context.name
Expand All @@ -172,7 +172,7 @@ def get_chat_history(limit: int = 10) -> list[Row]:
db = get_db()
auth = get_auth()

history = db.execute("SELECT * FROM chats WHERE user_id=? ORDER BY id DESC LIMIT ?", [auth['user_id'], limit]).fetchall()
history = db.execute("SELECT * FROM chats WHERE user_id=? ORDER BY id DESC LIMIT ?", [auth.user_id, limit]).fetchall()
return history


Expand All @@ -194,9 +194,9 @@ def get_chat(chat_id: int) -> tuple[list[ChatCompletionMessageParam], str, str,
raise ChatNotFoundError

access_allowed = \
(auth['user_id'] == chat_row['user_id']) \
or auth['is_admin'] \
or (auth['role'] == 'instructor' and auth['class_id'] == chat_row['class_id'])
(auth.user_id == chat_row['user_id']) \
or auth.is_admin \
or (auth.role == 'instructor' and auth.class_id == chat_row['class_id'])

if not access_allowed:
raise AccessDeniedError
Expand Down
Loading

0 comments on commit 101650b

Please sign in to comment.