Skip to content

Commit

Permalink
Raise EOF when writing to a closed sandbox over the network (#2330)
Browse files Browse the repository at this point in the history
Re-enable checkpointing with parameterized functions (#2339)

This feature works again.
  • Loading branch information
pawalt authored Oct 16, 2024
1 parent c992a7b commit 6fed77c
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 19 deletions.
41 changes: 24 additions & 17 deletions modal/io_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import asyncio
from typing import TYPE_CHECKING, AsyncIterator, Literal, Optional, Tuple, Union

from grpclib import Status
from grpclib.exceptions import GRPCError, StreamTerminatedError

from modal_proto import api_pb2
Expand Down Expand Up @@ -249,7 +250,7 @@ def write(self, data: Union[bytes, bytearray, memoryview, str]):
```
"""
if self._is_closed:
raise EOFError("Stdin is closed. Cannot write to it.")
raise ValueError("Stdin is closed. Cannot write to it.")
if isinstance(data, (bytes, bytearray, memoryview, str)):
if isinstance(data, str):
data = data.encode("utf-8")
Expand All @@ -270,27 +271,33 @@ def write_eof(self):

async def drain(self):
"""
Flushes the write buffer and EOF to the running process.
Flushes the write buffer to the running process. Flushes the EOF if the writer is closed.
"""
data = bytes(self._buffer)
self._buffer.clear()
index = self.get_next_index()

if self._object_type == "sandbox":
await retry_transient_errors(
self._client.stub.SandboxStdinWrite,
api_pb2.SandboxStdinWriteRequest(
sandbox_id=self._object_id, index=index, eof=self._is_closed, input=data
),
)
else:
await retry_transient_errors(
self._client.stub.ContainerExecPutInput,
api_pb2.ContainerExecPutInputRequest(
exec_id=self._object_id,
input=api_pb2.RuntimeInputMessage(message=data, message_index=index, eof=self._is_closed),
),
)
try:
if self._object_type == "sandbox":
await retry_transient_errors(
self._client.stub.SandboxStdinWrite,
api_pb2.SandboxStdinWriteRequest(
sandbox_id=self._object_id, index=index, eof=self._is_closed, input=data
),
)
else:
await retry_transient_errors(
self._client.stub.ContainerExecPutInput,
api_pb2.ContainerExecPutInputRequest(
exec_id=self._object_id,
input=api_pb2.RuntimeInputMessage(message=data, message_index=index, eof=self._is_closed),
),
)
except GRPCError as exc:
if exc.status == Status.FAILED_PRECONDITION:
raise ValueError(exc.message)
else:
raise exc


StreamReader = synchronize_api(_StreamReader)
Expand Down
3 changes: 3 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1329,6 +1329,9 @@ async def SandboxGetTaskId(self, stream):
async def SandboxStdinWrite(self, stream):
request: api_pb2.SandboxStdinWriteRequest = await stream.recv_message()

if self.sandbox.returncode is not None:
raise GRPCError(Status.FAILED_PRECONDITION, "Sandbox has already terminated")

self.sandbox.stdin.write(request.input)
await self.sandbox.stdin.drain()

Expand Down
14 changes: 12 additions & 2 deletions test/sandbox_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,21 @@ def test_sandbox_stdin_write_str(app, servicer):


@skip_non_linux
def test_sandbox_stdin_write_after_eof(app, servicer):
def test_sandbox_stdin_write_after_terminate(app, servicer):
sb = Sandbox.create("bash", "-c", "echo foo", app=app)
sb.wait()
with pytest.raises(ValueError):
sb.stdin.write(b"foo")
sb.stdin.drain()


@skip_non_linux
def test_sandbox_stdin_write_after_eof(app, servicer):
sb = Sandbox.create(app=app)
sb.stdin.write_eof()
with pytest.raises(EOFError):
with pytest.raises(ValueError):
sb.stdin.write(b"foo")
sb.terminate()


@skip_non_linux
Expand Down

0 comments on commit 6fed77c

Please sign in to comment.