diff --git a/src/codehelp/context.py b/src/codehelp/context.py index e23db66..5e118c3 100644 --- a/src/codehelp/context.py +++ b/src/codehelp/context.py @@ -115,7 +115,7 @@ def get_available_contexts() -> list[ContextConfig]: db = get_db() auth = get_auth() - class_id = auth.class_id + class_id = auth.cur_class.class_id if auth.cur_class else None # Only return contexts that are available: # current date anywhere on earth (using UTC+12) is at or after the saved date context_rows = db.execute("SELECT * FROM contexts WHERE class_id=? AND available <= date('now', '+12 hours') ORDER BY class_order ASC", [class_id]).fetchall() @@ -135,7 +135,7 @@ def get_context_string_by_id(ctx_id: int) -> str | None: context_row = db.execute("SELECT * FROM context_strings WHERE id=?", [ctx_id]).fetchone() else: # for non-admin users, double-check that the context is in the current class - class_id = auth.class_id + class_id = auth.cur_class.class_id if auth.cur_class else None context_row = db.execute("SELECT * FROM context_strings WHERE class_id=? AND id=?", [class_id, ctx_id]).fetchone() if not context_row: @@ -151,7 +151,7 @@ def get_context_by_name(ctx_name: str) -> ContextConfig | None: db = get_db() auth = get_auth() - class_id = auth.class_id + class_id = auth.cur_class.class_id if auth.cur_class else None context_row = db.execute("SELECT * FROM contexts WHERE class_id=? AND name=?", [class_id, ctx_name]).fetchone() diff --git a/src/codehelp/context_config.py b/src/codehelp/context_config.py index 017ef68..abb4df2 100644 --- a/src/codehelp/context_config.py +++ b/src/codehelp/context_config.py @@ -20,7 +20,7 @@ from markupsafe import Markup from werkzeug.wrappers.response import Response -from gened.auth import get_auth, instructor_required +from gened.auth import get_auth_class, instructor_required from gened.class_config import register_extra_section from gened.db import get_db @@ -50,8 +50,8 @@ def register(app: Flask) -> None: def config_section_render() -> Markup: db = get_db() - auth = get_auth() - class_id = auth.class_id + cur_class = get_auth_class() + class_id = cur_class.class_id contexts = db.execute(""" SELECT id, name, CAST(available AS TEXT) AS available @@ -81,10 +81,10 @@ def check_valid_context(f: Callable[P, R]) -> Callable[P, Response | R]: @wraps(f) def decorated_function(*args: P.args, **kwargs: P.kwargs) -> Response | R: db = get_db() - auth = get_auth() + cur_class = get_auth_class() # verify the given context is in the user's current class - class_id = auth.class_id + class_id = cur_class.class_id ctx_id = kwargs['ctx_id'] context_row = db.execute("SELECT * FROM contexts WHERE id=?", [ctx_id]).fetchone() if context_row['class_id'] != class_id: @@ -159,11 +159,10 @@ def _insert_context(class_id: int, name: str, config: str, available: str) -> in @bp.route("/create", methods=["POST"]) def create_context() -> Response: - auth = get_auth() - assert auth.class_id + cur_class = get_auth_class() context = ContextConfig.from_request_form(request.form) - _insert_context(auth.class_id, context.name, context.to_json(), "9999-12-31") # defaults to hidden + _insert_context(cur_class.class_id, context.name, context.to_json(), "9999-12-31") # defaults to hidden return redirect(url_for("class_config.config_form")) @@ -171,12 +170,11 @@ def create_context() -> Response: @bp.route("/copy/", methods=["POST"]) @check_valid_context def copy_context(ctx_row: Row, ctx_id: int) -> Response: - auth = get_auth() - assert auth.class_id + cur_class = get_auth_class() # passing existing name, but _insert_context will take care of finding # a new, unused name in the class. - _insert_context(auth.class_id, ctx_row['name'], ctx_row['config'], ctx_row['available']) + _insert_context(cur_class.class_id, ctx_row['name'], ctx_row['config'], ctx_row['available']) return redirect(url_for("class_config.config_form")) @@ -188,9 +186,8 @@ def update_context(ctx_id: int, ctx_row: Row) -> Response: context = ContextConfig.from_request_form(request.form) # names must be unique within a class: check/look for an unused name - auth = get_auth() - assert auth.class_id - name = _make_unique_context_name(auth.class_id, context.name, ctx_id) + cur_class = get_auth_class() + name = _make_unique_context_name(cur_class.class_id, context.name, ctx_id) db.execute("UPDATE contexts SET name=?, config=? WHERE id=?", [name, context.to_json(), ctx_id]) db.commit() @@ -215,9 +212,9 @@ def delete_context(ctx_id: int, ctx_row: Row) -> Response: @bp.route("/update_order", methods=["POST"]) def update_order() -> str: db = get_db() - auth = get_auth() + cur_class = get_auth_class() - class_id = auth.class_id # Get the current class to ensure we don't change another class. + class_id = cur_class.class_id # Get the current class to ensure we don't change another class. ordered_ids = request.json assert isinstance(ordered_ids, list) @@ -233,9 +230,9 @@ def update_order() -> str: @bp.route("/update_available", methods=["POST"]) def update_available() -> str: db = get_db() - auth = get_auth() + cur_class = get_auth_class() - class_id = auth.class_id # Get the current class to ensure we don't change another class. + class_id = cur_class.class_id # Get the current class to ensure we don't change another class. data = request.json assert isinstance(data, dict) diff --git a/src/codehelp/helper.py b/src/codehelp/helper.py index 7218fae..5b37ce9 100644 --- a/src/codehelp/helper.py +++ b/src/codehelp/helper.py @@ -192,7 +192,7 @@ def run_query(llm: LLMConfig, context: ContextConfig | None, code: str, error: s def record_query(context: ContextConfig | None, code: str, error: str, issue: str) -> int: db = get_db() auth = get_auth() - role_id = auth.role_id + role_id = auth.cur_class.role_id if auth.cur_class else None if context is not None: context_name = context.name @@ -253,7 +253,7 @@ def help_request(llm: LLMConfig) -> Response: def load_test(llm: LLMConfig) -> Response: # Require that we're logged in as the load_test admin user auth = get_auth() - if auth.display_name != 'load_test': + if auth.user is None or auth.user.display_name != 'load_test': return abort(403) context = ContextConfig(name="__LOADTEST_Context") diff --git a/src/codehelp/templates/context_config.html b/src/codehelp/templates/context_config.html index 6985001..94a7123 100644 --- a/src/codehelp/templates/context_config.html +++ b/src/codehelp/templates/context_config.html @@ -172,7 +172,7 @@

Schedule ''

{% macro link_display_alpinejs(route) -%} {# unholy combination of Jinja templating (route available in python) and Javascript string replacement (context available in JS)... #} { - link_URL: '{{ url_for(route, class_id=auth.class_id, ctx_name='__replace__', _external=True) }}'.replace('__replace__', encodeURIComponent(ctx.name)), + link_URL: '{{ url_for(route, class_id=auth.cur_class.class_id, ctx_name='__replace__', _external=True) }}'.replace('__replace__', encodeURIComponent(ctx.name)), copied: false, copy_url() { navigator.clipboard.writeText(this.link_URL); diff --git a/src/codehelp/templates/context_edit_form.html b/src/codehelp/templates/context_edit_form.html index d41eed6..74b9db6 100644 --- a/src/codehelp/templates/context_edit_form.html +++ b/src/codehelp/templates/context_edit_form.html @@ -11,10 +11,10 @@
{% if context %} {# We're editing an existing context. #} -

Editing context '{{ context.name }}' in class {{ auth.class_name }}

+

Editing context '{{ context.name }}' in class {{ auth.cur_class.class_name }}

{% else %} -

Create context in class {{ auth.class_name }}

+

Create context in class {{ auth.cur_class.class_name }}

{% endif %} diff --git a/src/codehelp/templates/help_form.html b/src/codehelp/templates/help_form.html index 1e62a0c..b4d5170 100644 --- a/src/codehelp/templates/help_form.html +++ b/src/codehelp/templates/help_form.html @@ -16,13 +16,13 @@ {# debounce on the submit handler so that the form's actual submit fires *before* the form elements are disabled #} - {% if auth.class_name %} + {% if auth.cur_class %}
- {{ auth.class_name }} + {{ auth.cur_class.class_name }}
{% elif llm.tokens_remaining != None %} diff --git a/src/codehelp/templates/landing.html b/src/codehelp/templates/landing.html index 24170f5..a3ce4a5 100644 --- a/src/codehelp/templates/landing.html +++ b/src/codehelp/templates/landing.html @@ -33,7 +33,7 @@

Ask it...

- {% if auth.user_id %} + {% if auth.user %} Try it now! @@ -66,7 +66,7 @@

For Instructors

  • Everyone will sign in automatically (no separate login) via a link from your course page.
  • Takes some time to set up, and may require support from your LMS administrator.
  • - {% if auth.user_id and auth.auth_provider != 'demo' %} + {% if auth.user and auth.user.auth_provider != 'demo' %} Go to your Profile page to manually create a class. {% else %} Sign in using Google, GitHub, or Microsoft and manually create a class from your profile page. diff --git a/src/codehelp/templates/tutor_nav_item.html b/src/codehelp/templates/tutor_nav_item.html index 743a315..bb23e10 100644 --- a/src/codehelp/templates/tutor_nav_item.html +++ b/src/codehelp/templates/tutor_nav_item.html @@ -4,7 +4,7 @@ SPDX-License-Identifier: AGPL-3.0-only #} -{% if auth.user_id and "chats_experiment" in auth.class_experiments %} +{% if auth.user and "chats_experiment" in auth.class_experiments %}
    diff --git a/src/codehelp/tutor.py b/src/codehelp/tutor.py index eb3e0bc..0029b27 100644 --- a/src/codehelp/tutor.py +++ b/src/codehelp/tutor.py @@ -146,7 +146,7 @@ def create_chat(topic: str, context: ContextConfig | None) -> int: db = get_db() auth = get_auth() user_id = auth.user_id - role_id = auth.role_id + role_id = auth.cur_class.role_id if auth.cur_class else None if context is not None: context_name = context.name @@ -196,7 +196,7 @@ def get_chat(chat_id: int) -> tuple[list[ChatCompletionMessageParam], str, str, access_allowed = \ (auth.user_id == chat_row['user_id']) \ or auth.is_admin \ - or (auth.role == 'instructor' and auth.class_id == chat_row['class_id']) + or (auth.cur_class and auth.cur_class.role == 'instructor' and auth.cur_class.class_id == chat_row['class_id']) if not access_allowed: raise AccessDeniedError diff --git a/src/gened/auth.py b/src/gened/auth.py index 0b97726..db9229a 100644 --- a/src/gened/auth.py +++ b/src/gened/auth.py @@ -32,26 +32,40 @@ ProviderType = Literal['local', 'lti', 'demo', 'google', 'github', 'microsoft'] RoleType = Literal['instructor', 'student'] +@dataclass(frozen=True) +class UserData: + id: int + display_name: str + auth_provider: ProviderType + is_admin: bool = False + is_tester: bool = False + @dataclass(frozen=True) class ClassData: class_id: int class_name: str + role_id: int role: RoleType @dataclass(frozen=True) class AuthData: - user_id: int | None = None - is_admin: bool = False - is_tester: bool = False - auth_provider: ProviderType | None = None - display_name: str | None = None - class_id: int | None = None # current class ID - class_name: str | None = None # current class name - role_id: int | None = None # current role - role: RoleType | None = None # current role name (e.g., 'instructor') + user: UserData | None = None + cur_class: ClassData | None = None class_experiments: list[str] = field(default_factory=list) # any experiments the current class is registered in other_classes: list[ClassData] = field(default_factory=list) # for storing active classes that are not the user's current class + @property + def user_id(self) -> int | None: + return self.user.id if self.user else None + + @property + def is_admin(self) -> bool: + return bool(self.user and self.user.is_admin) + + @property + def is_tester(self) -> bool: + return bool(self.user and self.user.is_tester) + def _invalidate_g_auth() -> None: """ Ensure no auth data is cached in the g object. @@ -92,7 +106,6 @@ def _get_auth_from_session() -> AuthData: # current user_id (if any). sess_auth = session.get(AUTH_SESSION_KEY, {}) user_id = sess_auth.get('user_id', None) - class_id = sess_auth.get('class_id', None) if not user_id: # No logged in user; return the default/empty auth data @@ -117,10 +130,13 @@ def _get_auth_from_session() -> AuthData: return AuthData() # Collect auth data values - display_name=user_row['display_name'] - is_admin=user_row['is_admin'] - is_tester=user_row['is_tester'] - auth_provider=user_row['auth_provider'] + user = UserData( + id=user_id, + display_name=user_row['display_name'], + auth_provider=user_row['auth_provider'], + is_admin=user_row['is_admin'], + is_tester=user_row['is_tester'], + ) # Check the database for any active roles (may be changed by another user) # and populate class/role information. @@ -138,54 +154,46 @@ def _get_auth_from_session() -> AuthData: ORDER BY roles.id DESC """, [user_id]).fetchall() - role_id = None - role = None + cur_class = None class_experiments = [] other_classes = [] - class_name = None + + sess_class_id = sess_auth.get('class_id', None) for row in role_rows: - if row['class_id'] == class_id: + class_data = ClassData( + class_id=row['class_id'], + class_name=row['name'], + role_id=row['role_id'], + role=row['role'], + ) + if row['class_id'] == sess_class_id: + assert cur_class is None # sanity check: should only ever match one role/class # capture class/role info - role_id = row['role_id'] - role = row['role'] + cur_class = class_data # check for any registered experiments in the current class - experiment_class_rows = db.execute("SELECT experiments.name FROM experiments JOIN experiment_class ON experiment_class.experiment_id=experiments.id WHERE experiment_class.class_id=?", [class_id]).fetchall() + experiment_class_rows = db.execute("SELECT experiments.name FROM experiments JOIN experiment_class ON experiment_class.experiment_id=experiments.id WHERE experiment_class.class_id=?", [sess_class_id]).fetchall() class_experiments = [row['name'] for row in experiment_class_rows] elif row['enabled']: # store a list of any other classes that are enabled (for navbar switching UI) - class_data = ClassData( - class_id=row['class_id'], - class_name=row['name'], - role=row['role'] - ) other_classes.append(class_data) - if not role_id and not is_admin: - # ensure we don't keep a class_id in auth if it's not a valid/active one - class_id = None - - if class_id is not None: - # get the class name (after all the above has shaken out) - class_row = db.execute("SELECT name FROM classes WHERE id=?", [class_id]).fetchone() - class_name = class_row['name'] - # admin gets instructor role in all classes automatically - if is_admin: - role = 'instructor' + # admin gets instructor role in all classes automatically + if user.is_admin and cur_class is None and sess_class_id is not None: + class_row = db.execute("SELECT name FROM classes WHERE id=?", [sess_class_id]).fetchone() + cur_class = ClassData( + class_id=sess_class_id, + class_name=class_row['name'], + role_id=-1, + role='instructor', + ) # return an AuthData with all collected values return AuthData( - user_id=user_id, - is_admin=is_admin, - is_tester=is_tester, - auth_provider=auth_provider, - display_name=display_name, - class_id=class_id, - class_name=class_name, - role_id=role_id, - role=role, + user=user, + cur_class=cur_class, class_experiments=class_experiments, - other_classes=other_classes + other_classes=other_classes, ) @@ -196,6 +204,12 @@ def get_auth() -> AuthData: return g.auth # type: ignore[no-any-return] +def get_auth_class() -> ClassData: + auth = get_auth() + assert auth.cur_class + return auth.cur_class + + def get_last_class(user_id: int) -> int | None: """ Find and return the last class (as a class ID) for the given user, as long as the user still has an active role in that class. @@ -318,7 +332,7 @@ def login_required(f: Callable[P, R]) -> Callable[P, Response | R]: @wraps(f) def decorated_function(*args: P.args, **kwargs: P.kwargs) -> Response | R: auth = get_auth() - if not auth.user_id: + if not auth.user: flash("Login required.", "warning") return redirect(url_for('auth.login', next=request.full_path)) return f(*args, **kwargs) @@ -329,7 +343,7 @@ def instructor_required(f: Callable[P, R]) -> Callable[P, Response | R]: @wraps(f) def decorated_function(*args: P.args, **kwargs: P.kwargs) -> Response | R: auth = get_auth() - if auth.role != "instructor": + if auth.cur_class is None or auth.cur_class.role != "instructor": flash("Instructor login required.", "warning") return redirect(url_for('auth.login', next=request.full_path)) return f(*args, **kwargs) @@ -340,13 +354,13 @@ def class_enabled_required(f: Callable[P, R]) -> Callable[P, str | R]: @wraps(f) def decorated_function(*args: P.args, **kwargs: P.kwargs) -> str | R: auth = get_auth() - class_id = auth.class_id - if class_id is None: + if auth.cur_class is None: # No active class, no problem return f(*args, **kwargs) # Otherwise, there's an active class, so we require it to be enabled. + class_id = auth.cur_class.class_id db = get_db() class_row = db.execute("SELECT * FROM classes WHERE id=?", [class_id]).fetchone() if not class_row['enabled']: diff --git a/src/gened/class_config.py b/src/gened/class_config.py index 90056ad..4077623 100644 --- a/src/gened/class_config.py +++ b/src/gened/class_config.py @@ -12,7 +12,7 @@ render_template, ) -from .auth import get_auth, instructor_required +from .auth import get_auth_class, instructor_required from .db import get_db from .openai import LLMConfig, get_completion, get_models, with_llm from .tz import date_is_past @@ -41,9 +41,9 @@ def register_extra_section(render_func: Callable[[], str]) -> None: @bp.route("/") def config_form() -> str: db = get_db() - auth = get_auth() - class_id = auth.class_id + cur_class = get_auth_class() + class_id = cur_class.class_id class_row = db.execute(""" SELECT classes.id, classes.enabled, classes_user.link_ident, classes_user.link_reg_expires, classes_user.openai_key, classes_user.model_id diff --git a/src/gened/experiments.py b/src/gened/experiments.py index 8260833..cde7036 100644 --- a/src/gened/experiments.py +++ b/src/gened/experiments.py @@ -23,15 +23,16 @@ # Functions for controlling access to experiments based on the current class -def current_class_in_experiment(experiment_name: str) -> bool: +def _current_class_in_experiment(experiment_name: str) -> bool: """ Return True if the current active class is registered in the specified experiment, False otherwise. """ db = get_db() experiment_class_rows = db.execute("SELECT experiment_class.class_id FROM experiments JOIN experiment_class ON experiment_class.experiment_id=experiments.id WHERE experiments.name=?", [experiment_name]).fetchall() experiment_class_ids = [row['class_id'] for row in experiment_class_rows] + auth = get_auth() - return auth.class_id in experiment_class_ids + return auth.cur_class is not None and auth.cur_class.class_id in experiment_class_ids # Decorator for routes designated as part of an experiment # For decorator type hints @@ -42,7 +43,7 @@ def experiment_required(experiment_name: str) -> Callable[[Callable[P, R]], Call def decorator(f: Callable[P, R]) -> Callable[P, Response | R]: @wraps(f) def decorated_function(*args: P.args, **kwargs: P.kwargs) -> Response | R: - if not current_class_in_experiment(experiment_name): + if not _current_class_in_experiment(experiment_name): return abort(404) else: return f(*args, **kwargs) diff --git a/src/gened/instructor.py b/src/gened/instructor.py index b5c3f8d..f4ae8c2 100644 --- a/src/gened/instructor.py +++ b/src/gened/instructor.py @@ -30,7 +30,7 @@ ) from werkzeug.wrappers.response import Response -from .auth import get_auth, instructor_required +from .auth import get_auth, get_auth_class, instructor_required from .classes import switch_class from .csv import csv_response from .db import get_db @@ -119,9 +119,8 @@ def get_users(class_id: int, for_export: bool = False) -> list[Row]: @bp.route("/") def main() -> str | Response: - auth = get_auth() - class_id = auth.class_id - assert class_id is not None + cur_class = get_auth_class() + class_id = cur_class.class_id users = get_users(class_id) @@ -142,11 +141,9 @@ def get_csv(kind: str) -> str | Response: if kind not in ('queries', 'users'): return abort(404) - auth = get_auth() - class_id = auth.class_id - class_name = auth.class_name - assert class_id is not None - assert class_name is not None + cur_class = get_auth_class() + class_id = cur_class.class_id + class_name = cur_class.class_name if kind == "queries": table = get_queries(class_id) @@ -159,10 +156,10 @@ def get_csv(kind: str) -> str | Response: @bp.route("/user_class/set", methods=["POST"]) def set_user_class_setting() -> Response: db = get_db() - auth = get_auth() # only trust class_id from auth, not from user - class_id = auth.class_id + cur_class = get_auth_class() + class_id = cur_class.class_id if 'clear_openai_key' in request.form: db.execute("UPDATE classes_user SET openai_key='' WHERE class_id=?", [class_id]) @@ -199,16 +196,16 @@ def set_user_class_setting() -> Response: @bp.route("/role/set_active//", methods=["POST"]) def set_role_active(role_id: int, bool_active: int) -> str: db = get_db() - auth = get_auth() + cur_class = get_auth_class() # prevent instructors from mistakenly making themselves not active and locking themselves out - if role_id == auth.role_id: + if role_id == cur_class.role_id: return "You cannot make yourself inactive." # class_id should be redundant w/ role_id, but without it, an instructor # could potentially deactivate a role in someone else's class. # only trust class_id from auth, not from user - class_id = auth.class_id + class_id = cur_class.class_id db.execute("UPDATE roles SET active=? WHERE id=? AND class_id=?", [bool_active, role_id, class_id]) db.commit() @@ -220,16 +217,16 @@ def set_role_active(role_id: int, bool_active: int) -> str: @bp.route("/role/set_instructor//", methods=["POST"]) def set_role_instructor(role_id: int, bool_instructor: int) -> str: db = get_db() - auth = get_auth() + cur_class = get_auth_class() # prevent instructors from mistakenly making themselves not instructors and locking themselves out - if role_id == auth.role_id: + if role_id == cur_class.role_id: return "You cannot change your own role." # class_id should be redundant w/ role_id, but without it, an instructor # could potentially deactivate a role in someone else's class. # only trust class_id from auth, not from user - class_id = auth.class_id + class_id = cur_class.class_id new_role = 'instructor' if bool_instructor else 'student' @@ -242,10 +239,9 @@ def set_role_instructor(role_id: int, bool_instructor: int) -> str: @bp.route("/class/delete", methods=["POST"]) def delete_class() -> Response: db = get_db() - auth = get_auth() - class_id = auth.class_id - assert class_id is not None - assert str(auth.class_id) == str(request.form.get('class_id')) + cur_class = get_auth_class() + class_id = cur_class.class_id + assert str(class_id) == str(request.form.get('class_id')) # Require explicit confirmation if request.form.get('confirm_delete') != 'DELETE': diff --git a/src/gened/openai.py b/src/gened/openai.py index 2310b29..d422a05 100644 --- a/src/gened/openai.py +++ b/src/gened/openai.py @@ -77,7 +77,7 @@ def make_system_client(tokens_remaining: int | None = None) -> LLMConfig: auth = get_auth() # Get class data, if there is an active class - if auth.class_id is not None: + if auth.cur_class is not None: class_row = db.execute(""" SELECT classes.enabled, @@ -94,7 +94,7 @@ def make_system_client(tokens_remaining: int | None = None) -> LLMConfig: LEFT JOIN models ON models.id = _model_id WHERE classes.id = ? - """, [auth.class_id]).fetchone() + """, [auth.cur_class.class_id]).fetchone() if not class_row['enabled']: raise ClassDisabledError diff --git a/src/gened/profile.py b/src/gened/profile.py index cdcedaa..782873f 100644 --- a/src/gened/profile.py +++ b/src/gened/profile.py @@ -28,7 +28,7 @@ def main() -> str: WHERE users.id=? """, [user_id]).fetchone() - class_id = auth.class_id or -1 # can't do a != to None/null, so convert that to -1 to match all classes in that case + cur_class_id = auth.cur_class.class_id if auth.cur_class else -1 # can't do a != to None/null in SQL, so convert that to -1 to match all classes in that case other_classes = db.execute(""" SELECT classes.id, @@ -41,7 +41,7 @@ def main() -> str: AND classes.id != ? AND classes.enabled=1 ORDER BY classes.id DESC - """, [user_id, class_id]).fetchall() + """, [user_id, cur_class_id]).fetchall() archived_classes = db.execute(""" SELECT @@ -55,6 +55,6 @@ def main() -> str: AND classes.id != ? AND classes.enabled=0 ORDER BY classes.id DESC - """, [user_id, class_id]).fetchall() + """, [user_id, cur_class_id]).fetchall() return render_template("profile_view.html", user=user, other_classes=other_classes, archived_classes=archived_classes) diff --git a/src/gened/queries.py b/src/gened/queries.py index 0204dfa..0a01cc2 100644 --- a/src/gened/queries.py +++ b/src/gened/queries.py @@ -19,8 +19,8 @@ def get_query(query_id: int) -> tuple[Row, dict[str, str]] | tuple[None, None]: if auth.is_admin: cur = db.execute("SELECT queries.*, users.display_name FROM queries JOIN users ON queries.user_id=users.id WHERE queries.id=?", [query_id]) - elif auth.role == 'instructor': - cur = db.execute("SELECT queries.*, users.display_name FROM queries JOIN users ON queries.user_id=users.id JOIN roles ON queries.role_id=roles.id WHERE (roles.class_id=? OR queries.user_id=?) AND queries.id=?", [auth.class_id, auth.user_id, query_id]) + elif auth.cur_class and auth.cur_class.role == 'instructor': + cur = db.execute("SELECT queries.*, users.display_name FROM queries JOIN users ON queries.user_id=users.id JOIN roles ON queries.role_id=roles.id WHERE (roles.class_id=? OR queries.user_id=?) AND queries.id=?", [auth.cur_class.class_id, auth.user_id, query_id]) else: cur = db.execute("SELECT queries.*, users.display_name FROM queries JOIN users ON queries.user_id=users.id WHERE queries.user_id=? AND queries.id=?", [auth.user_id, query_id]) query_row = cur.fetchone() diff --git a/src/gened/templates/base.html b/src/gened/templates/base.html index 59499a3..e32fb01 100644 --- a/src/gened/templates/base.html +++ b/src/gened/templates/base.html @@ -64,7 +64,7 @@