diff --git a/modal/io_streams.py b/modal/io_streams.py index ebea695e4..59d2f8008 100644 --- a/modal/io_streams.py +++ b/modal/io_streams.py @@ -1,6 +1,18 @@ # Copyright Modal Labs 2022 import asyncio -from typing import TYPE_CHECKING, AsyncGenerator, Generic, List, Literal, Optional, Tuple, TypeVar, Union +from typing import ( + TYPE_CHECKING, + AsyncGenerator, + AsyncIterator, + Generic, + List, + Literal, + Optional, + Tuple, + TypeVar, + Union, + cast, +) from grpclib import Status from grpclib.exceptions import GRPCError, StreamTerminatedError @@ -18,7 +30,7 @@ async def _sandbox_logs_iterator( - sandbox_id: str, file_descriptor: int, last_entry_id: Optional[str], client: _Client + sandbox_id: str, file_descriptor: "api_pb2.FileDescriptor.ValueType", last_entry_id: str, client: _Client ) -> AsyncGenerator[Tuple[Optional[bytes], str], None]: req = api_pb2.SandboxGetLogsRequest( sandbox_id=sandbox_id, @@ -37,7 +49,7 @@ async def _sandbox_logs_iterator( async def _container_process_logs_iterator( - process_id: str, file_descriptor: int, client: _Client + process_id: str, file_descriptor: "api_pb2.FileDescriptor.ValueType", client: _Client ) -> AsyncGenerator[Optional[bytes], None]: req = api_pb2.ContainerExecGetOutputRequest( exec_id=process_id, timeout=55, file_descriptor=file_descriptor, get_raw_bytes=True @@ -74,9 +86,11 @@ class _StreamReader(Generic[T]): ``` """ + _stream: Optional[AsyncGenerator[Optional[bytes], None]] + def __init__( self, - file_descriptor: int, + file_descriptor: "api_pb2.FileDescriptor.ValueType", object_id: str, object_type: Literal["sandbox", "container_process"], client: _Client, @@ -90,7 +104,7 @@ def __init__( self._object_id = object_id self._client = client self._stream = None - self._last_entry_id: Optional[str] = None + self._last_entry_id: str = "" self._line_buffer = b"" # Sandbox logs are streamed to the client as strings, so StreamReaders reading @@ -145,15 +159,20 @@ async def read(self) -> T: ``` """ - data = "" if self._text else b"" + data_str = "" + data_bytes = b"" async for message in self._get_logs(): if message is None: break if self._text: - data += message.decode("utf-8") + data_str += message.decode("utf-8") else: - data += message - return data + data_bytes += message + + if self._text: + return cast(T, data_str) + else: + return cast(T, data_bytes) async def _consume_container_process_stream(self): """ @@ -275,7 +294,7 @@ async def _get_logs_by_line(self) -> AsyncGenerator[Optional[bytes], None]: line, self._line_buffer = self._line_buffer.split(b"\n", 1) yield line + b"\n" - def __aiter__(self) -> AsyncGenerator[T, None]: + def __aiter__(self) -> AsyncIterator[T]: """mdmd:hidden""" if not self._stream: if self._by_line: @@ -287,6 +306,7 @@ def __aiter__(self) -> AsyncGenerator[T, None]: async def __anext__(self) -> T: """mdmd:hidden""" assert self._stream is not None + value = await self._stream.__anext__() # The stream yields None if it receives an EOF batch. @@ -294,9 +314,9 @@ async def __anext__(self) -> T: raise StopAsyncIteration if self._text: - return value.decode("utf-8") + return cast(T, value.decode("utf-8")) else: - return value + return cast(T, value) MAX_BUFFER_SIZE = 2 * 1024 * 1024 diff --git a/tasks.py b/tasks.py index ebbede6f5..954b3faef 100644 --- a/tasks.py +++ b/tasks.py @@ -149,6 +149,7 @@ def type_check(ctx): "modal/_utils/shell_utils.py", "test/cls_test.py", # see mypy bug above - but this works with pyright, so we run that instead "modal/_container_io_manager.py", + "modal/io_streams.py", ] ctx.run(f"pyright {' '.join(pyright_allowlist)}", pty=True)