Skip to content

Commit

Permalink
feat: Instrument latency without streaming duration (#290)
Browse files Browse the repository at this point in the history
* Track response start duration

This commit adds a feature to track the latency excluding
streaming duration.

* ci(pre-commit): Apply hook auto fixes

* fix: Add default to Info constructor and adjust test

* fix: Make mypy happy

* docs: Add parameter to docstring

* fix: Add start time stuff to body handler

* test: Add test

* feat: Add duration stuff to default and add tests

* docs: Add entry to changelog

---------

Co-authored-by: Tim Schwenke <[email protected]>
  • Loading branch information
dosuken123 and trallnag authored Mar 11, 2024
1 parent c608c4e commit 4530ba4
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 10 deletions.
18 changes: 18 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,24 @@ and adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0).
and implementing it in
[#288](https://github.com/trallnag/prometheus-fastapi-instrumentator/pull/288).

- **Middleware also records duration without streaming** in addition to the
already existing total latency (i.e. the time consumed for streaming is not
included in the duration value). The differentiation can be valuable as it
shows the time to first byte.

This mode is opt-in and can be enabled / used in several ways: The
`Instrumentator()` constructor, the `metrics.default()` closure, and the
`metrics.latency()` closure now come with the flag
`should_exclude_streaming_duration`. The attribute
`modified_duration_without_streaming` has been added to the `metrics.Info`
class. Instances of `metrics.Info` are passed to instrumentation functions,
where the added value can be used to set metrics.

Thanks to [@dosuken123](https://github.com/dosuken123) for proposing this in
[#291](https://github.com/trallnag/prometheus-fastapi-instrumentator/issues/291)
and implementing it in
[#290](https://github.com/trallnag/prometheus-fastapi-instrumentator/pull/290).

- Relaxed type of `get_route_name` argument to `HTTPConnection`. This allows
developers to use the `get_route_name` function for getting the name of
websocket routes as well. Thanks to [@pajowu](https://github.com/pajowu) for
Expand Down
7 changes: 7 additions & 0 deletions src/prometheus_fastapi_instrumentator/instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
should_round_latency_decimals: bool = False,
should_respect_env_var: bool = False,
should_instrument_requests_inprogress: bool = False,
should_exclude_streaming_duration: bool = False,
excluded_handlers: List[str] = [],
body_handlers: List[str] = [],
round_latency_decimals: int = 4,
Expand Down Expand Up @@ -69,6 +70,10 @@ def __init__(
the inprogress requests. See also the related args starting
with `inprogress`. Defaults to `False`.
should_exclude_streaming_duration: Should the streaming duration be
excluded? Only relevant if default metrics are used. Defaults
to `False`.
excluded_handlers (List[str]): List of strings that will be compiled
to regex patterns. All matches will be skipped and not
instrumented. Defaults to `[]`.
Expand Down Expand Up @@ -112,6 +117,7 @@ def __init__(
self.should_round_latency_decimals = should_round_latency_decimals
self.should_respect_env_var = should_respect_env_var
self.should_instrument_requests_inprogress = should_instrument_requests_inprogress
self.should_exclude_streaming_duration = should_exclude_streaming_duration

self.round_latency_decimals = round_latency_decimals
self.env_var_name = env_var_name
Expand Down Expand Up @@ -205,6 +211,7 @@ def instrument(
should_round_latency_decimals=self.should_round_latency_decimals,
should_respect_env_var=self.should_respect_env_var,
should_instrument_requests_inprogress=self.should_instrument_requests_inprogress,
should_exclude_streaming_duration=self.should_exclude_streaming_duration,
round_latency_decimals=self.round_latency_decimals,
env_var_name=self.env_var_name,
inprogress_name=self.inprogress_name,
Expand Down
34 changes: 29 additions & 5 deletions src/prometheus_fastapi_instrumentator/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
modified_handler: str,
modified_status: str,
modified_duration: float,
modified_duration_without_streaming: float = 0.0,
):
"""Creates Info object that is used for instrumentation functions.
Expand All @@ -42,6 +43,8 @@ def __init__(
by instrumentator. For example grouping into `2xx`, `3xx` and so on.
modified_duration (float): Latency representation after processing
by instrumentator. For example rounding of decimals. Seconds.
modified_duration_without_streaming (float): Latency between request arrival and response starts (i.e. first chunk duration).
Excluding the streaming duration. Defaults to 0.
"""

self.request = request
Expand All @@ -50,6 +53,7 @@ def __init__(
self.modified_handler = modified_handler
self.modified_status = modified_status
self.modified_duration = modified_duration
self.modified_duration_without_streaming = modified_duration_without_streaming


def _build_label_attribute_names(
Expand Down Expand Up @@ -114,6 +118,7 @@ def latency(
should_include_handler: bool = True,
should_include_method: bool = True,
should_include_status: bool = True,
should_exclude_streaming_duration: bool = False,
buckets: Sequence[Union[float, str]] = Histogram.DEFAULT_BUCKETS,
registry: CollectorRegistry = REGISTRY,
) -> Optional[Callable[[Info], None]]:
Expand Down Expand Up @@ -141,6 +146,9 @@ def latency(
should_include_status: Should the `status` label be part of the
metric? Defaults to `True`.
should_exclude_streaming_duration: Should the streaming duration be
excluded? Defaults to `False`.
buckets: Buckets for the histogram. Defaults to Prometheus default.
Defaults to default buckets from Prometheus client library.
Expand Down Expand Up @@ -184,15 +192,21 @@ def latency(
)

def instrumentation(info: Info) -> None:
duration = info.modified_duration
if should_exclude_streaming_duration:
duration = info.modified_duration_without_streaming
else:
duration = info.modified_duration

if label_names:
label_values = [
getattr(info, attribute_name)
for attribute_name in info_attribute_names
]

METRIC.labels(*label_values).observe(info.modified_duration)
METRIC.labels(*label_values).observe(duration)
else:
METRIC.observe(info.modified_duration)
METRIC.observe(duration)

return instrumentation
except ValueError as e:
Expand Down Expand Up @@ -569,6 +583,7 @@ def default(
metric_namespace: str = "",
metric_subsystem: str = "",
should_only_respect_2xx_for_highr: bool = False,
should_exclude_streaming_duration: bool = False,
latency_highr_buckets: Sequence[Union[float, str]] = (
0.01,
0.025,
Expand Down Expand Up @@ -610,7 +625,7 @@ def default(
content length bytes by handler.
* `http_request_duration_highr_seconds` (no labels): High number of buckets
leading to more accurate calculation of percentiles.
* `http_request_duration_seconds` (`handler`):
* `http_request_duration_seconds` (`handler`, `method`):
Kepp the bucket count very low. Only put in SLIs.
Args:
Expand All @@ -625,6 +640,9 @@ def default(
requests / responses that have a status code starting with `2`?
Defaults to `False`.
should_exclude_streaming_duration: Should the streaming duration be
excluded? Defaults to `False`.
latency_highr_buckets (tuple[float], optional): Buckets tuple for high
res histogram. Can be large because no labels are used. Defaults to
(0.01, 0.025, 0.05, 0.075, 0.1, 0.25, 0.5, 0.75, 1, 1.5, 2, 2.5,
Expand Down Expand Up @@ -719,6 +737,12 @@ def default(
)

def instrumentation(info: Info) -> None:
duration = info.modified_duration
if should_exclude_streaming_duration:
duration = info.modified_duration_without_streaming
else:
duration = info.modified_duration

TOTAL.labels(info.method, info.modified_status, info.modified_handler).inc()

IN_SIZE.labels(info.modified_handler).observe(
Expand All @@ -735,11 +759,11 @@ def instrumentation(info: Info) -> None:
if not should_only_respect_2xx_for_highr or info.modified_status.startswith(
"2"
):
LATENCY_HIGHR.observe(info.modified_duration)
LATENCY_HIGHR.observe(duration)

LATENCY_LOWR.labels(
handler=info.modified_handler, method=info.method
).observe(info.modified_duration)
).observe(duration)

return instrumentation

Expand Down
21 changes: 18 additions & 3 deletions src/prometheus_fastapi_instrumentator/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
should_round_latency_decimals: bool = False,
should_respect_env_var: bool = False,
should_instrument_requests_inprogress: bool = False,
should_exclude_streaming_duration: bool = False,
excluded_handlers: Sequence[str] = (),
body_handlers: Sequence[str] = (),
round_latency_decimals: int = 4,
Expand Down Expand Up @@ -89,6 +90,7 @@ def __init__(
metric_namespace=metric_namespace,
metric_subsystem=metric_subsystem,
should_only_respect_2xx_for_highr=should_only_respect_2xx_for_highr,
should_exclude_streaming_duration=should_exclude_streaming_duration,
latency_highr_buckets=latency_highr_buckets,
latency_lowr_buckets=latency_lowr_buckets,
registry=self.registry,
Expand Down Expand Up @@ -140,15 +142,17 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
status_code = 500
headers = []
body = b""
response_start_time = None

# Message body collected for handlers matching body_handlers patterns.
if any(pattern.search(handler) for pattern in self.body_handlers):

async def send_wrapper(message: Message) -> None:
if message["type"] == "http.response.start":
nonlocal status_code, headers
nonlocal status_code, headers, response_start_time
headers = message["headers"]
status_code = message["status"]
response_start_time = default_timer()
elif message["type"] == "http.response.body" and message["body"]:
nonlocal body
body += message["body"]
Expand All @@ -158,9 +162,10 @@ async def send_wrapper(message: Message) -> None:

async def send_wrapper(message: Message) -> None:
if message["type"] == "http.response.start":
nonlocal status_code, headers
nonlocal status_code, headers, response_start_time
headers = message["headers"]
status_code = message["status"]
response_start_time = default_timer()
await send(message)

try:
Expand All @@ -175,13 +180,22 @@ async def send_wrapper(message: Message) -> None:
)

if not is_excluded:
duration = max(default_timer() - start_time, 0)
duration = max(default_timer() - start_time, 0.0)
duration_without_streaming = 0.0

if response_start_time:
duration_without_streaming = max(
response_start_time - start_time, 0.0
)

if self.should_instrument_requests_inprogress:
inprogress.dec()

if self.should_round_latency_decimals:
duration = round(duration, self.round_latency_decimals)
duration_without_streaming = round(
duration_without_streaming, self.round_latency_decimals
)

if self.should_group_status_codes:
status = status[0] + "xx"
Expand All @@ -197,6 +211,7 @@ async def send_wrapper(message: Message) -> None:
modified_handler=handler,
modified_status=status,
modified_duration=duration,
modified_duration_without_streaming=duration_without_streaming,
)

for instrumentation in self.instrumentations:
Expand Down
79 changes: 77 additions & 2 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Any, Dict, Optional

import pytest
from fastapi import FastAPI, HTTPException
from prometheus_client import REGISTRY
from fastapi import FastAPI, HTTPException, responses
from prometheus_client import REGISTRY, Histogram
from requests import Response as TestClientResponse
from starlette.testclient import TestClient

Expand Down Expand Up @@ -106,6 +106,7 @@ def test_existence_of_attributes():
assert info.modified_duration is None
assert info.modified_status is None
assert info.modified_handler is None
assert info.modified_duration_without_streaming == 0.0


def test_build_label_attribute_names_all_false():
Expand Down Expand Up @@ -422,6 +423,47 @@ def test_latency_with_bucket_no_inf():
)


def test_latency_duration_without_streaming():
_ = create_app()
app = FastAPI()
client = TestClient(app)

@app.get("/")
def root():
return responses.StreamingResponse(("x" * 1_000 for _ in range(5)))

METRIC = Histogram(
"http_request_duration_with_streaming_seconds",
"x",
)

def instrumentation(info: metrics.Info) -> None:
METRIC.observe(info.modified_duration)

Instrumentator().add(
metrics.latency(
should_include_handler=False,
should_include_method=False,
should_include_status=False,
should_exclude_streaming_duration=True,
),
instrumentation,
).instrument(app).expose(app)
client = TestClient(app)

client.get("/")

_ = get_response(client, "/metrics")

assert REGISTRY.get_sample_value(
"http_request_duration_seconds_sum",
{},
) < REGISTRY.get_sample_value(
"http_request_duration_with_streaming_seconds_sum",
{},
)


# ------------------------------------------------------------------------------
# default

Expand Down Expand Up @@ -521,6 +563,39 @@ def test_default_with_runtime_error():
)


def test_default_duration_without_streaming():
_ = create_app()
app = FastAPI()

@app.get("/")
def root():
return responses.StreamingResponse(("x" * 1_000 for _ in range(5)))

METRIC = Histogram(
"http_request_duration_with_streaming_seconds", "x", labelnames=["handler"]
)

def instrumentation(info: metrics.Info) -> None:
METRIC.labels(info.modified_handler).observe(info.modified_duration)

Instrumentator().add(
metrics.default(should_exclude_streaming_duration=True), instrumentation
).instrument(app).expose(app)
client = TestClient(app)

client.get("/")

_ = get_response(client, "/metrics")

assert REGISTRY.get_sample_value(
"http_request_duration_with_streaming_seconds_sum",
{"handler": "/"},
) > REGISTRY.get_sample_value(
"http_request_duration_seconds_sum",
{"handler": "/", "method": "GET"},
)


# ------------------------------------------------------------------------------
# requests

Expand Down
21 changes: 21 additions & 0 deletions tests/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,24 @@ def instrumentation(info: metrics.Info) -> None:
response = client.get("/")
assert instrumentation_executed
assert len(response.content) == 5_000_000


def test_info_body_duration_without_streaming():
app = FastAPI()
client = TestClient(app)

@app.get("/")
def root():
return responses.StreamingResponse(("x" * 1_000 for _ in range(5)))

instrumentation_executed = False

def instrumentation(info: metrics.Info) -> None:
nonlocal instrumentation_executed
instrumentation_executed = True
assert info.modified_duration_without_streaming < info.modified_duration

Instrumentator(body_handlers=[r".*"]).instrument(app).add(instrumentation)

client.get("/")
assert instrumentation_executed

0 comments on commit 4530ba4

Please sign in to comment.