From 101650bac193360a5452ca7eeea05146b634d59c Mon Sep 17 00:00:00 2001 From: Mark Liffiton Date: Thu, 28 Nov 2024 00:23:53 -0600 Subject: [PATCH] Refactor auth dicts to dataclasses for better type checking and safety. --- src/codehelp/context.py | 8 +- src/codehelp/context_config.py | 20 +-- src/codehelp/helper.py | 10 +- src/codehelp/templates/context_config.html | 4 +- src/codehelp/templates/context_edit_form.html | 4 +- src/codehelp/templates/help_form.html | 4 +- src/codehelp/templates/help_view.html | 6 +- src/codehelp/templates/landing.html | 4 +- src/codehelp/templates/tutor_nav_item.html | 2 +- src/codehelp/tutor.py | 12 +- src/gened/auth.py | 164 +++++++++--------- src/gened/class_config.py | 2 +- src/gened/classes.py | 8 +- src/gened/demo.py | 2 +- src/gened/experiments.py | 2 +- src/gened/instructor.py | 20 +-- src/gened/openai.py | 8 +- src/gened/profile.py | 4 +- src/gened/queries.py | 10 +- src/gened/templates/base.html | 24 +-- src/gened/templates/instructor_base.html | 2 +- .../templates/instructor_class_config.html | 4 +- src/gened/templates/profile_view.html | 8 +- src/starburst/helper.py | 6 +- src/starburst/templates/help_view.html | 4 +- src/starburst/templates/landing.html | 2 +- tests/test_auth.py | 41 +++-- 27 files changed, 193 insertions(+), 192 deletions(-) diff --git a/src/codehelp/context.py b/src/codehelp/context.py index 6a01410..e23db66 100644 --- a/src/codehelp/context.py +++ b/src/codehelp/context.py @@ -115,7 +115,7 @@ def get_available_contexts() -> list[ContextConfig]: db = get_db() auth = get_auth() - class_id = auth['class_id'] + class_id = auth.class_id # Only return contexts that are available: # current date anywhere on earth (using UTC+12) is at or after the saved date context_rows = db.execute("SELECT * FROM contexts WHERE class_id=? AND available <= date('now', '+12 hours') ORDER BY class_order ASC", [class_id]).fetchall() @@ -130,12 +130,12 @@ def get_context_string_by_id(ctx_id: int) -> str | None: db = get_db() auth = get_auth() - if auth['is_admin']: + if auth.is_admin: # admin can grab any context context_row = db.execute("SELECT * FROM context_strings WHERE id=?", [ctx_id]).fetchone() else: # for non-admin users, double-check that the context is in the current class - class_id = auth['class_id'] + class_id = auth.class_id context_row = db.execute("SELECT * FROM context_strings WHERE class_id=? AND id=?", [class_id, ctx_id]).fetchone() if not context_row: @@ -151,7 +151,7 @@ def get_context_by_name(ctx_name: str) -> ContextConfig | None: db = get_db() auth = get_auth() - class_id = auth['class_id'] + class_id = auth.class_id context_row = db.execute("SELECT * FROM contexts WHERE class_id=? AND name=?", [class_id, ctx_name]).fetchone() diff --git a/src/codehelp/context_config.py b/src/codehelp/context_config.py index 71d293b..017ef68 100644 --- a/src/codehelp/context_config.py +++ b/src/codehelp/context_config.py @@ -51,7 +51,7 @@ def register(app: Flask) -> None: def config_section_render() -> Markup: db = get_db() auth = get_auth() - class_id = auth['class_id'] + class_id = auth.class_id contexts = db.execute(""" SELECT id, name, CAST(available AS TEXT) AS available @@ -84,7 +84,7 @@ def decorated_function(*args: P.args, **kwargs: P.kwargs) -> Response | R: auth = get_auth() # verify the given context is in the user's current class - class_id = auth['class_id'] + class_id = auth.class_id ctx_id = kwargs['ctx_id'] context_row = db.execute("SELECT * FROM contexts WHERE id=?", [ctx_id]).fetchone() if context_row['class_id'] != class_id: @@ -160,10 +160,10 @@ def _insert_context(class_id: int, name: str, config: str, available: str) -> in @bp.route("/create", methods=["POST"]) def create_context() -> Response: auth = get_auth() - assert auth['class_id'] + assert auth.class_id context = ContextConfig.from_request_form(request.form) - _insert_context(auth['class_id'], context.name, context.to_json(), "9999-12-31") # defaults to hidden + _insert_context(auth.class_id, context.name, context.to_json(), "9999-12-31") # defaults to hidden return redirect(url_for("class_config.config_form")) @@ -172,11 +172,11 @@ def create_context() -> Response: @check_valid_context def copy_context(ctx_row: Row, ctx_id: int) -> Response: auth = get_auth() - assert auth['class_id'] + assert auth.class_id # passing existing name, but _insert_context will take care of finding # a new, unused name in the class. - _insert_context(auth['class_id'], ctx_row['name'], ctx_row['config'], ctx_row['available']) + _insert_context(auth.class_id, ctx_row['name'], ctx_row['config'], ctx_row['available']) return redirect(url_for("class_config.config_form")) @@ -189,8 +189,8 @@ def update_context(ctx_id: int, ctx_row: Row) -> Response: # names must be unique within a class: check/look for an unused name auth = get_auth() - assert auth['class_id'] - name = _make_unique_context_name(auth['class_id'], context.name, ctx_id) + assert auth.class_id + name = _make_unique_context_name(auth.class_id, context.name, ctx_id) db.execute("UPDATE contexts SET name=?, config=? WHERE id=?", [name, context.to_json(), ctx_id]) db.commit() @@ -217,7 +217,7 @@ def update_order() -> str: db = get_db() auth = get_auth() - class_id = auth['class_id'] # Get the current class to ensure we don't change another class. + class_id = auth.class_id # Get the current class to ensure we don't change another class. ordered_ids = request.json assert isinstance(ordered_ids, list) @@ -235,7 +235,7 @@ def update_available() -> str: db = get_db() auth = get_auth() - class_id = auth['class_id'] # Get the current class to ensure we don't change another class. + class_id = auth.class_id # Get the current class to ensure we don't change another class. data = request.json assert isinstance(data, dict) diff --git a/src/codehelp/helper.py b/src/codehelp/helper.py index 3c1aa08..7218fae 100644 --- a/src/codehelp/helper.py +++ b/src/codehelp/helper.py @@ -81,7 +81,7 @@ def help_form(llm: LLMConfig, query_id: int | None = None, class_id: int | None else: # no query specified, # but we can pre-select the most recently used context, if available - recent_row = db.execute("SELECT context_name FROM queries WHERE queries.user_id=? ORDER BY id DESC LIMIT 1", [auth['user_id']]).fetchone() + recent_row = db.execute("SELECT context_name FROM queries WHERE queries.user_id=? ORDER BY id DESC LIMIT 1", [auth.user_id]).fetchone() if recent_row: selected_context_name = recent_row['context_name'] @@ -192,7 +192,7 @@ def run_query(llm: LLMConfig, context: ContextConfig | None, code: str, error: s def record_query(context: ContextConfig | None, code: str, error: str, issue: str) -> int: db = get_db() auth = get_auth() - role_id = auth['role_id'] + role_id = auth.role_id if context is not None: context_name = context.name @@ -204,7 +204,7 @@ def record_query(context: ContextConfig | None, code: str, error: str, issue: st cur = db.execute( "INSERT INTO queries (context_name, context_string_id, code, error, issue, user_id, role_id) VALUES (?, ?, ?, ?, ?, ?, ?)", - [context_name, context_string_id, code, error, issue, auth['user_id'], role_id] + [context_name, context_string_id, code, error, issue, auth.user_id, role_id] ) new_row_id = cur.lastrowid db.commit() @@ -253,7 +253,7 @@ def help_request(llm: LLMConfig) -> Response: def load_test(llm: LLMConfig) -> Response: # Require that we're logged in as the load_test admin user auth = get_auth() - if auth['display_name'] != 'load_test': + if auth.display_name != 'load_test': return abort(403) context = ContextConfig(name="__LOADTEST_Context") @@ -279,7 +279,7 @@ def post_helpful() -> str: query_id = int(request.form['id']) value = int(request.form['value']) - db.execute("UPDATE queries SET helpful=? WHERE id=? AND user_id=?", [value, query_id, auth['user_id']]) + db.execute("UPDATE queries SET helpful=? WHERE id=? AND user_id=?", [value, query_id, auth.user_id]) db.commit() return "" diff --git a/src/codehelp/templates/context_config.html b/src/codehelp/templates/context_config.html index dce1673..6985001 100644 --- a/src/codehelp/templates/context_config.html +++ b/src/codehelp/templates/context_config.html @@ -172,7 +172,7 @@

Schedule ''

{% macro link_display_alpinejs(route) -%} {# unholy combination of Jinja templating (route available in python) and Javascript string replacement (context available in JS)... #} { - link_URL: '{{ url_for(route, class_id=auth['class_id'], ctx_name='__replace__', _external=True) }}'.replace('__replace__', encodeURIComponent(ctx.name)), + link_URL: '{{ url_for(route, class_id=auth.class_id, ctx_name='__replace__', _external=True) }}'.replace('__replace__', encodeURIComponent(ctx.name)), copied: false, copy_url() { navigator.clipboard.writeText(this.link_URL); @@ -207,7 +207,7 @@

- {% if "chats_experiment" in auth['class_experiments'] %} + {% if "chats_experiment" in auth.class_experiments %}
Tutor Chat:
diff --git a/src/codehelp/templates/context_edit_form.html b/src/codehelp/templates/context_edit_form.html index 15fffcf..d41eed6 100644 --- a/src/codehelp/templates/context_edit_form.html +++ b/src/codehelp/templates/context_edit_form.html @@ -11,10 +11,10 @@
{% if context %} {# We're editing an existing context. #} -

Editing context '{{ context.name }}' in class {{ auth['class_name'] }}

+

Editing context '{{ context.name }}' in class {{ auth.class_name }}

{% else %} -

Create context in class {{ auth['class_name'] }}

+

Create context in class {{ auth.class_name }}

{% endif %} diff --git a/src/codehelp/templates/help_form.html b/src/codehelp/templates/help_form.html index af76d1c..1e62a0c 100644 --- a/src/codehelp/templates/help_form.html +++ b/src/codehelp/templates/help_form.html @@ -16,13 +16,13 @@ {# debounce on the submit handler so that the form's actual submit fires *before* the form elements are disabled #} - {% if auth['class_name'] %} + {% if auth.class_name %}
- {{ auth['class_name'] }} + {{ auth.class_name }}
{% elif llm.tokens_remaining != None %} diff --git a/src/codehelp/templates/help_view.html b/src/codehelp/templates/help_view.html index 307ea3f..d703120 100644 --- a/src/codehelp/templates/help_view.html +++ b/src/codehelp/templates/help_view.html @@ -21,7 +21,7 @@ {% if query %}
- {% if auth['user_id'] != query.user_id %} + {% if auth.user_id != query.user_id %}
@@ -124,7 +124,7 @@

Response