Skip to content

Commit

Permalink
fix merge conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
devennavani committed Nov 25, 2024
2 parents 720aa55 + 8a95aae commit 1b4d2f0
Show file tree
Hide file tree
Showing 19 changed files with 333 additions and 125 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ We appreciate your patience while we speedily work towards a stable release of t

<!-- NEW CONTENT GENERATED BELOW. PLEASE PRESERVE THIS COMMENT. -->

### 0.66.40 (2024-11-23)

* Adds `Image.add_local_file(..., copy=False)` and `Image.add_local_dir(..., copy=False)` as a unified replacement for the old `Image.copy_local_*()` and `Mount.add_local_*` methods.



### 0.66.30 (2024-11-21)

- Removed the `aiostream` package from the modal client library dependencies.
Expand Down
2 changes: 1 addition & 1 deletion modal/_container_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ def main(container_args: api_pb2.ContainerArguments, client: Client):
call_lifecycle_functions(event_loop, container_io_manager, list(pre_snapshot_methods.values()))

# If this container is being used to create a checkpoint, checkpoint the container after
# global imports and innitialization. Checkpointed containers run from this point onwards.
# global imports and initialization. Checkpointed containers run from this point onwards.
if is_snapshotting_function:
container_io_manager.memory_snapshot()

Expand Down
2 changes: 1 addition & 1 deletion modal/_runtime/user_code_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def get_user_class_instance(
modal_obj: modal.cls.Obj = cls(*args, **kwargs)
modal_obj.entered = True # ugly but prevents .local() from triggering additional enter-logic
# TODO: unify lifecycle logic between .local() and container_entrypoint
user_cls_instance = modal_obj._get_user_cls_instance()
user_cls_instance = modal_obj._cached_user_cls_instance()
else:
# undecorated class (non-global decoration or serialized)
user_cls_instance = cls(*args, **kwargs)
Expand Down
59 changes: 33 additions & 26 deletions modal/_utils/grpc_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def intercept(servicer):
ctx = InterceptionContext()
servicer.interception_context = ctx
yield ctx
ctx.assert_responses_consumed()
ctx._assert_responses_consumed()
servicer.interception_context = None

cls.intercept = intercept
Expand All @@ -64,7 +64,7 @@ async def patched_method(servicer_self, stream):
ctx = servicer_self.interception_context
if ctx:
intercepted_stream = await InterceptedStream(ctx, method_name, stream).initialize()
custom_responder = ctx.next_custom_responder(method_name, intercepted_stream.request_message)
custom_responder = ctx._next_custom_responder(method_name, intercepted_stream.request_message)
if custom_responder:
return await custom_responder(servicer_self, intercepted_stream)
else:
Expand Down Expand Up @@ -105,19 +105,23 @@ def __init__(self):
self.custom_responses: Dict[str, List[Tuple[Callable[[Any], bool], List[Any]]]] = defaultdict(list)
self.custom_defaults: Dict[str, Callable[["MockClientServicer", grpclib.server.Stream], Awaitable[None]]] = {}

def add_recv(self, method_name: str, msg):
self.calls.append((method_name, msg))

def add_response(
self, method_name: str, first_payload, *, request_filter: Callable[[Any], bool] = lambda req: True
):
# adds one response to a queue of responses for requests of the specified type
"""Adds one response payload to an expected queue of responses for a method.
These responses will be used once each instead of calling the MockServicer's
implementation of the method.
The interception context will throw an exception on exit if not all of the added
responses have been consumed.
"""
self.custom_responses[method_name].append((request_filter, [first_payload]))

def set_responder(
self, method_name: str, responder: Callable[["MockClientServicer", grpclib.server.Stream], Awaitable[None]]
):
"""Replace the default responder method. E.g.
"""Replace the default responder from the MockClientServicer with a custom implementation
```python notest
def custom_responder(servicer, stream):
Expand All @@ -128,11 +132,28 @@ def custom_responder(servicer, stream):
ctx.set_responder("SomeMethod", custom_responder)
```
Responses added via `.add_response()` take precedence.
Responses added via `.add_response()` take precedence over the use of this replacement
"""
self.custom_defaults[method_name] = responder

def next_custom_responder(self, method_name, request):
def pop_request(self, method_name):
# fast forward to the next request of type method_name
# dropping any preceding requests if there is a match
# returns the payload of the request
for i, (_method_name, msg) in enumerate(self.calls):
if _method_name == method_name:
self.calls = self.calls[i + 1 :]
return msg

raise KeyError(f"No message of that type in call list: {self.calls}")

def get_requests(self, method_name: str) -> List[Any]:
return [msg for _method_name, msg in self.calls if _method_name == method_name]

def _add_recv(self, method_name: str, msg):
self.calls.append((method_name, msg))

def _next_custom_responder(self, method_name, request):
method_responses = self.custom_responses[method_name]
for i, (request_filter, response_messages) in enumerate(method_responses):
try:
Expand All @@ -159,31 +180,17 @@ async def responder(servicer_self, stream):

return responder

def assert_responses_consumed(self):
def _assert_responses_consumed(self):
unconsumed = []
for method_name, queued_responses in self.custom_responses.items():
unconsumed += [method_name] * len(queued_responses)

if unconsumed:
raise ResponseNotConsumed(unconsumed)

def pop_request(self, method_name):
# fast forward to the next request of type method_name
# dropping any preceding requests if there is a match
# returns the payload of the request
for i, (_method_name, msg) in enumerate(self.calls):
if _method_name == method_name:
self.calls = self.calls[i + 1 :]
return msg

raise KeyError(f"No message of that type in call list: {self.calls}")

def get_requests(self, method_name: str) -> List[Any]:
return [msg for _method_name, msg in self.calls if _method_name == method_name]


class InterceptedStream:
def __init__(self, interception_context, method_name, stream):
def __init__(self, interception_context: InterceptionContext, method_name: str, stream):
self.interception_context = interception_context
self.method_name = method_name
self.stream = stream
Expand All @@ -200,7 +207,7 @@ async def recv_message(self):
return ret

msg = await self.stream.recv_message()
self.interception_context.add_recv(self.method_name, msg)
self.interception_context._add_recv(self.method_name, msg)
return msg

async def send_message(self, msg):
Expand Down
45 changes: 26 additions & 19 deletions modal/cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(
method = self._instance_service_function._bind_instance_method(class_bound_method)
self._method_functions[method_name] = method
else:
# <v0.63 classes - bind each individual method to the new parameters
# looked up <v0.63 classes - bind each individual method to the new parameters
self._instance_service_function = None
for method_name, class_bound_method in classbound_methods.items():
method = class_bound_method._bind_parameters(self, from_other_workspace, options, args, kwargs)
Expand All @@ -125,12 +125,14 @@ def __init__(
self._user_cls = user_cls
self._construction_args = (args, kwargs) # used for lazy construction in case of explicit constructors

def _user_cls_instance_constr(self):
def _new_user_cls_instance(self):
args, kwargs = self._construction_args
if not _use_annotation_parameters(self._user_cls):
# TODO(elias): deprecate this code path eventually
user_cls_instance = self._user_cls(*args, **kwargs)
else:
# ignore constructor (assumes there is no custom constructor,
# which is guaranteed by _use_annotation_parameters)
# set the attributes on the class corresponding to annotations
# with = parameter() specifications
sig = _get_class_constructor_signature(self._user_cls)
Expand All @@ -139,6 +141,7 @@ def _user_cls_instance_constr(self):
user_cls_instance = self._user_cls.__new__(self._user_cls) # new instance without running __init__
user_cls_instance.__dict__.update(bound_vars.arguments)

# TODO: always use Obj instances instead of making modifications to user cls
user_cls_instance._modal_functions = self._method_functions # Needed for PartialFunction.__get__
return user_cls_instance

Expand All @@ -163,10 +166,12 @@ async def keep_warm(self, warm_pool_size: int) -> None:
)
await self._instance_service_function.keep_warm(warm_pool_size)

def _get_user_cls_instance(self):
"""Construct local object lazily. Used for .local() calls."""
def _cached_user_cls_instance(self):
"""Get or construct the local object
Used for .local() calls and getting attributes of classes"""
if not self._user_cls_instance:
self._user_cls_instance = self._user_cls_instance_constr() # Instantiate object
self._user_cls_instance = self._new_user_cls_instance() # Instantiate object

return self._user_cls_instance

Expand Down Expand Up @@ -196,7 +201,7 @@ def entered(self, val):
@synchronizer.nowrap
async def aenter(self):
if not self.entered:
user_cls_instance = self._get_user_cls_instance()
user_cls_instance = self._cached_user_cls_instance()
if hasattr(user_cls_instance, "__aenter__"):
await user_cls_instance.__aenter__()
elif hasattr(user_cls_instance, "__enter__"):
Expand All @@ -205,20 +210,22 @@ async def aenter(self):

def __getattr__(self, k):
if k in self._method_functions:
# if we know the user is accessing a method, we don't have to create an instance
# yet, since the user might just call `.remote()` on it which doesn't require
# a local instance (in case __init__ does stuff that can't locally)
# If we know the user is accessing a *method* and not another attribute,
# we don't have to create an instance of the user class yet.
# This is because it might just be a call to `.remote()` on it which
# doesn't require a local instance.
# As long as we have the service function or params, we can do remote calls
# without calling the constructor of the class in the calling context.
return self._method_functions[k]
elif self._user_cls_instance_constr:
# if it's *not* a method
# TODO: To get lazy loading (from_name) of classes to work, we need to avoid
# this path, otherwise local initialization will happen regardless if user
# only runs .remote(), since we don't know methods for the class until we
# load it
user_cls_instance = self._get_user_cls_instance()
return getattr(user_cls_instance, k)
else:
raise AttributeError(k)

# if it's *not* a method, it *might* be an attribute of the class,
# so we construct it and proxy the attribute
# TODO: To get lazy loading (from_name) of classes to work, we need to avoid
# this path, otherwise local initialization will happen regardless if user
# only runs .remote(), since we don't know methods for the class until we
# load it
user_cls_instance = self._cached_user_cls_instance()
return getattr(user_cls_instance, k)


Obj = synchronize_api(_Obj)
Expand Down
61 changes: 32 additions & 29 deletions modal/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,18 +314,14 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
_tag: str
_raw_f: Callable[..., Any]
_build_args: dict
_can_use_base_function: bool = False # whether we need to call FunctionBindParams

_is_generator: Optional[bool] = None
_cluster_size: Optional[int] = None

# when this is the method of a class/object function, invocation of this function
# should supply the method name in the FunctionInput:
_use_method_name: str = ""

# TODO (elias): remove _parent. In case of instance functions, and methods bound on those,
# this references the parent class-function and is used to infer the client for lazy-loaded methods
_parent: Optional["_Function"] = None

_class_parameter_info: Optional["api_pb2.ClassParameterInfo"] = None
_method_handle_metadata: Optional[Dict[str, "api_pb2.FunctionHandleMetadata"]] = None

Expand Down Expand Up @@ -421,7 +417,6 @@ def _deps():
fun._info = class_bound_method._info
fun._obj = instance_service_function._obj
fun._is_method = True
fun._parent = instance_service_function._parent
fun._app = class_bound_method._app
fun._spec = class_bound_method._spec
return fun
Expand Down Expand Up @@ -931,62 +926,68 @@ def _bind_parameters(
Binds a class-function to a specific instance of (init params, options) or a new workspace
"""

async def _load(self: _Function, resolver: Resolver, existing_object_id: Optional[str]):
if self._parent is None:
# In some cases, reuse the base function, i.e. not create new clones of each method or the "service function"
can_use_parent = len(args) + len(kwargs) == 0 and not from_other_workspace and options is None
parent = self

async def _load(param_bound_func: _Function, resolver: Resolver, existing_object_id: Optional[str]):
if parent is None:
raise ExecutionError("Can't find the parent class' service function")
try:
identity = f"{self._parent.info.function_name} class service function"
identity = f"{parent.info.function_name} class service function"
except Exception:
# Can't always look up the function name that way, so fall back to generic message
identity = "class service function for a parameterized class"
if not self._parent.is_hydrated:
if self._parent.app._running_app is None:
if not parent.is_hydrated:
if parent.app._running_app is None:
reason = ", because the App it is defined on is not running"
else:
reason = ""
raise ExecutionError(
f"The {identity} has not been hydrated with the metadata it needs to run on Modal{reason}."
)
assert self._parent._client.stub

assert parent._client.stub

if can_use_parent:
# We can end up here if parent wasn't hydrated when class was instantiated, but has been since.
param_bound_func._hydrate_from_other(parent)
return

if (
self._parent._class_parameter_info
and self._parent._class_parameter_info.format
== api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PROTO
parent._class_parameter_info
and parent._class_parameter_info.format == api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PROTO
):
if args:
# TODO(elias) - We could potentially support positional args as well, if we want to?
raise InvalidError(
"Can't use positional arguments with modal.parameter-based synthetic constructors.\n"
"Use (<parameter_name>=value) keyword arguments when constructing classes instead."
)
serialized_params = serialize_proto_params(kwargs, self._parent._class_parameter_info.schema)
serialized_params = serialize_proto_params(kwargs, parent._class_parameter_info.schema)
else:
serialized_params = serialize((args, kwargs))
environment_name = _get_environment_name(None, resolver)
assert self._parent is not None
assert parent is not None
req = api_pb2.FunctionBindParamsRequest(
function_id=self._parent._object_id,
function_id=parent._object_id,
serialized_params=serialized_params,
function_options=options,
environment_name=environment_name
or "", # TODO: investigate shouldn't environment name always be specified here?
)

response = await retry_transient_errors(self._parent._client.stub.FunctionBindParams, req)
self._hydrate(response.bound_function_id, self._parent._client, response.handle_metadata)
response = await retry_transient_errors(parent._client.stub.FunctionBindParams, req)
param_bound_func._hydrate(response.bound_function_id, parent._client, response.handle_metadata)

fun: _Function = _Function._from_loader(_load, "Function(parametrized)", hydrate_lazily=True)

# In some cases, reuse the base function, i.e. not create new clones of each method or the "service function"
fun._can_use_base_function = len(args) + len(kwargs) == 0 and not from_other_workspace and options is None
if fun._can_use_base_function and self.is_hydrated:
# Edge case that lets us hydrate all objects right away
# if the instance didn't use explicit constructor arguments
fun._hydrate_from_other(self)
if can_use_parent and parent.is_hydrated:
# skip the resolver altogether:
fun._hydrate_from_other(parent)

fun._info = self._info
fun._obj = obj
fun._parent = self
return fun

@live_method
Expand Down Expand Up @@ -1171,8 +1172,10 @@ def _check_no_web_url(self, fn_name: str):
+ f"or call it locally: {self._function_name}.local()"
)

# TODO (live_method on properties is not great, since it could be blocking the event loop from async contexts)
@property
def web_url(self) -> str:
@live_method
async def web_url(self) -> str:
"""URL of a Function running as a web endpoint."""
if not self._web_url:
raise ValueError(
Expand Down Expand Up @@ -1345,7 +1348,7 @@ def local(self, *args: P.args, **kwargs: P.kwargs) -> OriginalReturnType:
return fun(*args, **kwargs)
else:
# This is a method on a class, so bind the self to the function
user_cls_instance = obj._get_user_cls_instance()
user_cls_instance = obj._cached_user_cls_instance()

fun = info.raw_f.__get__(user_cls_instance)

Expand Down
Loading

0 comments on commit 1b4d2f0

Please sign in to comment.