diff --git a/CHANGELOG.md b/CHANGELOG.md index 0ad0aa6..ec6b96a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,7 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added -- Emit error message when an exception is handled by the built application +- Emit error message when an exception is handled by the built application. Add `X-Request-ID` header value to log message when present ## [0.6.2] - 2023-08-23 ### Fixed diff --git a/fastapi_mlflow/applications.py b/fastapi_mlflow/applications.py index 76f3b99..d7bb291 100644 --- a/fastapi_mlflow/applications.py +++ b/fastapi_mlflow/applications.py @@ -31,10 +31,12 @@ def build_app(pyfunc_model: PyFuncModel) -> FastAPI: @app.exception_handler(DictSerialisableException) def handle_serialisable_exception( - _: Request, exc: DictSerialisableException + req: Request, exc: DictSerialisableException ) -> ORJSONResponse: nonlocal logger - logger.exception(exc.message) + req_id = req.headers.get("x-request-id") + extra = {"x-request-id": req_id} if req_id is not None else {} + logger.exception(exc.message, extra=extra) return ORJSONResponse( status_code=500, content=exc.to_dict(), diff --git a/tests/test_application.py b/tests/test_application.py index 9d850f2..a39b657 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -105,4 +105,35 @@ def test_built_application_logs_exceptions( assert len(caplog.records) >= 1 log_record = caplog.records[-1] assert log_record.name == "fastapi_mlflow.applications" - assert log_record.message == python_model_value_error.ERROR_MESSAGE \ No newline at end of file + assert log_record.message == python_model_value_error.ERROR_MESSAGE + + +@pytest.mark.parametrize( + "req_id_header_name", + [ + "x-request-id", + "X-Request-Id", + "X-Request-ID", + "X-REQUEST-ID", + ], +) +def test_built_application_logs_exceptions_including_request_id_header_when_sent( + model_input: pd.DataFrame, + pyfunc_model_value_error: PyFuncModel, + python_model_value_error: PythonModel, + caplog: pytest.LogCaptureFixture, + req_id_header_name: str +): + app = build_app(pyfunc_model_value_error) + client = TestClient(app, raise_server_exceptions=False) + df_str = model_input.to_json(orient="records") + request_data = f'{{"data": {df_str}}}' + request_id = "abcdef" + + _ = client.post( + "/predictions", content=request_data, headers={req_id_header_name: request_id} + ) + + log_record = caplog.records[-1] + assert hasattr(log_record, "x-request-id") + assert getattr(log_record, "x-request-id") == request_id