Skip to content

Commit

Permalink
add user_id and team_id as log facets (#321)
Browse files Browse the repository at this point in the history
* add user_id and team_id as log facets, refactor a little

* fix lint, remove draft comments
  • Loading branch information
song-william authored Oct 13, 2023
1 parent d30a1a5 commit 4367b83
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 23 deletions.
8 changes: 4 additions & 4 deletions model-engine/model_engine_server/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
from model_engine_server.api.tasks_v1 import inference_task_router_v1
from model_engine_server.api.triggers_v1 import trigger_router_v1
from model_engine_server.core.loggers import (
LoggerTagKey,
LoggerTagManager,
filename_wo_ext,
get_request_id,
make_logger,
set_request_id,
)

logger = make_logger(filename_wo_ext(__name__))
Expand All @@ -47,11 +47,11 @@
@app.middleware("http")
async def dispatch(request: Request, call_next):
try:
set_request_id(str(uuid.uuid4()))
LoggerTagManager.set(LoggerTagKey.REQUEST_ID, str(uuid.uuid4()))
return await call_next(request)
except Exception as e:
tb_str = traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__)
request_id = get_request_id()
request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID)
timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z")
structured_log = {
"error": str(e),
Expand Down
11 changes: 10 additions & 1 deletion model-engine/model_engine_server/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
from model_engine_server.core.auth.fake_authentication_repository import (
FakeAuthenticationRepository,
)
from model_engine_server.core.loggers import filename_wo_ext, make_logger
from model_engine_server.core.loggers import (
LoggerTagKey,
LoggerTagManager,
filename_wo_ext,
make_logger,
)
from model_engine_server.db.base import SessionAsync, SessionReadOnlyAsync
from model_engine_server.domain.gateways import (
CronJobGateway,
Expand Down Expand Up @@ -330,6 +335,10 @@ async def verify_authentication(
headers={"WWW-Authenticate": "Basic"},
)

# set logger context with identity data
LoggerTagManager.set(LoggerTagKey.USER_ID, auth.user_id)
LoggerTagManager.set(LoggerTagKey.TEAM_ID, auth.team_id)

return auth


Expand Down
11 changes: 8 additions & 3 deletions model-engine/model_engine_server/api/llms_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,12 @@
)
from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy
from model_engine_server.core.auth.authentication_repository import User
from model_engine_server.core.loggers import filename_wo_ext, get_request_id, make_logger
from model_engine_server.core.loggers import (
LoggerTagKey,
LoggerTagManager,
filename_wo_ext,
make_logger,
)
from model_engine_server.domain.exceptions import (
EndpointDeleteFailedException,
EndpointLabelsException,
Expand Down Expand Up @@ -82,7 +87,7 @@ def handle_streaming_exception(
message: str,
):
tb_str = traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__)
request_id = get_request_id()
request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID)
timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z")
structured_log = {
"error": message,
Expand Down Expand Up @@ -223,7 +228,7 @@ async def create_completion_sync_task(
user=auth, model_endpoint_name=model_endpoint_name, request=request
)
except UpstreamServiceError:
request_id = get_request_id()
request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID)
logger.exception(f"Upstream service error for request {request_id}")
raise HTTPException(
status_code=500,
Expand Down
48 changes: 33 additions & 15 deletions model-engine/model_engine_server/core/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import sys
import warnings
from contextlib import contextmanager
from typing import Optional, Sequence
from enum import Enum
from typing import Dict, Optional, Sequence

import ddtrace
import json_log_formatter
Expand All @@ -16,8 +17,6 @@
LOG_FORMAT: str = "%(asctime)s %(levelname)s [%(name)s] [%(filename)s:%(lineno)d] - %(message)s"
# REQUIRED FOR DATADOG COMPATIBILITY

ctx_var_request_id = contextvars.ContextVar("ctx_var_request_id", default=None)

__all__: Sequence[str] = (
# most common imports
"make_logger",
Expand All @@ -35,19 +34,37 @@
"loggers_at_level",
# utils
"filename_wo_ext",
"get_request_id",
"set_request_id",
"LoggerTagKey",
"LoggerTagManager",
)


def get_request_id() -> Optional[str]:
"""Get the request id from the context variable."""
return ctx_var_request_id.get()
class LoggerTagKey(str, Enum):
REQUEST_ID = "request_id"
TEAM_ID = "team_id"
USER_ID = "user_id"


class LoggerTagManager:
_context_vars: Dict[LoggerTagKey, contextvars.ContextVar] = {}

@classmethod
def get(cls, key: LoggerTagKey) -> Optional[str]:
"""Get the value from the context variable."""
ctx_var = cls._context_vars.get(key)
if ctx_var is not None:
return ctx_var.get()
return None

def set_request_id(request_id: str) -> None:
"""Set the request id in the context variable."""
ctx_var_request_id.set(request_id) # type: ignore
@classmethod
def set(cls, key: LoggerTagKey, value: Optional[str]) -> None:
"""Set the value in the context variable."""
if value is not None:
ctx_var = cls._context_vars.get(key)
if ctx_var is None:
ctx_var = contextvars.ContextVar(f"ctx_var_{key.name.lower()}", default=None)
cls._context_vars[key] = ctx_var
ctx_var.set(value)


def make_standard_logger(name: str, log_level: int = logging.INFO) -> logging.Logger:
Expand Down Expand Up @@ -77,10 +94,11 @@ def json_record(self, message: str, extra: dict, record: logging.LogRecord) -> d
extra["lineno"] = record.lineno
extra["pathname"] = record.pathname

# add the http request id if it exists
request_id = ctx_var_request_id.get()
if request_id:
extra["request_id"] = request_id
# add additional logger tags
for tag_key in LoggerTagKey:
tag_value = LoggerTagManager.get(tag_key)
if tag_value:
extra[tag_key.value] = tag_value

current_span = tracer.current_span()
extra["dd.trace_id"] = current_span.trace_id if current_span else 0
Expand Down

0 comments on commit 4367b83

Please sign in to comment.