diff --git a/starlette/middleware/__init__.py b/starlette/middleware/__init__.py index 8e0a54edb..5164e49e0 100644 --- a/starlette/middleware/__init__.py +++ b/starlette/middleware/__init__.py @@ -14,7 +14,7 @@ class _MiddlewareFactory(Protocol[P]): - def __call__(self, app: ASGIApp, *args: P.args, **kwargs: P.kwargs) -> ASGIApp: ... # pragma: no cover + def __call__(self, app: ASGIApp, /, *args: P.args, **kwargs: P.kwargs) -> ASGIApp: ... # pragma: no cover class Middleware: diff --git a/tests/test_applications.py b/tests/test_applications.py index 29c011a29..db2a6050b 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -3,7 +3,7 @@ import os from contextlib import asynccontextmanager from pathlib import Path -from typing import AsyncGenerator, AsyncIterator, Generator +from typing import AsyncGenerator, AsyncIterator, Callable, Generator import anyio.from_thread import pytest @@ -567,9 +567,12 @@ async def _app(scope: Scope, receive: Receive, send: Send) -> None: return _app + def get_middleware_factory() -> Callable[[ASGIApp, str], ASGIApp]: + return _middleware_factory + app = Starlette() app.add_middleware(_middleware_factory, arg="foo") - app.add_middleware(_middleware_factory, arg="bar") + app.add_middleware(get_middleware_factory(), "bar") with test_client_factory(app): pass