Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wip: sandbox snapshots #2675

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
37 changes: 35 additions & 2 deletions modal/sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ._utils.deprecation import deprecation_error
from ._utils.grpc_utils import retry_transient_errors
from ._utils.mount_utils import validate_network_file_systems, validate_volumes
from .app import _App
from .client import _Client
from .config import config
from .container_process import _ContainerProcess
Expand Down Expand Up @@ -58,6 +59,7 @@ class _Sandbox(_Object, type_prefix="sb"):
_stdin: _StreamWriter
_task_id: Optional[str] = None
_tunnels: Optional[dict[int, Tunnel]] = None
_enable_snapshot: bool = False

@staticmethod
def _new(
Expand All @@ -81,6 +83,7 @@ def _new(
unencrypted_ports: Sequence[int] = [],
proxy: Optional[_Proxy] = None,
_experimental_scheduler_placement: Optional[SchedulerPlacement] = None,
enable_snapshot: bool = False,
) -> "_Sandbox":
"""mdmd:hidden"""

Expand Down Expand Up @@ -177,6 +180,7 @@ async def _load(self: _Sandbox, resolver: Resolver, _existing_object_id: Optiona
open_ports=api_pb2.PortSpecs(ports=open_ports),
network_access=network_access,
proxy_id=(proxy.object_id if proxy else None),
enable_snapshot=enable_snapshot,
)

# Note - `resolver.app_id` will be `None` for app-less sandboxes
Expand Down Expand Up @@ -224,13 +228,13 @@ async def create(
unencrypted_ports: Sequence[int] = [],
# Reference to a Modal Proxy to use in front of this Sandbox.
proxy: Optional[_Proxy] = None,
# Enable memory snapshots.
enable_snapshot: bool = False,
_experimental_scheduler_placement: Optional[
SchedulerPlacement
] = None, # Experimental controls over fine-grained scheduling (alpha).
client: Optional[_Client] = None,
) -> "_Sandbox":
from .app import _App

environment_name = _get_environment_name(environment_name)

# If there are no entrypoint args, we'll sleep forever so that the sandbox will stay
Expand Down Expand Up @@ -261,7 +265,9 @@ async def create(
unencrypted_ports=unencrypted_ports,
proxy=proxy,
_experimental_scheduler_placement=_experimental_scheduler_placement,
enable_snapshot=enable_snapshot,
)
obj._enable_snapshot = enable_snapshot

app_id: Optional[str] = None
app_client: Optional[_Client] = None
Expand Down Expand Up @@ -534,6 +540,33 @@ async def exec(
by_line = bufsize == 1
return _ContainerProcess(resp.exec_id, self._client, stdout=stdout, stderr=stderr, text=text, by_line=by_line)

async def snapshot(self) -> str:
if not self._enable_snapshot:
raise ValueError(
"Memory snapshots are not supported for this sandbox. To enable memory snapshots, "
"set `enable_snapshot=True` when creating the sandbox."
)
req = api_pb2.SandboxSnapshotRequest(sandbox_id=self.object_id)
resp = await retry_transient_errors(self._client.stub.SandboxSnapshot, req)
snapshot_id = resp.snapshot_id
wait_req = api_pb2.SandboxSnapshotWaitRequest(snapshot_id=resp.snapshot_id, timeout=55.0)
resp = await retry_transient_errors(self._client.stub.SandboxSnapshotWait, wait_req)
if resp.result.status != api_pb2.GenericResult.GENERIC_STATUS_SUCCESS:
raise ExecutionError(resp.result.exception)
return snapshot_id

@staticmethod
async def from_snapshot(snapshot_id: str, client: Optional[_Client] = None):
client = client or await _Client.from_env()

req = api_pb2.SandboxRestoreRequest(snapshot_id=snapshot_id)
resp: api_pb2.SandboxRestoreResponse = await retry_transient_errors(client.stub.SandboxRestore, req)
sandbox = await _Sandbox.from_id(resp.sandbox_id, client)
wait_req = api_pb2.SandboxWaitRequest(sandbox_id=resp.sandbox_id, timeout=0)
resp = await retry_transient_errors(client.stub.SandboxWait, wait_req)
print("from_snapshot resp", resp)
return sandbox

@overload
async def open(
self,
Expand Down
4 changes: 1 addition & 3 deletions modal_proto/api.proto
Original file line number Diff line number Diff line change
Expand Up @@ -2318,9 +2318,7 @@ message SandboxListResponse {
}

message SandboxRestoreRequest {
string app_id = 1 [ (modal.options.audit_target_attr) = true ];
string snapshot_id = 2;
string environment_name = 3;
string snapshot_id = 1;
}

message SandboxRestoreResponse {
Expand Down
Loading