From 4367b83254a178969016c41ac0d424582528c625 Mon Sep 17 00:00:00 2001 From: William Song Date: Thu, 12 Oct 2023 20:19:41 -0700 Subject: [PATCH] add user_id and team_id as log facets (#321) * add user_id and team_id as log facets, refactor a little * fix lint, remove draft comments --- model-engine/model_engine_server/api/app.py | 8 ++-- .../model_engine_server/api/dependencies.py | 11 ++++- .../model_engine_server/api/llms_v1.py | 11 +++-- .../model_engine_server/core/loggers.py | 48 +++++++++++++------ 4 files changed, 55 insertions(+), 23 deletions(-) diff --git a/model-engine/model_engine_server/api/app.py b/model-engine/model_engine_server/api/app.py index f87fcf769..1593b951f 100644 --- a/model-engine/model_engine_server/api/app.py +++ b/model-engine/model_engine_server/api/app.py @@ -22,10 +22,10 @@ from model_engine_server.api.tasks_v1 import inference_task_router_v1 from model_engine_server.api.triggers_v1 import trigger_router_v1 from model_engine_server.core.loggers import ( + LoggerTagKey, + LoggerTagManager, filename_wo_ext, - get_request_id, make_logger, - set_request_id, ) logger = make_logger(filename_wo_ext(__name__)) @@ -47,11 +47,11 @@ @app.middleware("http") async def dispatch(request: Request, call_next): try: - set_request_id(str(uuid.uuid4())) + LoggerTagManager.set(LoggerTagKey.REQUEST_ID, str(uuid.uuid4())) return await call_next(request) except Exception as e: tb_str = traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__) - request_id = get_request_id() + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z") structured_log = { "error": str(e), diff --git a/model-engine/model_engine_server/api/dependencies.py b/model-engine/model_engine_server/api/dependencies.py index bdd158db0..89854841e 100644 --- a/model-engine/model_engine_server/api/dependencies.py +++ b/model-engine/model_engine_server/api/dependencies.py @@ -13,7 +13,12 @@ from model_engine_server.core.auth.fake_authentication_repository import ( FakeAuthenticationRepository, ) -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import ( + LoggerTagKey, + LoggerTagManager, + filename_wo_ext, + make_logger, +) from model_engine_server.db.base import SessionAsync, SessionReadOnlyAsync from model_engine_server.domain.gateways import ( CronJobGateway, @@ -330,6 +335,10 @@ async def verify_authentication( headers={"WWW-Authenticate": "Basic"}, ) + # set logger context with identity data + LoggerTagManager.set(LoggerTagKey.USER_ID, auth.user_id) + LoggerTagManager.set(LoggerTagKey.TEAM_ID, auth.team_id) + return auth diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index 92ddad0e7..4917ee32b 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -36,7 +36,12 @@ ) from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.loggers import filename_wo_ext, get_request_id, make_logger +from model_engine_server.core.loggers import ( + LoggerTagKey, + LoggerTagManager, + filename_wo_ext, + make_logger, +) from model_engine_server.domain.exceptions import ( EndpointDeleteFailedException, EndpointLabelsException, @@ -82,7 +87,7 @@ def handle_streaming_exception( message: str, ): tb_str = traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__) - request_id = get_request_id() + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z") structured_log = { "error": message, @@ -223,7 +228,7 @@ async def create_completion_sync_task( user=auth, model_endpoint_name=model_endpoint_name, request=request ) except UpstreamServiceError: - request_id = get_request_id() + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) logger.exception(f"Upstream service error for request {request_id}") raise HTTPException( status_code=500, diff --git a/model-engine/model_engine_server/core/loggers.py b/model-engine/model_engine_server/core/loggers.py index 94c96998e..ce0ee847a 100644 --- a/model-engine/model_engine_server/core/loggers.py +++ b/model-engine/model_engine_server/core/loggers.py @@ -5,7 +5,8 @@ import sys import warnings from contextlib import contextmanager -from typing import Optional, Sequence +from enum import Enum +from typing import Dict, Optional, Sequence import ddtrace import json_log_formatter @@ -16,8 +17,6 @@ LOG_FORMAT: str = "%(asctime)s %(levelname)s [%(name)s] [%(filename)s:%(lineno)d] - %(message)s" # REQUIRED FOR DATADOG COMPATIBILITY -ctx_var_request_id = contextvars.ContextVar("ctx_var_request_id", default=None) - __all__: Sequence[str] = ( # most common imports "make_logger", @@ -35,19 +34,37 @@ "loggers_at_level", # utils "filename_wo_ext", - "get_request_id", - "set_request_id", + "LoggerTagKey", + "LoggerTagManager", ) -def get_request_id() -> Optional[str]: - """Get the request id from the context variable.""" - return ctx_var_request_id.get() +class LoggerTagKey(str, Enum): + REQUEST_ID = "request_id" + TEAM_ID = "team_id" + USER_ID = "user_id" + + +class LoggerTagManager: + _context_vars: Dict[LoggerTagKey, contextvars.ContextVar] = {} + @classmethod + def get(cls, key: LoggerTagKey) -> Optional[str]: + """Get the value from the context variable.""" + ctx_var = cls._context_vars.get(key) + if ctx_var is not None: + return ctx_var.get() + return None -def set_request_id(request_id: str) -> None: - """Set the request id in the context variable.""" - ctx_var_request_id.set(request_id) # type: ignore + @classmethod + def set(cls, key: LoggerTagKey, value: Optional[str]) -> None: + """Set the value in the context variable.""" + if value is not None: + ctx_var = cls._context_vars.get(key) + if ctx_var is None: + ctx_var = contextvars.ContextVar(f"ctx_var_{key.name.lower()}", default=None) + cls._context_vars[key] = ctx_var + ctx_var.set(value) def make_standard_logger(name: str, log_level: int = logging.INFO) -> logging.Logger: @@ -77,10 +94,11 @@ def json_record(self, message: str, extra: dict, record: logging.LogRecord) -> d extra["lineno"] = record.lineno extra["pathname"] = record.pathname - # add the http request id if it exists - request_id = ctx_var_request_id.get() - if request_id: - extra["request_id"] = request_id + # add additional logger tags + for tag_key in LoggerTagKey: + tag_value = LoggerTagManager.get(tag_key) + if tag_value: + extra[tag_key.value] = tag_value current_span = tracer.current_span() extra["dd.trace_id"] = current_span.trace_id if current_span else 0