From 8697311bc82d24752d5c273ecffec808a55f5c71 Mon Sep 17 00:00:00 2001 From: rongzhang Date: Wed, 3 Jul 2024 23:25:56 +0000 Subject: [PATCH] use execution metadata to store access token flag --- querybook/server/app/auth/utils.py | 9 ++------ .../server/datasources/query_execution.py | 21 +++++++++++-------- .../lib/query_executor/base_executor.py | 12 +++++++++-- .../lib/query_executor/executor_factory.py | 8 +------ querybook/server/tasks/run_query.py | 10 ++------- 5 files changed, 27 insertions(+), 33 deletions(-) diff --git a/querybook/server/app/auth/utils.py b/querybook/server/app/auth/utils.py index dbe778438..1a204c58c 100644 --- a/querybook/server/app/auth/utils.py +++ b/querybook/server/app/auth/utils.py @@ -19,13 +19,8 @@ class AuthenticationError(Exception): class AuthUser(UserMixin): - def __init__(self, user: User, api_access_token=False): + def __init__(self, user: User): self._user_dict = user.to_dict(with_roles=True) - self._api_access_token = api_access_token - - @property - def api_access_token(self): - return self._api_access_token @property def id(self): @@ -77,7 +72,7 @@ def load_user_with_api_access_token(request): if token_validation: if token_validation.enabled: user = get_user_by_id(token_validation.creator_uid, session=session) - return AuthUser(user, api_access_token=True) + return AuthUser(user) else: flask.abort( UNAUTHORIZED_STATUS_CODE, description="Token is disabled." diff --git a/querybook/server/datasources/query_execution.py b/querybook/server/datasources/query_execution.py index c6d800e31..57571d60a 100644 --- a/querybook/server/datasources/query_execution.py +++ b/querybook/server/datasources/query_execution.py @@ -1,6 +1,6 @@ from datetime import datetime -from flask import abort, Response, redirect +from flask import abort, Response, redirect, request from flask_login import current_user from app.flask_app import socketio @@ -72,10 +72,14 @@ def create_query_execution( query=query, engine_id=engine_id, uid=uid, session=session ) - if metadata: - logic.create_query_execution_metadata( - query_execution.id, metadata, session=session - ) + api_access_token = ( + True if request.headers.get("api-access-token", None) else False + ) + metadata = metadata or {} + metadata["api_access_token"] = api_access_token + logic.create_query_execution_metadata( + query_execution.id, metadata, session=session + ) data_doc = None if data_cell_id: @@ -87,10 +91,9 @@ def create_query_execution( try: run_query_task.apply_async( - kwargs={ - "query_execution_id": query_execution.id, - "api_access_token": current_user.api_access_token, - }, + args=[ + query_execution.id, + ] ) query_execution_dict = query_execution.to_dict() diff --git a/querybook/server/lib/query_executor/base_executor.py b/querybook/server/lib/query_executor/base_executor.py index a6db1fc72..805815e2c 100644 --- a/querybook/server/lib/query_executor/base_executor.py +++ b/querybook/server/lib/query_executor/base_executor.py @@ -502,12 +502,10 @@ def __init__( statement_ranges, client_setting, execution_type, - api_access_token=False, ): self._query = query self._query_execution_id = query_execution_id self._execution_type = execution_type - self._api_access_token = api_access_token if self.SINGLE_QUERY_QUERY_ENGINE(): self._statement_ranges = [[0, len(query)]] @@ -530,6 +528,16 @@ def __init__( self._client = None self._cursor = None + with DBSession() as session: + query_execution_metadata = ( + qe_logic.get_query_execution_metadata_by_execution_id( + self._query_execution_id, session=session + ).execution_metadata + ) + self._api_access_token = query_execution_metadata.get( + "api_access_token", False + ) + def __del__(self): del self._logger del self._cursor diff --git a/querybook/server/lib/query_executor/executor_factory.py b/querybook/server/lib/query_executor/executor_factory.py index 74ce527f0..39d277ec9 100644 --- a/querybook/server/lib/query_executor/executor_factory.py +++ b/querybook/server/lib/query_executor/executor_factory.py @@ -18,11 +18,7 @@ @with_session def create_executor_from_execution( - query_execution_id, - celery_task, - execution_type, - api_access_token=False, - session=None, + query_execution_id, celery_task, execution_type, session=None ): executor_params, engine = _get_executor_params_and_engine( query_execution_id, @@ -30,8 +26,6 @@ def create_executor_from_execution( execution_type=execution_type, session=session, ) - executor_params["api_access_token"] = api_access_token - executor = get_executor_class(engine.language, engine.executor)(**executor_params) return executor diff --git a/querybook/server/tasks/run_query.py b/querybook/server/tasks/run_query.py index da7094403..20f379169 100644 --- a/querybook/server/tasks/run_query.py +++ b/querybook/server/tasks/run_query.py @@ -29,10 +29,7 @@ acks_late=True, ) def run_query_task( - self, - query_execution_id, - execution_type=QueryExecutionType.ADHOC.value, - api_access_token=False, + self, query_execution_id, execution_type=QueryExecutionType.ADHOC.value ): stats_logger.incr(QUERY_EXECUTIONS, tags={"execution_type": execution_type}) @@ -42,10 +39,7 @@ def run_query_task( try: executor = create_executor_from_execution( - query_execution_id, - celery_task=self, - execution_type=execution_type, - api_access_token=api_access_token, + query_execution_id, celery_task=self, execution_type=execution_type ) run_executor_until_finish(self, executor) except SoftTimeLimitExceeded: