diff --git a/modal/cls.py b/modal/cls.py index 8b37a27fc..e11c7e5a2 100644 --- a/modal/cls.py +++ b/modal/cls.py @@ -524,6 +524,7 @@ async def lookup( await resolver.load(obj) return obj + @synchronizer.no_input_translation def __call__(self, *args, **kwargs) -> _Obj: """This acts as the class constructor.""" return _Obj( diff --git a/test/cls_test.py b/test/cls_test.py index aedc6d11f..64fc39239 100644 --- a/test/cls_test.py +++ b/test/cls_test.py @@ -1007,3 +1007,20 @@ class D: @app.function(serialized=True) def f(self): pass + + +def test_modal_object_param_uses_wrapped_type(servicer, set_env_client, client): + with servicer.intercept() as ctx: + with modal.Dict.ephemeral() as dct: + with baz_app.run(): + # create bound instance: + typing.cast(modal.Cls, Baz(x=dct)).keep_warm(1) + + req: api_pb2.FunctionBindParamsRequest = ctx.pop_request("FunctionBindParams") + function_def: api_pb2.Function = servicer.app_functions[req.function_id] + from modal._container_entrypoint import deserialize_params + + _client = typing.cast(modal.client._Client, synchronizer._translate_in(client)) + container_params = deserialize_params(req.serialized_params, function_def, _client) + args, kwargs = container_params + assert type(kwargs["x"]) == type(dct)