diff --git a/modal/partial_function.py b/modal/partial_function.py index e49ba8487..13302026f 100644 --- a/modal/partial_function.py +++ b/modal/partial_function.py @@ -379,6 +379,11 @@ def wrapper(raw_f: Callable[..., Any]) -> _PartialFunction: f"Modal will drop support for default parameters in a future release.", ) + if inspect.iscoroutinefunction(raw_f): + raise InvalidError( + f"ASGI app function {raw_f.__name__} is an async function. Only sync Python functions are supported." + ) + if not wait_for_response: deprecation_error( (2024, 5, 13), @@ -448,6 +453,11 @@ def wrapper(raw_f: Callable[..., Any]) -> _PartialFunction: f"Modal will drop support for default parameters in a future release.", ) + if inspect.iscoroutinefunction(raw_f): + raise InvalidError( + f"WSGI app function {raw_f.__name__} is an async function. Only sync Python functions are supported." + ) + if not wait_for_response: deprecation_error( (2024, 5, 13), diff --git a/test/webhook_test.py b/test/webhook_test.py index 8bb123a56..dd343ec27 100644 --- a/test/webhook_test.py +++ b/test/webhook_test.py @@ -134,40 +134,54 @@ async def test_asgi_wsgi(servicer, client): @app.function(serialized=True) @asgi_app() - async def my_asgi(): + def my_asgi(): pass @app.function(serialized=True) @wsgi_app() - async def my_wsgi(): + def my_wsgi(): pass with pytest.raises(InvalidError, match="can't have parameters"): @app.function(serialized=True) @asgi_app() - async def my_invalid_asgi(x): + def my_invalid_asgi(x): pass with pytest.raises(InvalidError, match="can't have parameters"): @app.function(serialized=True) @wsgi_app() - async def my_invalid_wsgi(x): + def my_invalid_wsgi(x): pass with pytest.warns(DeprecationError, match="default parameters"): @app.function(serialized=True) @asgi_app() - async def my_deprecated_default_params_asgi(x=1): + def my_deprecated_default_params_asgi(x=1): pass with pytest.warns(DeprecationError, match="default parameters"): @app.function(serialized=True) @wsgi_app() - async def my_deprecated_default_params_wsgi(x=1): + def my_deprecated_default_params_wsgi(x=1): + pass + + with pytest.raises(InvalidError, match="async function"): + + @app.function(serialized=True) + @asgi_app() + async def my_async_asgi_function(): + pass + + with pytest.raises(InvalidError, match="async function"): + + @app.function(serialized=True) + @wsgi_app() + async def my_async_wsgi_function(): pass async with app.run(client=client):