Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

io streams types #2533

Merged
merged 25 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 32 additions & 12 deletions modal/io_streams.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
# Copyright Modal Labs 2022
import asyncio
from typing import TYPE_CHECKING, AsyncGenerator, Generic, List, Literal, Optional, Tuple, TypeVar, Union
from typing import (
TYPE_CHECKING,
AsyncGenerator,
AsyncIterator,
Generic,
List,
Literal,
Optional,
Tuple,
TypeVar,
Union,
cast,
)

from grpclib import Status
from grpclib.exceptions import GRPCError, StreamTerminatedError
Expand All @@ -18,7 +30,7 @@


async def _sandbox_logs_iterator(
sandbox_id: str, file_descriptor: int, last_entry_id: Optional[str], client: _Client
sandbox_id: str, file_descriptor: "api_pb2.FileDescriptor.ValueType", last_entry_id: str, client: _Client
) -> AsyncGenerator[Tuple[Optional[bytes], str], None]:
req = api_pb2.SandboxGetLogsRequest(
sandbox_id=sandbox_id,
Expand All @@ -37,7 +49,7 @@ async def _sandbox_logs_iterator(


async def _container_process_logs_iterator(
process_id: str, file_descriptor: int, client: _Client
process_id: str, file_descriptor: "api_pb2.FileDescriptor.ValueType", client: _Client
) -> AsyncGenerator[Optional[bytes], None]:
req = api_pb2.ContainerExecGetOutputRequest(
exec_id=process_id, timeout=55, file_descriptor=file_descriptor, get_raw_bytes=True
Expand Down Expand Up @@ -74,9 +86,11 @@ class _StreamReader(Generic[T]):
```
"""

_stream: Optional[AsyncGenerator[Optional[bytes], None]]

def __init__(
self,
file_descriptor: int,
file_descriptor: "api_pb2.FileDescriptor.ValueType",
object_id: str,
object_type: Literal["sandbox", "container_process"],
client: _Client,
Expand All @@ -90,7 +104,7 @@ def __init__(
self._object_id = object_id
self._client = client
self._stream = None
self._last_entry_id: Optional[str] = None
self._last_entry_id: str = ""
self._line_buffer = b""

# Sandbox logs are streamed to the client as strings, so StreamReaders reading
Expand Down Expand Up @@ -145,15 +159,20 @@ async def read(self) -> T:
```

"""
data = "" if self._text else b""
data_str = ""
data_bytes = b""
async for message in self._get_logs():
if message is None:
break
if self._text:
data += message.decode("utf-8")
data_str += message.decode("utf-8")
else:
data += message
return data
data_bytes += message

if self._text:
return cast(T, data_str)
else:
return cast(T, data_bytes)

async def _consume_container_process_stream(self):
"""
Expand Down Expand Up @@ -275,7 +294,7 @@ async def _get_logs_by_line(self) -> AsyncGenerator[Optional[bytes], None]:
line, self._line_buffer = self._line_buffer.split(b"\n", 1)
yield line + b"\n"

def __aiter__(self) -> AsyncGenerator[T, None]:
def __aiter__(self) -> AsyncIterator[T]:
"""mdmd:hidden"""
if not self._stream:
if self._by_line:
Expand All @@ -287,16 +306,17 @@ def __aiter__(self) -> AsyncGenerator[T, None]:
async def __anext__(self) -> T:
"""mdmd:hidden"""
assert self._stream is not None

value = await self._stream.__anext__()

# The stream yields None if it receives an EOF batch.
if value is None:
raise StopAsyncIteration

if self._text:
return value.decode("utf-8")
return cast(T, value.decode("utf-8"))
else:
return value
return cast(T, value)


MAX_BUFFER_SIZE = 2 * 1024 * 1024
Expand Down
1 change: 1 addition & 0 deletions tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def type_check(ctx):
"modal/_utils/shell_utils.py",
"test/cls_test.py", # see mypy bug above - but this works with pyright, so we run that instead
"modal/_container_io_manager.py",
"modal/io_streams.py",
]
ctx.run(f"pyright {' '.join(pyright_allowlist)}", pty=True)

Expand Down
Loading