Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise EOF when writing to a closed sandbox over the network #2330

Merged
merged 1 commit into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 24 additions & 17 deletions modal/io_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import asyncio
from typing import TYPE_CHECKING, AsyncIterator, Literal, Optional, Tuple, Union

from grpclib import Status
from grpclib.exceptions import GRPCError, StreamTerminatedError

from modal_proto import api_pb2
Expand Down Expand Up @@ -226,7 +227,7 @@ def write(self, data: Union[bytes, bytearray, memoryview, str]):
```
"""
if self._is_closed:
raise EOFError("Stdin is closed. Cannot write to it.")
raise ValueError("Stdin is closed. Cannot write to it.")
if isinstance(data, (bytes, bytearray, memoryview, str)):
if isinstance(data, str):
data = data.encode("utf-8")
Expand All @@ -247,27 +248,33 @@ def write_eof(self):

async def drain(self):
"""
Flushes the write buffer and EOF to the running process.
Flushes the write buffer to the running process. Flushes the EOF if the writer is closed.
"""
data = bytes(self._buffer)
self._buffer.clear()
index = self.get_next_index()

if self._object_type == "sandbox":
await retry_transient_errors(
self._client.stub.SandboxStdinWrite,
api_pb2.SandboxStdinWriteRequest(
sandbox_id=self._object_id, index=index, eof=self._is_closed, input=data
),
)
else:
await retry_transient_errors(
self._client.stub.ContainerExecPutInput,
api_pb2.ContainerExecPutInputRequest(
exec_id=self._object_id,
input=api_pb2.RuntimeInputMessage(message=data, message_index=index, eof=self._is_closed),
),
)
try:
if self._object_type == "sandbox":
await retry_transient_errors(
self._client.stub.SandboxStdinWrite,
api_pb2.SandboxStdinWriteRequest(
sandbox_id=self._object_id, index=index, eof=self._is_closed, input=data
),
)
else:
await retry_transient_errors(
self._client.stub.ContainerExecPutInput,
api_pb2.ContainerExecPutInputRequest(
exec_id=self._object_id,
input=api_pb2.RuntimeInputMessage(message=data, message_index=index, eof=self._is_closed),
),
)
except GRPCError as exc:
if exc.status == Status.FAILED_PRECONDITION:
raise ValueError(exc.message)
else:
raise exc


StreamReader = synchronize_api(_StreamReader)
Expand Down
3 changes: 3 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1329,6 +1329,9 @@ async def SandboxGetTaskId(self, stream):
async def SandboxStdinWrite(self, stream):
request: api_pb2.SandboxStdinWriteRequest = await stream.recv_message()

if self.sandbox.returncode is not None:
raise GRPCError(Status.FAILED_PRECONDITION, "Sandbox has already terminated")

self.sandbox.stdin.write(request.input)
await self.sandbox.stdin.drain()

Expand Down
14 changes: 12 additions & 2 deletions test/sandbox_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,21 @@ def test_sandbox_stdin_write_str(app, servicer):


@skip_non_linux
def test_sandbox_stdin_write_after_eof(app, servicer):
def test_sandbox_stdin_write_after_terminate(app, servicer):
sb = Sandbox.create("bash", "-c", "echo foo", app=app)
sb.wait()
with pytest.raises(ValueError):
sb.stdin.write(b"foo")
sb.stdin.drain()


@skip_non_linux
def test_sandbox_stdin_write_after_eof(app, servicer):
sb = Sandbox.create(app=app)
sb.stdin.write_eof()
with pytest.raises(EOFError):
with pytest.raises(ValueError):
sb.stdin.write(b"foo")
sb.terminate()


@skip_non_linux
Expand Down
Loading