Skip to content

Commit

Permalink
use execution metadata to store access token flag
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangvi7 committed Jul 3, 2024
1 parent 90e60b2 commit 8697311
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 33 deletions.
9 changes: 2 additions & 7 deletions querybook/server/app/auth/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,8 @@ class AuthenticationError(Exception):


class AuthUser(UserMixin):
def __init__(self, user: User, api_access_token=False):
def __init__(self, user: User):
self._user_dict = user.to_dict(with_roles=True)
self._api_access_token = api_access_token

@property
def api_access_token(self):
return self._api_access_token

@property
def id(self):
Expand Down Expand Up @@ -77,7 +72,7 @@ def load_user_with_api_access_token(request):
if token_validation:
if token_validation.enabled:
user = get_user_by_id(token_validation.creator_uid, session=session)
return AuthUser(user, api_access_token=True)
return AuthUser(user)
else:
flask.abort(
UNAUTHORIZED_STATUS_CODE, description="Token is disabled."
Expand Down
21 changes: 12 additions & 9 deletions querybook/server/datasources/query_execution.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from datetime import datetime

from flask import abort, Response, redirect
from flask import abort, Response, redirect, request
from flask_login import current_user

from app.flask_app import socketio
Expand Down Expand Up @@ -72,10 +72,14 @@ def create_query_execution(
query=query, engine_id=engine_id, uid=uid, session=session
)

if metadata:
logic.create_query_execution_metadata(
query_execution.id, metadata, session=session
)
api_access_token = (
True if request.headers.get("api-access-token", None) else False
)
metadata = metadata or {}
metadata["api_access_token"] = api_access_token
logic.create_query_execution_metadata(
query_execution.id, metadata, session=session
)

data_doc = None
if data_cell_id:
Expand All @@ -87,10 +91,9 @@ def create_query_execution(

try:
run_query_task.apply_async(
kwargs={
"query_execution_id": query_execution.id,
"api_access_token": current_user.api_access_token,
},
args=[
query_execution.id,
]
)
query_execution_dict = query_execution.to_dict()

Expand Down
12 changes: 10 additions & 2 deletions querybook/server/lib/query_executor/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,12 +502,10 @@ def __init__(
statement_ranges,
client_setting,
execution_type,
api_access_token=False,
):
self._query = query
self._query_execution_id = query_execution_id
self._execution_type = execution_type
self._api_access_token = api_access_token

if self.SINGLE_QUERY_QUERY_ENGINE():
self._statement_ranges = [[0, len(query)]]
Expand All @@ -530,6 +528,16 @@ def __init__(
self._client = None
self._cursor = None

with DBSession() as session:
query_execution_metadata = (
qe_logic.get_query_execution_metadata_by_execution_id(
self._query_execution_id, session=session
).execution_metadata
)
self._api_access_token = query_execution_metadata.get(
"api_access_token", False
)

def __del__(self):
del self._logger
del self._cursor
Expand Down
8 changes: 1 addition & 7 deletions querybook/server/lib/query_executor/executor_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,14 @@

@with_session
def create_executor_from_execution(
query_execution_id,
celery_task,
execution_type,
api_access_token=False,
session=None,
query_execution_id, celery_task, execution_type, session=None
):
executor_params, engine = _get_executor_params_and_engine(
query_execution_id,
celery_task=celery_task,
execution_type=execution_type,
session=session,
)
executor_params["api_access_token"] = api_access_token

executor = get_executor_class(engine.language, engine.executor)(**executor_params)
return executor

Expand Down
10 changes: 2 additions & 8 deletions querybook/server/tasks/run_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,7 @@
acks_late=True,
)
def run_query_task(
self,
query_execution_id,
execution_type=QueryExecutionType.ADHOC.value,
api_access_token=False,
self, query_execution_id, execution_type=QueryExecutionType.ADHOC.value
):
stats_logger.incr(QUERY_EXECUTIONS, tags={"execution_type": execution_type})

Expand All @@ -42,10 +39,7 @@ def run_query_task(

try:
executor = create_executor_from_execution(
query_execution_id,
celery_task=self,
execution_type=execution_type,
api_access_token=api_access_token,
query_execution_id, celery_task=self, execution_type=execution_type
)
run_executor_until_finish(self, executor)
except SoftTimeLimitExceeded:
Expand Down

0 comments on commit 8697311

Please sign in to comment.