From 9dc92ae4bb4139fb9cdd88deef40fdb1e80be44f Mon Sep 17 00:00:00 2001 From: rongzhang Date: Wed, 3 Jul 2024 19:50:45 +0000 Subject: [PATCH] feat: pass api_access_token to executor client --- querybook/server/app/auth/utils.py | 9 +++++++-- querybook/server/datasources/query_execution.py | 7 ++++--- querybook/server/lib/query_executor/base_executor.py | 2 ++ .../server/lib/query_executor/executor_factory.py | 8 +++++++- querybook/server/tasks/run_query.py | 10 ++++++++-- 5 files changed, 28 insertions(+), 8 deletions(-) diff --git a/querybook/server/app/auth/utils.py b/querybook/server/app/auth/utils.py index 1a204c58c..dbe778438 100644 --- a/querybook/server/app/auth/utils.py +++ b/querybook/server/app/auth/utils.py @@ -19,8 +19,13 @@ class AuthenticationError(Exception): class AuthUser(UserMixin): - def __init__(self, user: User): + def __init__(self, user: User, api_access_token=False): self._user_dict = user.to_dict(with_roles=True) + self._api_access_token = api_access_token + + @property + def api_access_token(self): + return self._api_access_token @property def id(self): @@ -72,7 +77,7 @@ def load_user_with_api_access_token(request): if token_validation: if token_validation.enabled: user = get_user_by_id(token_validation.creator_uid, session=session) - return AuthUser(user) + return AuthUser(user, api_access_token=True) else: flask.abort( UNAUTHORIZED_STATUS_CODE, description="Token is disabled." diff --git a/querybook/server/datasources/query_execution.py b/querybook/server/datasources/query_execution.py index 1c0c72876..c6d800e31 100644 --- a/querybook/server/datasources/query_execution.py +++ b/querybook/server/datasources/query_execution.py @@ -87,9 +87,10 @@ def create_query_execution( try: run_query_task.apply_async( - args=[ - query_execution.id, - ] + kwargs={ + "query_execution_id": query_execution.id, + "api_access_token": current_user.api_access_token, + }, ) query_execution_dict = query_execution.to_dict() diff --git a/querybook/server/lib/query_executor/base_executor.py b/querybook/server/lib/query_executor/base_executor.py index 0bfb45ca4..a6db1fc72 100644 --- a/querybook/server/lib/query_executor/base_executor.py +++ b/querybook/server/lib/query_executor/base_executor.py @@ -502,10 +502,12 @@ def __init__( statement_ranges, client_setting, execution_type, + api_access_token=False, ): self._query = query self._query_execution_id = query_execution_id self._execution_type = execution_type + self._api_access_token = api_access_token if self.SINGLE_QUERY_QUERY_ENGINE(): self._statement_ranges = [[0, len(query)]] diff --git a/querybook/server/lib/query_executor/executor_factory.py b/querybook/server/lib/query_executor/executor_factory.py index 39d277ec9..74ce527f0 100644 --- a/querybook/server/lib/query_executor/executor_factory.py +++ b/querybook/server/lib/query_executor/executor_factory.py @@ -18,7 +18,11 @@ @with_session def create_executor_from_execution( - query_execution_id, celery_task, execution_type, session=None + query_execution_id, + celery_task, + execution_type, + api_access_token=False, + session=None, ): executor_params, engine = _get_executor_params_and_engine( query_execution_id, @@ -26,6 +30,8 @@ def create_executor_from_execution( execution_type=execution_type, session=session, ) + executor_params["api_access_token"] = api_access_token + executor = get_executor_class(engine.language, engine.executor)(**executor_params) return executor diff --git a/querybook/server/tasks/run_query.py b/querybook/server/tasks/run_query.py index 20f379169..da7094403 100644 --- a/querybook/server/tasks/run_query.py +++ b/querybook/server/tasks/run_query.py @@ -29,7 +29,10 @@ acks_late=True, ) def run_query_task( - self, query_execution_id, execution_type=QueryExecutionType.ADHOC.value + self, + query_execution_id, + execution_type=QueryExecutionType.ADHOC.value, + api_access_token=False, ): stats_logger.incr(QUERY_EXECUTIONS, tags={"execution_type": execution_type}) @@ -39,7 +42,10 @@ def run_query_task( try: executor = create_executor_from_execution( - query_execution_id, celery_task=self, execution_type=execution_type + query_execution_id, + celery_task=self, + execution_type=execution_type, + api_access_token=api_access_token, ) run_executor_until_finish(self, executor) except SoftTimeLimitExceeded: