diff --git a/.github/workflows/build-push.yaml b/.github/workflows/build-push.yaml index 554020cb..ad6604de 100644 --- a/.github/workflows/build-push.yaml +++ b/.github/workflows/build-push.yaml @@ -20,9 +20,10 @@ jobs: example_directories: [ "examples/map/even_odd", "examples/map/flatmap", "examples/map/forward_message", "examples/map/multiproc_map", "examples/mapstream/flatmap_stream", "examples/reduce/counter", - "examples/reducestream/counter", "examples/reducestream/sum", "examples/sideinput/simple_sideinput", - "examples/sideinput/simple_sideinput/udf", "examples/sink/async_log", "examples/sink/log", - "examples/source/async_source", "examples/source/simple_source", "examples/sourcetransform/event_time_filter" + "examples/reducestream/counter", "examples/reducestream/sum", "examples/sideinput/simple-sideinput", + "examples/sideinput/simple-sideinput/udf", "examples/sink/async_log", "examples/sink/log", + "examples/source/async-source", "examples/source/simple-source", "examples/sourcetransform/event_time_filter", + "examples/batchmap/flatmap" ] steps: diff --git a/Makefile b/Makefile index d5bd7c13..2dbf00d9 100644 --- a/Makefile +++ b/Makefile @@ -34,6 +34,7 @@ proto: python3 -m grpc_tools.protoc -I=pynumaflow/proto/sourcetransformer --python_out=pynumaflow/proto/sourcetransformer --grpc_python_out=pynumaflow/proto/sourcetransformer pynumaflow/proto/sourcetransformer/*.proto python3 -m grpc_tools.protoc -I=pynumaflow/proto/sideinput --python_out=pynumaflow/proto/sideinput --grpc_python_out=pynumaflow/proto/sideinput pynumaflow/proto/sideinput/*.proto python3 -m grpc_tools.protoc -I=pynumaflow/proto/sourcer --python_out=pynumaflow/proto/sourcer --grpc_python_out=pynumaflow/proto/sourcer pynumaflow/proto/sourcer/*.proto + python3 -m grpc_tools.protoc -I=pynumaflow/proto/batchmapper --python_out=pynumaflow/proto/batchmapper --grpc_python_out=pynumaflow/proto/batchmapper pynumaflow/proto/batchmapper/*.proto sed -i '' 's/^\(import.*_pb2\)/from . \1/' pynumaflow/proto/*/*.py diff --git a/README.md b/README.md index 17978a04..c7f397f2 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,7 @@ pre-commit install - [Map](https://github.com/numaproj/numaflow-python/tree/main/examples/map) - [Reduce](https://github.com/numaproj/numaflow-python/tree/main/examples/reduce) - [Map Stream](https://github.com/numaproj/numaflow-python/tree/main/examples/mapstream) + - [Batch Map](https://github.com/numaproj/numaflow-python/tree/main/examples/batchmap) - [Implement User Defined Sinks](https://github.com/numaproj/numaflow-python/tree/main/examples/sink) - [Implement User Defined SideInputs](https://github.com/numaproj/numaflow-python/tree/main/examples/sideinput) @@ -95,7 +96,7 @@ This could be an alternative to creating multiple replicas of the same UDF conta Thus this server type is useful for UDFs which are CPU intensive. ``` -grpc_server = MapMultiProcServer(handler) +grpc_server = MapMultiProcServer(mapper_instance=handler, server_count=2) ``` #### Currently Supported Server Types for each functionality @@ -111,6 +112,8 @@ These are the class names for the server types supported by each of the function - ReduceAsyncServer - MapStream - MapStreamAsyncServer + - BatchMap + - BatchMapAsyncServer - Source Transform - SourceTransformServer - SourceTransformMultiProcServer @@ -147,6 +150,8 @@ The list of base handler classes for each of the functionalities is given below - MapStreamer - Source Transform - SourceTransformer + - Batch Map + - BatchMapper - UDSource - Sourcer - UDSink diff --git a/examples/batchmap/README.md b/examples/batchmap/README.md new file mode 100644 index 00000000..8892edcc --- /dev/null +++ b/examples/batchmap/README.md @@ -0,0 +1,66 @@ +## BatchMap Interface +The BatchMap interface allows developers to +process multiple data items together in a single UDF handler. + + +### What is BatchMap? +BatchMap is an interface that allows developers to process multiple data items +in a UDF single call, rather than each item in separate calls. + + +The BatchMap interface can be helpful in scenarios +where performing operations on a group of data can be more efficient. + + +### Understanding the User Interface +The BatchMap interface requires developers to implement a handler with a specific signature. +Here is the signature of the BatchMap handler: + +```python +async def handler(datums: AsyncIterable[Datum]) -> BatchResponses: +``` +The handler takes an iterable of `Datum` objects and returns +`BatchResponses`. +The `BatchResponses` object is a list of the *same length* as the input +datums, with each item corresponding to the response for one request datum. + +To clarify, let's say we have three data items: + +``` +data_1 = {"name": "John", "age": 25} +data_2 = {"name": "Jane", "age": 30} +data_3 = {"name": "Bob", "age": 45} +``` + +These data items will be grouped together by numaflow and +passed to the handler as an iterable: + +```python +result = await handler([data_1, data_2, data_3]) +``` + +The result will be a BatchResponses object, which is a list of responses corresponding to each input data item's processing. + +### Important Considerations +When using BatchMap, there are a few important considerations to keep in mind: + +- Ensure that the `BatchResponses` object is tagged with the *correct request ID*. +Each Datum has a unique ID tag, which will be used by Numaflow to ensure correctness. + +```python +async for datum in datums: + batch_response = BatchResponse.from_id(datum.id) +``` + + +- Ensure that the length of the `BatchResponses` +list is equal to the number of requests received. +**This means that for every input data item**, there should be a corresponding +response in the BatchResponses list. + +Use batch processing only when it makes sense. In some +scenarios, batch processing may not be the most +efficient approach, and processing data items one by one +could be a better option. +The burden of concurrent processing of the data will rely on the +UDF implementation in this use case. diff --git a/examples/batchmap/flatmap/Dockerfile b/examples/batchmap/flatmap/Dockerfile new file mode 100644 index 00000000..e22a0108 --- /dev/null +++ b/examples/batchmap/flatmap/Dockerfile @@ -0,0 +1,54 @@ +#################################################################################################### +# builder: install needed dependencies +#################################################################################################### + +FROM python:3.10-slim-bullseye AS builder + +ENV PYTHONFAULTHANDLER=1 \ + PYTHONUNBUFFERED=1 \ + PYTHONHASHSEED=random \ + PIP_NO_CACHE_DIR=on \ + PIP_DISABLE_PIP_VERSION_CHECK=on \ + PIP_DEFAULT_TIMEOUT=100 \ + POETRY_VERSION=1.2.2 \ + POETRY_HOME="/opt/poetry" \ + POETRY_VIRTUALENVS_IN_PROJECT=true \ + POETRY_NO_INTERACTION=1 \ + PYSETUP_PATH="/opt/pysetup" + +ENV EXAMPLE_PATH="$PYSETUP_PATH/examples/batchmap/flatmap" +ENV VENV_PATH="$EXAMPLE_PATH/.venv" +ENV PATH="$POETRY_HOME/bin:$VENV_PATH/bin:$PATH" + +RUN apt-get update \ + && apt-get install --no-install-recommends -y \ + curl \ + wget \ + # deps for building python deps + build-essential \ + && apt-get install -y git \ + && apt-get clean && rm -rf /var/lib/apt/lists/* \ + \ + # install dumb-init + && wget -O /dumb-init https://github.com/Yelp/dumb-init/releases/download/v1.2.5/dumb-init_1.2.5_x86_64 \ + && chmod +x /dumb-init \ + && curl -sSL https://install.python-poetry.org | python3 - + +#################################################################################################### +# udf: used for running the udf vertices +#################################################################################################### +FROM builder AS udf + +WORKDIR $PYSETUP_PATH +COPY ./ ./ + +WORKDIR $EXAMPLE_PATH +RUN poetry install --no-cache --no-root && \ + rm -rf ~/.cache/pypoetry/ + +RUN chmod +x entry.sh + +ENTRYPOINT ["/dumb-init", "--"] +CMD ["sh", "-c", "$EXAMPLE_PATH/entry.sh"] + +EXPOSE 5000 diff --git a/examples/batchmap/flatmap/Makefile b/examples/batchmap/flatmap/Makefile new file mode 100644 index 00000000..6b5e4d44 --- /dev/null +++ b/examples/batchmap/flatmap/Makefile @@ -0,0 +1,22 @@ +TAG ?= stable +PUSH ?= false +IMAGE_REGISTRY = quay.io/numaio/numaflow-python/batch-map-flatmap:${TAG} +DOCKER_FILE_PATH = examples/batchmap/flatmap/Dockerfile + +.PHONY: update +update: + poetry update -vv + +.PHONY: image-push +image-push: update + cd ../../../ && docker buildx build \ + -f ${DOCKER_FILE_PATH} \ + -t ${IMAGE_REGISTRY} \ + --platform linux/amd64,linux/arm64 . --push + +.PHONY: image +image: update + cd ../../../ && docker build \ + -f ${DOCKER_FILE_PATH} \ + -t ${IMAGE_REGISTRY} . + @if [ "$(PUSH)" = "true" ]; then docker push ${IMAGE_REGISTRY}; fi diff --git a/examples/batchmap/flatmap/entry.sh b/examples/batchmap/flatmap/entry.sh new file mode 100644 index 00000000..073b05e3 --- /dev/null +++ b/examples/batchmap/flatmap/entry.sh @@ -0,0 +1,4 @@ +#!/bin/sh +set -eux + +python example.py diff --git a/examples/batchmap/flatmap/example.py b/examples/batchmap/flatmap/example.py new file mode 100644 index 00000000..ee7455fa --- /dev/null +++ b/examples/batchmap/flatmap/example.py @@ -0,0 +1,46 @@ +from collections.abc import AsyncIterable + +from pynumaflow.batchmapper import ( + Message, + Datum, + BatchMapper, + BatchMapAsyncServer, + BatchResponses, + BatchResponse, +) + + +class Flatmap(BatchMapper): + """ + This is a class that inherits from the BatchMapper class. + It implements a flatmap operation over a batch of input messages + """ + + async def handler( + self, + datums: AsyncIterable[Datum], + ) -> BatchResponses: + batch_responses = BatchResponses() + async for datum in datums: + val = datum.value + _ = datum.event_time + _ = datum.watermark + strs = val.decode("utf-8").split(",") + batch_response = BatchResponse.from_id(datum.id) + if len(strs) == 0: + batch_response.append(Message.to_drop()) + else: + for s in strs: + batch_response.append(Message(str.encode(s))) + batch_responses.append(batch_response) + + return batch_responses + + +if __name__ == "__main__": + """ + This example shows how to use the Batch Map Flatmap. + We use a class as handler, but a function can be used as well. + """ + grpc_server = BatchMapAsyncServer(Flatmap()) + grpc_server.start() diff --git a/examples/batchmap/flatmap/pipeline.yaml b/examples/batchmap/flatmap/pipeline.yaml new file mode 100644 index 00000000..d7d37db4 --- /dev/null +++ b/examples/batchmap/flatmap/pipeline.yaml @@ -0,0 +1,33 @@ +apiVersion: numaflow.numaproj.io/v1alpha1 +kind: Pipeline +metadata: + name: flatmap +spec: + vertices: + - name: in + source: + # A self data generating source + generator: + rpu: 500 + duration: 1s + - name: batch-flatmap + partitions: 2 + metadata: + annotations: + numaflow.numaproj.io/batch-map: "true" + scale: + min: 1 + udf: + container: + image: quay.io/numaio/numaflow-python/batch-map-flatmap:stable + imagePullPolicy: Always + - name: sink + scale: + min: 1 + sink: + log: {} + edges: + - from: in + to: batch-flatmap + - from: batch-flatmap + to: sink diff --git a/examples/batchmap/flatmap/pyproject.toml b/examples/batchmap/flatmap/pyproject.toml new file mode 100644 index 00000000..20b28b76 --- /dev/null +++ b/examples/batchmap/flatmap/pyproject.toml @@ -0,0 +1,15 @@ +[tool.poetry] +name = "batch-map-flatmap" +version = "0.1.0" +description = "" +authors = ["Numaflow developers"] + +[tool.poetry.dependencies] +python = "~3.10" +pynumaflow = { path = "../../../"} + +[tool.poetry.dev-dependencies] + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" diff --git a/pynumaflow/_constants.py b/pynumaflow/_constants.py index 2cf9755c..cbbba8e4 100644 --- a/pynumaflow/_constants.py +++ b/pynumaflow/_constants.py @@ -17,6 +17,7 @@ SOURCE_SOCK_PATH = "/var/run/numaflow/source.sock" MULTIPROC_MAP_SOCK_ADDR = "/var/run/numaflow/multiproc" FALLBACK_SINK_SOCK_PATH = "/var/run/numaflow/fb-sink.sock" +BATCH_MAP_SOCK_PATH = "/var/run/numaflow/batchmap.sock" # Server information file configs MAP_SERVER_INFO_FILE_PATH = "/var/run/numaflow/mapper-server-info" @@ -28,6 +29,7 @@ SIDE_INPUT_SERVER_INFO_FILE_PATH = "/var/run/numaflow/sideinput-server-info" SOURCE_SERVER_INFO_FILE_PATH = "/var/run/numaflow/sourcer-server-info" FALLBACK_SINK_SERVER_INFO_FILE_PATH = "/var/run/numaflow/fb-sinker-server-info" +BATCH_MAP_SERVER_INFO_FILE_PATH = "/var/run/numaflow/batchmapper-server-info" ENV_UD_CONTAINER_TYPE = "NUMAFLOW_UD_CONTAINER_TYPE" UD_CONTAINER_FALLBACK_SINK = "fb-udsink" @@ -43,7 +45,10 @@ DROP = "U+005C__DROP__" _PROCESS_COUNT = os.cpu_count() -MAX_THREADS = int(os.getenv("MAX_THREADS", "4")) +# Cap max value to 16 +MAX_NUM_THREADS = 16 +# If NUM_THREADS_DEFAULT env is not set default to 4 +NUM_THREADS_DEFAULT = int(os.getenv("MAX_THREADS", "4")) _LOGGER = setup_logging(__name__) if os.getenv("PYTHONDEBUG"): diff --git a/pynumaflow/batchmapper/__init__.py b/pynumaflow/batchmapper/__init__.py new file mode 100644 index 00000000..0438d8f6 --- /dev/null +++ b/pynumaflow/batchmapper/__init__.py @@ -0,0 +1,20 @@ +from pynumaflow._constants import DROP + +from pynumaflow.batchmapper._dtypes import ( + Message, + Datum, + BatchMapper, + BatchResponses, + BatchResponse, +) +from pynumaflow.batchmapper.async_server import BatchMapAsyncServer + +__all__ = [ + "Message", + "Datum", + "DROP", + "BatchMapAsyncServer", + "BatchMapper", + "BatchResponses", + "BatchResponse", +] diff --git a/pynumaflow/batchmapper/_dtypes.py b/pynumaflow/batchmapper/_dtypes.py new file mode 100644 index 00000000..91762a48 --- /dev/null +++ b/pynumaflow/batchmapper/_dtypes.py @@ -0,0 +1,226 @@ +from abc import ABCMeta, abstractmethod +from collections.abc import Iterator, Sequence +from dataclasses import dataclass +from datetime import datetime +from typing import TypeVar, Callable, Union, Optional +from collections.abc import AsyncIterable +from collections.abc import Awaitable + +from pynumaflow._constants import DROP + +M = TypeVar("M", bound="Message") +B = TypeVar("B", bound="BatchResponse") +Bs = TypeVar("Bs", bound="BatchResponses") + + +@dataclass(init=False) +class Message: + """ + Basic datatype for data passing to the next vertex/vertices. + + Args: + value: data in bytes + keys: []string keys for vertex (optional) + tags: []string tags for conditional forwarding (optional) + """ + + __slots__ = ("_value", "_keys", "_tags") + + _value: bytes + _keys: list[str] + _tags: list[str] + + def __init__( + self, value: bytes, keys: Optional[list[str]] = None, tags: Optional[list[str]] = None + ): + """ + Creates a Message object to send value to a vertex. + """ + self._keys = keys or [] + self._tags = tags or [] + self._value = value or b"" + + # returns the Message Object which will be dropped + @classmethod + def to_drop(cls: type[M]) -> M: + return cls(b"", None, [DROP]) + + @property + def value(self) -> bytes: + return self._value + + @property + def keys(self) -> list[str]: + return self._keys + + @property + def tags(self) -> list[str]: + return self._tags + + +@dataclass(init=False) +class Datum: + """ + Class to define the important information for the event. + Args: + keys: the keys of the event. + value: the payload of the event. + event_time: the event time of the event. + watermark: the watermark of the event. + headers: the headers of the event. + id: the unique ID for this request + """ + + __slots__ = ("_keys", "_value", "_event_time", "_watermark", "_headers", "_id") + + _keys: list[str] + _value: bytes + _event_time: datetime + _watermark: datetime + _headers: dict[str, str] + _id: str + + def __init__( + self, + id: str, + keys: list[str], + value: bytes, + event_time: datetime, + watermark: datetime, + headers: Optional[dict[str, str]] = None, + ): + self._id = id + self._keys = keys or list() + self._value = value or b"" + if not isinstance(event_time, datetime): + raise TypeError(f"Wrong data type: {type(event_time)} for Datum.event_time") + self._event_time = event_time + if not isinstance(watermark, datetime): + raise TypeError(f"Wrong data type: {type(watermark)} for Datum.watermark") + self._watermark = watermark + self._headers = headers or {} + + def keys(self) -> list[str]: + """Returns the keys of the event""" + return self._keys + + @property + def value(self) -> bytes: + """Returns the value of the event.""" + return self._value + + @property + def event_time(self) -> datetime: + """Returns the event time of the event.""" + return self._event_time + + @property + def watermark(self) -> datetime: + """Returns the watermark of the event.""" + return self._watermark + + @property + def headers(self) -> dict[str, str]: + """Returns the headers of the event.""" + return self._headers.copy() + + @property + def id(self) -> str: + """Returns the id of the event.""" + return self._id + + +@dataclass +class BatchResponse: + """ + Basic datatype for Batch map response. + + Args: + id: the id of the request. + messages: list of responses for corresponding to the request id + """ + + _id: str + messages: list[M] + + __slots__ = ("_id", "messages") + + @classmethod + def from_id(cls: type[B], id_: str) -> B: + return BatchResponse(_id=id_, messages=[]) + + @classmethod + def with_msgs(cls: type[B], id_: str, msgs: list[M]) -> B: + return BatchResponse(_id=id_, messages=msgs) + + def append(self, message: Message) -> None: + self.messages.append(message) + + def items(self) -> list[Message]: + return self.messages + + @property + def id(self) -> str: + return self._id + + +class BatchResponses(Sequence[B]): + """ + Class to define a list of Batch Response objects. + + Args: + responses: list of Batch Response objects. + """ + + __slots__ = ("_responses",) + + def __init__(self, *responses: B): + self._responses = list(responses) or [] + + def __str__(self) -> str: + return str(self._responses) + + def __repr__(self) -> str: + return str(self) + + def __len__(self) -> int: + return len(self._responses) + + def __iter__(self) -> Iterator[B]: + return iter(self._responses) + + def __getitem__(self, index: int) -> M: + return self._responses[index] + + def append(self, response: BatchResponse) -> None: + self._responses.append(response) + + def items(self) -> list[BatchResponse]: + return self._responses + + +class BatchMapper(metaclass=ABCMeta): + """ + Provides an interface to write a Batch Mapper + which will be exposed over a gRPC server. + + Args: + + """ + + def __call__(self, *args, **kwargs): + """ + Allow to call handler function directly if class instance is sent + """ + return self.handler(*args, **kwargs) + + @abstractmethod + async def handler(self, datums: AsyncIterable[Datum]) -> BatchResponses: + """ + Implement this handler function which implements the BatchMapAsyncCallable interface. + """ + pass + + +BatchMapAsyncCallable = Callable[[AsyncIterable[Datum]], Awaitable[BatchResponses]] +BatchMapCallable = Union[BatchMapper, BatchMapAsyncCallable] diff --git a/pynumaflow/batchmapper/async_server.py b/pynumaflow/batchmapper/async_server.py new file mode 100644 index 00000000..861f6a12 --- /dev/null +++ b/pynumaflow/batchmapper/async_server.py @@ -0,0 +1,106 @@ +import aiorun +import grpc + +from pynumaflow._constants import ( + MAX_MESSAGE_SIZE, + NUM_THREADS_DEFAULT, + _LOGGER, + BATCH_MAP_SOCK_PATH, + BATCH_MAP_SERVER_INFO_FILE_PATH, + MAX_NUM_THREADS, +) +from pynumaflow.batchmapper._dtypes import BatchMapCallable +from pynumaflow.batchmapper.servicer.async_servicer import AsyncBatchMapServicer +from pynumaflow.proto.batchmapper import batchmap_pb2_grpc +from pynumaflow.shared.server import NumaflowServer, start_async_server + + +class BatchMapAsyncServer(NumaflowServer): + """ + Class for a new Batch Map Async Server instance. + """ + + def __init__( + self, + batch_mapper_instance: BatchMapCallable, + sock_path=BATCH_MAP_SOCK_PATH, + max_message_size=MAX_MESSAGE_SIZE, + max_threads=NUM_THREADS_DEFAULT, + server_info_file=BATCH_MAP_SERVER_INFO_FILE_PATH, + ): + """ + Create a new grpc Async Batch Map Server instance. + A new servicer instance is created and attached to the server. + The server instance is returned. + Args: + batch_mapper_instance: The batch map stream instance to be used for Batch Map UDF + sock_path: The UNIX socket path to be used for the server + max_message_size: The max message size in bytes the server can receive and send + max_threads: The max number of threads to be spawned; + defaults to 4 and max capped at 16 + + Example invocation: + class Flatmap(BatchMapper): + async def handler( + self, + datums: AsyncIterable[Datum], + ) -> BatchResponses: + batch_responses = BatchResponses() + async for datum in datums: + val = datum.value + _ = datum.event_time + _ = datum.watermark + strs = val.decode("utf-8").split(",") + batch_response = BatchResponse.from_id(datum.id) + if len(strs) == 0: + batch_response.append(Message.to_drop()) + else: + for s in strs: + batch_response.append(Message(str.encode(s))) + batch_responses.append(batch_response) + + return batch_responses + + if __name__ == "__main__": + grpc_server = BatchMapAsyncServer(Flatmap()) + grpc_server.start() + """ + self.batch_mapper_instance: BatchMapCallable = batch_mapper_instance + self.sock_path = f"unix://{sock_path}" + self.max_threads = min(max_threads, MAX_NUM_THREADS) + self.max_message_size = max_message_size + self.server_info_file = server_info_file + + self._server_options = [ + ("grpc.max_send_message_length", self.max_message_size), + ("grpc.max_receive_message_length", self.max_message_size), + ] + + self.servicer = AsyncBatchMapServicer(handler=self.batch_mapper_instance) + + def start(self): + """ + Starter function for the Async Batch Map server, we need a separate caller + to the aexec so that all the async coroutines can be started from a single context + """ + aiorun.run(self.aexec(), use_uvloop=True) + + async def aexec(self): + """ + Starts the Async gRPC server on the given UNIX socket with + given max threads. + """ + # As the server is async, we need to create a new server instance in the + # same thread as the event loop so that all the async calls are made in the + # same context + # Create a new async server instance and add the servicer to it + server = grpc.aio.server() + server.add_insecure_port(self.sock_path) + batchmap_pb2_grpc.add_BatchMapServicer_to_server( + self.servicer, + server, + ) + _LOGGER.info("Starting Batch Map Server") + await start_async_server( + server, self.sock_path, self.max_threads, self._server_options, self.server_info_file + ) diff --git a/pynumaflow/batchmapper/servicer/__init__.py b/pynumaflow/batchmapper/servicer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pynumaflow/batchmapper/servicer/async_servicer.py b/pynumaflow/batchmapper/servicer/async_servicer.py new file mode 100644 index 00000000..6daf7b5a --- /dev/null +++ b/pynumaflow/batchmapper/servicer/async_servicer.py @@ -0,0 +1,134 @@ +import asyncio +from collections.abc import AsyncIterable + +import grpc +from google.protobuf import empty_pb2 as _empty_pb2 + +from pynumaflow.batchmapper import Datum +from pynumaflow.batchmapper._dtypes import BatchMapCallable +from pynumaflow.proto.batchmapper import batchmap_pb2, batchmap_pb2_grpc +from pynumaflow.shared.asynciter import NonBlockingIterator +from pynumaflow.shared.server import exit_on_error +from pynumaflow.types import NumaflowServicerContext +from pynumaflow._constants import _LOGGER, STREAM_EOF + + +async def datum_generator( + request_iterator: AsyncIterable[batchmap_pb2.BatchMapRequest], +) -> AsyncIterable[Datum]: + """ + This function is used to create an async generator + from the gRPC request iterator. + It yields a Datum instance for each request received which is then + forwarded to the UDF. + """ + async for d in request_iterator: + request = Datum( + keys=d.keys, + value=d.value, + event_time=d.event_time.ToDatetime(), + watermark=d.watermark.ToDatetime(), + headers=dict(d.headers), + id=d.id, + ) + yield request + + +class AsyncBatchMapServicer(batchmap_pb2_grpc.BatchMapServicer): + """ + This class is used to create a new grpc Batch Map Servicer instance. + It implements the BatchMapServicer interface from the proto + batchmap_pb2_grpc.py file. + Provides the functionality for the required rpc methods. + """ + + def __init__( + self, + handler: BatchMapCallable, + ): + self.background_tasks = set() + self.__batch_map_handler: BatchMapCallable = handler + + async def BatchMapFn( + self, + request_iterator: AsyncIterable[batchmap_pb2.BatchMapRequest], + context: NumaflowServicerContext, + ) -> batchmap_pb2.BatchMapResponse: + """ + Applies a batch map function to a BatchMapRequest stream in a batching mode. + The pascal case function name comes from the proto batchmap_pb2_grpc.py file. + """ + # Create an async iterator from the request iterator + datum_iterator = datum_generator(request_iterator=request_iterator) + + try: + # invoke the UDF call for batch map + responses, request_counter = await self.invoke_batch_map(datum_iterator) + + # If the number of responses received does not align with the request batch size, + # we will not be able to process the data correctly. + # This should be marked as an error and raised to the user. + if len(responses) != request_counter: + err_msg = "batchMapFn: mismatch between length of batch requests and responses" + raise Exception(err_msg) + + # iterate over the responses received and covert to the required proto format + for batch_response in responses: + single_req_resp = [] + for msg in batch_response.messages: + single_req_resp.append( + batchmap_pb2.BatchMapResponse.Result( + keys=msg.keys, value=msg.value, tags=msg.tags + ) + ) + + # send the response for a given ID back to the stream + yield batchmap_pb2.BatchMapResponse(id=batch_response.id, results=single_req_resp) + + except BaseException as err: + _LOGGER.critical("UDFError, re-raising the error", exc_info=True) + await asyncio.gather( + context.abort(grpc.StatusCode.UNKNOWN, details=repr(err)), return_exceptions=True + ) + exit_on_error(context, repr(err)) + return + + async def invoke_batch_map(self, datum_iterator: AsyncIterable[Datum]): + """ + # iterate over the incoming requests, and keep sending to the user code + # once all messages have been sent, we wait for the responses + """ + # create a message queue to send to the user code + niter = NonBlockingIterator() + riter = niter.read_iterator() + # create a task for invoking the UDF handler + task = asyncio.create_task(self.__batch_map_handler(riter)) + # Save a reference to the result of this function, to avoid a + # task disappearing mid-execution. + self.background_tasks.add(task) + task.add_done_callback(lambda t: self.background_tasks.remove(t)) + + req_count = 0 + # start streaming the messages to the UDF code, and increment the request counter + async for datum in datum_iterator: + await niter.put(datum) + req_count += 1 + + # once all messages have been exhausted, send an EOF to indicate end of messages + # to the UDF + await niter.put(STREAM_EOF) + + # wait for all the responses + await task + + # return the result from the UDF, along with the request_counter + return task.result(), req_count + + async def IsReady( + self, request: _empty_pb2.Empty, context: NumaflowServicerContext + ) -> batchmap_pb2.ReadyResponse: + """ + IsReady is the heartbeat endpoint for gRPC. + The pascal case function name comes from the proto batchmap_pb2_grpc.py file. + """ + return batchmap_pb2.ReadyResponse(ready=True) diff --git a/pynumaflow/mapper/async_server.py b/pynumaflow/mapper/async_server.py index 53f30b0c..58bd6c04 100644 --- a/pynumaflow/mapper/async_server.py +++ b/pynumaflow/mapper/async_server.py @@ -1,13 +1,12 @@ -import os - import aiorun import grpc from pynumaflow._constants import ( - MAX_THREADS, + NUM_THREADS_DEFAULT, MAX_MESSAGE_SIZE, MAP_SOCK_PATH, MAP_SERVER_INFO_FILE_PATH, + MAX_NUM_THREADS, ) from pynumaflow.mapper._dtypes import MapAsyncCallable from pynumaflow.mapper.servicer.async_servicer import AsyncMapServicer @@ -26,7 +25,7 @@ class MapAsyncServer(NumaflowServer): sock_path: The UNIX socket path to be used for the server max_message_size: The max message size in bytes the server can receive and send max_threads: The max number of threads to be spawned; - defaults to number of processors x4 + defaults to 4 and max capped at 16 Example invocation: from pynumaflow.mapper import Messages, Message, Datum, MapAsyncServer @@ -50,7 +49,7 @@ def __init__( mapper_instance: MapAsyncCallable, sock_path=MAP_SOCK_PATH, max_message_size=MAX_MESSAGE_SIZE, - max_threads=MAX_THREADS, + max_threads=NUM_THREADS_DEFAULT, server_info_file=MAP_SERVER_INFO_FILE_PATH, ): """ @@ -62,10 +61,10 @@ def __init__( sock_path: The UNIX socket path to be used for the server max_message_size: The max message size in bytes the server can receive and send max_threads: The max number of threads to be spawned; - defaults to number of processors x4 + defaults to 4 and max capped at 16 """ self.sock_path = f"unix://{sock_path}" - self.max_threads = min(max_threads, int(os.getenv("MAX_THREADS", "4"))) + self.max_threads = min(max_threads, MAX_NUM_THREADS) self.max_message_size = max_message_size self.server_info_file = server_info_file diff --git a/pynumaflow/mapper/multiproc_server.py b/pynumaflow/mapper/multiproc_server.py index a198b5e7..b5a70d02 100644 --- a/pynumaflow/mapper/multiproc_server.py +++ b/pynumaflow/mapper/multiproc_server.py @@ -1,12 +1,11 @@ -import os - from pynumaflow._constants import ( - MAX_THREADS, + NUM_THREADS_DEFAULT, MAX_MESSAGE_SIZE, MAP_SOCK_PATH, UDFType, _PROCESS_COUNT, MAP_SERVER_INFO_FILE_PATH, + MAX_NUM_THREADS, ) from pynumaflow.mapper._dtypes import MapSyncCallable from pynumaflow.mapper.servicer.sync_servicer import SyncMapServicer @@ -27,7 +26,7 @@ def __init__( server_count: int = _PROCESS_COUNT, sock_path=MAP_SOCK_PATH, max_message_size=MAX_MESSAGE_SIZE, - max_threads=MAX_THREADS, + max_threads=NUM_THREADS_DEFAULT, server_info_file=MAP_SERVER_INFO_FILE_PATH, ): """ @@ -40,7 +39,7 @@ def __init__( sock_path: The UNIX socket path to be used for the server max_message_size: The max message size in bytes the server can receive and send max_threads: The max number of threads to be spawned; - defaults to number of processors x4 + defaults to 4 and max capped at 16 Example invocation: import math @@ -66,9 +65,7 @@ def handler(self, keys: list[str], datum: Datum) -> Messages: return messages if __name__ == "__main__": - # To set the env server_count value set the env variable - # NUM_CPU_MULTIPROC="N" - server_count = int(os.getenv("NUM_CPU_MULTIPROC", "2")) + server_count = 2 prime_class = PrimeMap() # Server count is the number of server processes to start grpc_server = MapMultiprocServer(prime_class, server_count=server_count) @@ -76,7 +73,7 @@ def handler(self, keys: list[str], datum: Datum) -> Messages: """ self.sock_path = f"unix://{sock_path}" - self.max_threads = min(max_threads, int(os.getenv("MAX_THREADS", "4"))) + self.max_threads = min(max_threads, MAX_NUM_THREADS) self.max_message_size = max_message_size self.server_info_file = server_info_file @@ -89,7 +86,7 @@ def handler(self, keys: list[str], datum: Datum) -> Messages: ("grpc.so_reuseaddr", 1), ] # Set the number of processes to be spawned to the number of CPUs or - # the value of the env var NUM_CPU_MULTIPROC defined by the user + # the value of the parameter server_count defined by the user # Setting the max value to 2 * CPU count # Used for multiproc server self._process_count = min(server_count, 2 * _PROCESS_COUNT) @@ -99,9 +96,8 @@ def start(self) -> None: """ Starts the N grpc servers gRPC serves on the with given max threads. - where N = The number of CPUs or the - value of the env var NUM_CPU_MULTIPROC defined by the user. The max value - is set to 2 * CPU count. + where N = The number of CPUs or the value of the parameter server_count + defined by the user. The max value is capped to 2 * CPU count. """ # Start the multiproc server diff --git a/pynumaflow/mapper/sync_server.py b/pynumaflow/mapper/sync_server.py index c5377efa..48b909f5 100644 --- a/pynumaflow/mapper/sync_server.py +++ b/pynumaflow/mapper/sync_server.py @@ -1,15 +1,13 @@ -import os - - from pynumaflow.mapper.servicer.sync_servicer import SyncMapServicer from pynumaflow._constants import ( - MAX_THREADS, + NUM_THREADS_DEFAULT, MAX_MESSAGE_SIZE, _LOGGER, MAP_SOCK_PATH, UDFType, MAP_SERVER_INFO_FILE_PATH, + MAX_NUM_THREADS, ) from pynumaflow.mapper._dtypes import MapSyncCallable @@ -27,7 +25,7 @@ class MapServer(NumaflowServer): sock_path: The UNIX socket path to be used for the server max_message_size: The max message size in bytes the server can receive and send max_threads: The max number of threads to be spawned; - defaults to number of processors x4 + defaults to 4 and max capped at 16 Example Invocation: from pynumaflow.mapper import Messages, Message, Datum, MapServer, Mapper @@ -65,7 +63,7 @@ def __init__( mapper_instance: MapSyncCallable, sock_path=MAP_SOCK_PATH, max_message_size=MAX_MESSAGE_SIZE, - max_threads=MAX_THREADS, + max_threads=NUM_THREADS_DEFAULT, server_info_file=MAP_SERVER_INFO_FILE_PATH, ): """ @@ -77,10 +75,10 @@ def __init__( sock_path: The UNIX socket path to be used for the server max_message_size: The max message size in bytes the server can receive and send max_threads: The max number of threads to be spawned; - defaults to number of processors x4 + defaults to 4 and max capped at 16 """ self.sock_path = f"unix://{sock_path}" - self.max_threads = min(max_threads, int(os.getenv("MAX_THREADS", "4"))) + self.max_threads = min(max_threads, MAX_NUM_THREADS) self.max_message_size = max_message_size self.server_info_file = server_info_file diff --git a/pynumaflow/mapstreamer/async_server.py b/pynumaflow/mapstreamer/async_server.py index eb09a181..db368147 100644 --- a/pynumaflow/mapstreamer/async_server.py +++ b/pynumaflow/mapstreamer/async_server.py @@ -1,5 +1,3 @@ -import os - import aiorun import grpc @@ -9,9 +7,10 @@ from pynumaflow._constants import ( MAP_STREAM_SOCK_PATH, MAX_MESSAGE_SIZE, - MAX_THREADS, + NUM_THREADS_DEFAULT, _LOGGER, MAP_STREAM_SERVER_INFO_FILE_PATH, + MAX_NUM_THREADS, ) from pynumaflow.mapstreamer._dtypes import MapStreamCallable @@ -29,7 +28,7 @@ def __init__( map_stream_instance: MapStreamCallable, sock_path=MAP_STREAM_SOCK_PATH, max_message_size=MAX_MESSAGE_SIZE, - max_threads=MAX_THREADS, + max_threads=NUM_THREADS_DEFAULT, server_info_file=MAP_STREAM_SERVER_INFO_FILE_PATH, ): """ @@ -41,7 +40,7 @@ def __init__( sock_path: The UNIX socket path to be used for the server max_message_size: The max message size in bytes the server can receive and send max_threads: The max number of threads to be spawned; - defaults to number of processors x4 + defaults to 4 and max capped at 16 server_type: The type of server to be used Example invocation: @@ -87,7 +86,7 @@ async def map_stream_handler(_: list[str], datum: Datum) -> AsyncIterable[Messag """ self.map_stream_instance: MapStreamCallable = map_stream_instance self.sock_path = f"unix://{sock_path}" - self.max_threads = min(max_threads, int(os.getenv("MAX_THREADS", "4"))) + self.max_threads = min(max_threads, MAX_NUM_THREADS) self.max_message_size = max_message_size self.server_info_file = server_info_file diff --git a/pynumaflow/proto/batchmapper/__init__.py b/pynumaflow/proto/batchmapper/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pynumaflow/proto/batchmapper/batchmap.proto b/pynumaflow/proto/batchmapper/batchmap.proto new file mode 100644 index 00000000..82d5672d --- /dev/null +++ b/pynumaflow/proto/batchmapper/batchmap.proto @@ -0,0 +1,50 @@ +syntax = "proto3"; + +import "google/protobuf/empty.proto"; +import "google/protobuf/timestamp.proto"; + +package batchmap.v1; + +service BatchMap { + // IsReady is the heartbeat endpoint for gRPC. + rpc IsReady(google.protobuf.Empty) returns (ReadyResponse); + + // BatchMapFn is a bi-directional streaming rpc which applies a + // Map function on each BatchMapRequest element of the stream and then returns streams + // back BatchMapResponse elements. + rpc BatchMapFn(stream BatchMapRequest) returns (stream BatchMapResponse); +} + +/** + * BatchMapRequest represents a request element. + */ +message BatchMapRequest { + repeated string keys = 1; + bytes value = 2; + google.protobuf.Timestamp event_time = 3; + google.protobuf.Timestamp watermark = 4; + map headers = 5; + // This ID is used uniquely identify a map request + string id = 6; +} + +/** + * BatchMapResponse represents a response element. + */ +message BatchMapResponse { + message Result { + repeated string keys = 1; + bytes value = 2; + repeated string tags = 3; + } + repeated Result results = 1; + // This ID is used to refer the responses to the request it corresponds to. + string id = 2; +} + +/** + * ReadyResponse is the health check result. + */ +message ReadyResponse { + bool ready = 1; +} \ No newline at end of file diff --git a/pynumaflow/proto/batchmapper/batchmap_pb2.py b/pynumaflow/proto/batchmapper/batchmap_pb2.py new file mode 100644 index 00000000..b25383d5 --- /dev/null +++ b/pynumaflow/proto/batchmapper/batchmap_pb2.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: batchmap.proto +# Protobuf Python Version: 4.25.1 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 +from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x0e\x62\x61tchmap.proto\x12\x0b\x62\x61tchmap.v1\x1a\x1bgoogle/protobuf/empty.proto\x1a\x1fgoogle/protobuf/timestamp.proto"\x85\x02\n\x0f\x42\x61tchMapRequest\x12\x0c\n\x04keys\x18\x01 \x03(\t\x12\r\n\x05value\x18\x02 \x01(\x0c\x12.\n\nevent_time\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12-\n\twatermark\x18\x04 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12:\n\x07headers\x18\x05 \x03(\x0b\x32).batchmap.v1.BatchMapRequest.HeadersEntry\x12\n\n\x02id\x18\x06 \x01(\t\x1a.\n\x0cHeadersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01"\x8a\x01\n\x10\x42\x61tchMapResponse\x12\x35\n\x07results\x18\x01 \x03(\x0b\x32$.batchmap.v1.BatchMapResponse.Result\x12\n\n\x02id\x18\x02 \x01(\t\x1a\x33\n\x06Result\x12\x0c\n\x04keys\x18\x01 \x03(\t\x12\r\n\x05value\x18\x02 \x01(\x0c\x12\x0c\n\x04tags\x18\x03 \x03(\t"\x1e\n\rReadyResponse\x12\r\n\x05ready\x18\x01 \x01(\x08\x32\x98\x01\n\x08\x42\x61tchMap\x12=\n\x07IsReady\x12\x16.google.protobuf.Empty\x1a\x1a.batchmap.v1.ReadyResponse\x12M\n\nBatchMapFn\x12\x1c.batchmap.v1.BatchMapRequest\x1a\x1d.batchmap.v1.BatchMapResponse(\x01\x30\x01\x62\x06proto3' +) + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "batchmap_pb2", _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _globals["_BATCHMAPREQUEST_HEADERSENTRY"]._options = None + _globals["_BATCHMAPREQUEST_HEADERSENTRY"]._serialized_options = b"8\001" + _globals["_BATCHMAPREQUEST"]._serialized_start = 94 + _globals["_BATCHMAPREQUEST"]._serialized_end = 355 + _globals["_BATCHMAPREQUEST_HEADERSENTRY"]._serialized_start = 309 + _globals["_BATCHMAPREQUEST_HEADERSENTRY"]._serialized_end = 355 + _globals["_BATCHMAPRESPONSE"]._serialized_start = 358 + _globals["_BATCHMAPRESPONSE"]._serialized_end = 496 + _globals["_BATCHMAPRESPONSE_RESULT"]._serialized_start = 445 + _globals["_BATCHMAPRESPONSE_RESULT"]._serialized_end = 496 + _globals["_READYRESPONSE"]._serialized_start = 498 + _globals["_READYRESPONSE"]._serialized_end = 528 + _globals["_BATCHMAP"]._serialized_start = 531 + _globals["_BATCHMAP"]._serialized_end = 683 +# @@protoc_insertion_point(module_scope) diff --git a/pynumaflow/proto/batchmapper/batchmap_pb2_grpc.py b/pynumaflow/proto/batchmapper/batchmap_pb2_grpc.py new file mode 100644 index 00000000..d3614d5a --- /dev/null +++ b/pynumaflow/proto/batchmapper/batchmap_pb2_grpc.py @@ -0,0 +1,128 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +from . import batchmap_pb2 as batchmap__pb2 +from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 + + +class BatchMapStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.IsReady = channel.unary_unary( + "/batchmap.v1.BatchMap/IsReady", + request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, + response_deserializer=batchmap__pb2.ReadyResponse.FromString, + ) + self.BatchMapFn = channel.stream_stream( + "/batchmap.v1.BatchMap/BatchMapFn", + request_serializer=batchmap__pb2.BatchMapRequest.SerializeToString, + response_deserializer=batchmap__pb2.BatchMapResponse.FromString, + ) + + +class BatchMapServicer(object): + """Missing associated documentation comment in .proto file.""" + + def IsReady(self, request, context): + """IsReady is the heartbeat endpoint for gRPC.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def BatchMapFn(self, request_iterator, context): + """BatchMapFn is a bi-directional streaming rpc which applies a + Map function on each BatchMapRequest element of the stream and then returns streams + back BatchMapResponse elements. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + +def add_BatchMapServicer_to_server(servicer, server): + rpc_method_handlers = { + "IsReady": grpc.unary_unary_rpc_method_handler( + servicer.IsReady, + request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, + response_serializer=batchmap__pb2.ReadyResponse.SerializeToString, + ), + "BatchMapFn": grpc.stream_stream_rpc_method_handler( + servicer.BatchMapFn, + request_deserializer=batchmap__pb2.BatchMapRequest.FromString, + response_serializer=batchmap__pb2.BatchMapResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + "batchmap.v1.BatchMap", rpc_method_handlers + ) + server.add_generic_rpc_handlers((generic_handler,)) + + +# This class is part of an EXPERIMENTAL API. +class BatchMap(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def IsReady( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/batchmap.v1.BatchMap/IsReady", + google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, + batchmap__pb2.ReadyResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + + @staticmethod + def BatchMapFn( + request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.stream_stream( + request_iterator, + target, + "/batchmap.v1.BatchMap/BatchMapFn", + batchmap__pb2.BatchMapRequest.SerializeToString, + batchmap__pb2.BatchMapResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) diff --git a/pynumaflow/reducer/async_server.py b/pynumaflow/reducer/async_server.py index 3bb43414..57f9256b 100644 --- a/pynumaflow/reducer/async_server.py +++ b/pynumaflow/reducer/async_server.py @@ -11,9 +11,10 @@ from pynumaflow._constants import ( REDUCE_SOCK_PATH, MAX_MESSAGE_SIZE, - MAX_THREADS, + NUM_THREADS_DEFAULT, _LOGGER, REDUCE_SERVER_INFO_FILE_PATH, + MAX_NUM_THREADS, ) from pynumaflow.reducer._dtypes import ( @@ -60,7 +61,7 @@ class ReduceAsyncServer(NumaflowServer): sock_path: The UNIX socket path to be used for the server max_message_size: The max message size in bytes the server can receive and send max_threads: The max number of threads to be spawned; - defaults to number of processors x4 + defaults to 4 and max capped at 16 Example invocation: import os from collections.abc import AsyncIterable @@ -118,7 +119,7 @@ def __init__( init_kwargs: dict = None, sock_path=REDUCE_SOCK_PATH, max_message_size=MAX_MESSAGE_SIZE, - max_threads=MAX_THREADS, + max_threads=NUM_THREADS_DEFAULT, server_info_file=REDUCE_SERVER_INFO_FILE_PATH, ): """ @@ -130,7 +131,7 @@ def __init__( sock_path: The UNIX socket path to be used for the server max_message_size: The max message size in bytes the server can receive and send max_threads: The max number of threads to be spawned; - defaults to number of processors x4 + defaults to 4 and max capped at 16 server_type: The type of server to be used """ if init_kwargs is None: @@ -138,7 +139,7 @@ def __init__( self.reducer_handler = get_handler(reducer_handler, init_args, init_kwargs) self.sock_path = f"unix://{sock_path}" self.max_message_size = max_message_size - self.max_threads = max_threads + self.max_threads = min(max_threads, MAX_NUM_THREADS) self.server_info_file = server_info_file self._server_options = [ diff --git a/pynumaflow/reducestreamer/async_server.py b/pynumaflow/reducestreamer/async_server.py index b14f50e6..e7d5bc7e 100644 --- a/pynumaflow/reducestreamer/async_server.py +++ b/pynumaflow/reducestreamer/async_server.py @@ -10,10 +10,11 @@ from pynumaflow._constants import ( MAX_MESSAGE_SIZE, - MAX_THREADS, + NUM_THREADS_DEFAULT, _LOGGER, REDUCE_STREAM_SOCK_PATH, REDUCE_STREAM_SERVER_INFO_FILE_PATH, + MAX_NUM_THREADS, ) from pynumaflow.reducestreamer._dtypes import ( @@ -66,7 +67,7 @@ class ReduceStreamAsyncServer(NumaflowServer): sock_path: The UNIX socket path to be used for the server max_message_size: The max message size in bytes the server can receive and send max_threads: The max number of threads to be spawned; - defaults to number of processors x4 + defaults to 4 and max capped at 16 server_info_file: The path to the server info file Example invocation: import os @@ -131,7 +132,7 @@ def __init__( init_kwargs: dict = None, sock_path=REDUCE_STREAM_SOCK_PATH, max_message_size=MAX_MESSAGE_SIZE, - max_threads=MAX_THREADS, + max_threads=NUM_THREADS_DEFAULT, server_info_file=REDUCE_STREAM_SERVER_INFO_FILE_PATH, ): """ @@ -147,7 +148,7 @@ def __init__( sock_path: The UNIX socket path to be used for the server max_message_size: The max message size in bytes the server can receive and send max_threads: The max number of threads to be spawned; - defaults to number of processors x4 + defaults to 4 and max capped at 16 server_info_file: The path to the server info file """ if init_kwargs is None: @@ -155,7 +156,7 @@ def __init__( self.reduce_stream_handler = get_handler(reduce_stream_handler, init_args, init_kwargs) self.sock_path = f"unix://{sock_path}" self.max_message_size = max_message_size - self.max_threads = max_threads + self.max_threads = min(max_threads, MAX_NUM_THREADS) self.server_info_file = server_info_file self._server_options = [ diff --git a/pynumaflow/sideinput/server.py b/pynumaflow/sideinput/server.py index 554f7c64..901d1ef6 100644 --- a/pynumaflow/sideinput/server.py +++ b/pynumaflow/sideinput/server.py @@ -1,16 +1,16 @@ -import os from pynumaflow.shared import NumaflowServer from pynumaflow.shared.server import sync_server_start from pynumaflow.sideinput._dtypes import RetrieverCallable from pynumaflow.sideinput.servicer.servicer import SideInputServicer from pynumaflow._constants import ( - MAX_THREADS, + NUM_THREADS_DEFAULT, MAX_MESSAGE_SIZE, SIDE_INPUT_SOCK_PATH, _LOGGER, UDFType, SIDE_INPUT_DIR_PATH, SIDE_INPUT_SERVER_INFO_FILE_PATH, + MAX_NUM_THREADS, ) @@ -54,12 +54,12 @@ def __init__( side_input_instance: RetrieverCallable, sock_path=SIDE_INPUT_SOCK_PATH, max_message_size=MAX_MESSAGE_SIZE, - max_threads=MAX_THREADS, + max_threads=NUM_THREADS_DEFAULT, side_input_dir_path=SIDE_INPUT_DIR_PATH, server_info_file=SIDE_INPUT_SERVER_INFO_FILE_PATH, ): self.sock_path = f"unix://{sock_path}" - self.max_threads = min(max_threads, int(os.getenv("MAX_THREADS", "4"))) + self.max_threads = min(max_threads, MAX_NUM_THREADS) self.max_message_size = max_message_size self.server_info_file = server_info_file diff --git a/pynumaflow/sinker/async_server.py b/pynumaflow/sinker/async_server.py index 4d272e28..e3080f86 100644 --- a/pynumaflow/sinker/async_server.py +++ b/pynumaflow/sinker/async_server.py @@ -10,13 +10,14 @@ from pynumaflow._constants import ( SINK_SOCK_PATH, MAX_MESSAGE_SIZE, - MAX_THREADS, + NUM_THREADS_DEFAULT, SINK_SERVER_INFO_FILE_PATH, ENV_UD_CONTAINER_TYPE, UD_CONTAINER_FALLBACK_SINK, _LOGGER, FALLBACK_SINK_SOCK_PATH, FALLBACK_SINK_SERVER_INFO_FILE_PATH, + MAX_NUM_THREADS, ) from pynumaflow.shared.server import NumaflowServer, start_async_server @@ -34,7 +35,7 @@ class SinkAsyncServer(NumaflowServer): sock_path: The UNIX socket path to be used for the server max_message_size: The max message size in bytes the server can receive and send max_threads: The max number of threads to be spawned; - defaults to number of processors x4 + defaults to 4 and max capped at 16 Example invocation: import os @@ -76,7 +77,7 @@ def __init__( sinker_instance: AsyncSinkCallable, sock_path=SINK_SOCK_PATH, max_message_size=MAX_MESSAGE_SIZE, - max_threads=MAX_THREADS, + max_threads=NUM_THREADS_DEFAULT, server_info_file=SINK_SERVER_INFO_FILE_PATH, ): # If the container type is fallback sink, then use the fallback sink address and path. @@ -86,7 +87,7 @@ def __init__( server_info_file = FALLBACK_SINK_SERVER_INFO_FILE_PATH self.sock_path = f"unix://{sock_path}" - self.max_threads = min(max_threads, int(os.getenv("MAX_THREADS", "4"))) + self.max_threads = min(max_threads, MAX_NUM_THREADS) self.max_message_size = max_message_size self.server_info_file = server_info_file diff --git a/pynumaflow/sinker/server.py b/pynumaflow/sinker/server.py index a501971e..67c7f646 100644 --- a/pynumaflow/sinker/server.py +++ b/pynumaflow/sinker/server.py @@ -6,7 +6,7 @@ from pynumaflow._constants import ( SINK_SOCK_PATH, MAX_MESSAGE_SIZE, - MAX_THREADS, + NUM_THREADS_DEFAULT, _LOGGER, UDFType, SINK_SERVER_INFO_FILE_PATH, @@ -14,6 +14,7 @@ UD_CONTAINER_FALLBACK_SINK, FALLBACK_SINK_SOCK_PATH, FALLBACK_SINK_SERVER_INFO_FILE_PATH, + MAX_NUM_THREADS, ) from pynumaflow.shared.server import NumaflowServer, sync_server_start @@ -30,7 +31,7 @@ def __init__( sinker_instance: SyncSinkCallable, sock_path=SINK_SOCK_PATH, max_message_size=MAX_MESSAGE_SIZE, - max_threads=MAX_THREADS, + max_threads=NUM_THREADS_DEFAULT, server_info_file=SINK_SERVER_INFO_FILE_PATH, ): """ @@ -42,7 +43,7 @@ def __init__( sock_path: The UNIX socket path to be used for the server max_message_size: The max message size in bytes the server can receive and send max_threads: The max number of threads to be spawned; - defaults to number of processors x4 + defaults to 4 and max capped at 16 Example invocation: import os from collections.abc import Iterator @@ -83,7 +84,7 @@ def udsink_handler(datums: Iterator[Datum]) -> Responses: server_info_file = FALLBACK_SINK_SERVER_INFO_FILE_PATH self.sock_path = f"unix://{sock_path}" - self.max_threads = min(max_threads, int(os.getenv("MAX_THREADS", "4"))) + self.max_threads = min(max_threads, MAX_NUM_THREADS) self.max_message_size = max_message_size self.server_info_file = server_info_file diff --git a/pynumaflow/sourcer/async_server.py b/pynumaflow/sourcer/async_server.py index 5f9a2686..c61080a9 100644 --- a/pynumaflow/sourcer/async_server.py +++ b/pynumaflow/sourcer/async_server.py @@ -1,5 +1,3 @@ -import os - import aiorun import grpc from pynumaflow.sourcer.servicer.async_servicer import AsyncSourceServicer @@ -7,8 +5,9 @@ from pynumaflow._constants import ( SOURCE_SOCK_PATH, MAX_MESSAGE_SIZE, - MAX_THREADS, + NUM_THREADS_DEFAULT, SOURCE_SERVER_INFO_FILE_PATH, + MAX_NUM_THREADS, ) from pynumaflow.proto.sourcer import source_pb2_grpc @@ -26,7 +25,7 @@ def __init__( sourcer_instance: SourceCallable, sock_path=SOURCE_SOCK_PATH, max_message_size=MAX_MESSAGE_SIZE, - max_threads=MAX_THREADS, + max_threads=NUM_THREADS_DEFAULT, server_info_file=SOURCE_SERVER_INFO_FILE_PATH, ): """ @@ -38,7 +37,7 @@ def __init__( sock_path: The UNIX socket path to be used for the server max_message_size: The max message size in bytes the server can receive and send max_threads: The max number of threads to be spawned; - defaults to number of processors x4 + defaults to 4 and max capped at 16 Example invocation: from collections.abc import AsyncIterable @@ -103,7 +102,7 @@ async def partitions_handler(self) -> PartitionsResponse: """ self.sock_path = f"unix://{sock_path}" - self.max_threads = min(max_threads, int(os.getenv("MAX_THREADS", "4"))) + self.max_threads = min(max_threads, MAX_NUM_THREADS) self.max_message_size = max_message_size self.server_info_file = server_info_file diff --git a/pynumaflow/sourcer/server.py b/pynumaflow/sourcer/server.py index 8c4699c3..d4961191 100644 --- a/pynumaflow/sourcer/server.py +++ b/pynumaflow/sourcer/server.py @@ -1,12 +1,11 @@ -import os - from pynumaflow._constants import ( SOURCE_SOCK_PATH, MAX_MESSAGE_SIZE, - MAX_THREADS, + NUM_THREADS_DEFAULT, _LOGGER, UDFType, SOURCE_SERVER_INFO_FILE_PATH, + MAX_NUM_THREADS, ) from pynumaflow.shared.server import NumaflowServer, sync_server_start from pynumaflow.sourcer._dtypes import SourceCallable @@ -23,7 +22,7 @@ def __init__( sourcer_instance: SourceCallable, sock_path=SOURCE_SOCK_PATH, max_message_size=MAX_MESSAGE_SIZE, - max_threads=MAX_THREADS, + max_threads=NUM_THREADS_DEFAULT, server_info_file=SOURCE_SERVER_INFO_FILE_PATH, ): """ @@ -35,7 +34,7 @@ def __init__( sock_path: The UNIX socket path to be used for the server max_message_size: The max message size in bytes the server can receive and send max_threads: The max number of threads to be spawned; - defaults to number of processors x4 + defaults to 4 and max capped at 16 Example invocation: from collections.abc import Iterable @@ -101,7 +100,7 @@ def partitions_handler(self) -> PartitionsResponse: grpc_server.start() """ self.sock_path = f"unix://{sock_path}" - self.max_threads = min(max_threads, int(os.getenv("MAX_THREADS", "4"))) + self.max_threads = min(max_threads, MAX_NUM_THREADS) self.max_message_size = max_message_size self.server_info_file = server_info_file diff --git a/pynumaflow/sourcetransformer/multiproc_server.py b/pynumaflow/sourcetransformer/multiproc_server.py index 5c617a4a..42c09fe5 100644 --- a/pynumaflow/sourcetransformer/multiproc_server.py +++ b/pynumaflow/sourcetransformer/multiproc_server.py @@ -1,5 +1,3 @@ -import os - from pynumaflow.sourcetransformer.servicer.server import SourceTransformServicer from pynumaflow.shared.server import start_multiproc_server @@ -7,10 +5,11 @@ from pynumaflow._constants import ( MAX_MESSAGE_SIZE, SOURCE_TRANSFORMER_SOCK_PATH, - MAX_THREADS, + NUM_THREADS_DEFAULT, UDFType, _PROCESS_COUNT, SOURCE_TRANSFORMER_SERVER_INFO_FILE_PATH, + MAX_NUM_THREADS, ) from pynumaflow.sourcetransformer._dtypes import SourceTransformCallable @@ -29,20 +28,21 @@ def __init__( server_count: int = _PROCESS_COUNT, sock_path=SOURCE_TRANSFORMER_SOCK_PATH, max_message_size=MAX_MESSAGE_SIZE, - max_threads=MAX_THREADS, + max_threads=NUM_THREADS_DEFAULT, server_info_file=SOURCE_TRANSFORMER_SERVER_INFO_FILE_PATH, ): """ - Create a new grpc Source Transformer Server instance. + Create a new grpc Source Transformer Multiproc Server instance. A new servicer instance is created and attached to the server. The server instance is returned. Args: source_transform_instance: The source transformer instance to be used for Source Transformer UDF sock_path: The UNIX socket path to be used for the server + server_count: The number of grpc server instances to be forked for multiproc max_message_size: The max message size in bytes the server can receive and send max_threads: The max number of threads to be spawned; - defaults to number of processors x4 + defaults to 4 and max capped at 16 Example invocation: import datetime @@ -93,11 +93,12 @@ def my_handler(keys: list[str], datum: Datum) -> Messages: return messages if __name__ == "__main__": - grpc_server = SourceTransformServer(my_handler) + grpc_server = SourceTransformMultiProcServer(source_transform_instance=my_handler + ,server_count = 2) grpc_server.start() """ self.sock_path = f"unix://{sock_path}" - self.max_threads = min(max_threads, int(os.getenv("MAX_THREADS", "4"))) + self.max_threads = min(max_threads, MAX_NUM_THREADS) self.max_message_size = max_message_size self.server_info_file = server_info_file @@ -110,7 +111,7 @@ def my_handler(keys: list[str], datum: Datum) -> Messages: ("grpc.so_reuseaddr", 1), ] # Set the number of processes to be spawned to the number of CPUs or - # the value of the env var NUM_CPU_MULTIPROC defined by the user + # the value of the parameter server_count defined by the user # Setting the max value to 2 * CPU count # Used for multiproc server self._process_count = min(server_count, 2 * _PROCESS_COUNT) @@ -118,8 +119,10 @@ def my_handler(keys: list[str], datum: Datum) -> Messages: def start(self): """ - Starts the Multiproc gRPC server on the given TCP sockets - with given max threads. + Starts the N grpc servers gRPC serves on the with + given max threads. + where N = The number of CPUs or the value of the parameter server_count + defined by the user. The max value is capped to 2 * CPU count. """ start_multiproc_server( max_threads=self.max_threads, diff --git a/pynumaflow/sourcetransformer/server.py b/pynumaflow/sourcetransformer/server.py index 9a921db3..e7e24db2 100644 --- a/pynumaflow/sourcetransformer/server.py +++ b/pynumaflow/sourcetransformer/server.py @@ -1,12 +1,11 @@ -import os - from pynumaflow._constants import ( MAX_MESSAGE_SIZE, SOURCE_TRANSFORMER_SOCK_PATH, - MAX_THREADS, + NUM_THREADS_DEFAULT, _LOGGER, UDFType, SOURCE_TRANSFORMER_SERVER_INFO_FILE_PATH, + MAX_NUM_THREADS, ) from pynumaflow.shared import NumaflowServer from pynumaflow.shared.server import sync_server_start @@ -24,7 +23,7 @@ def __init__( source_transform_instance: SourceTransformCallable, sock_path=SOURCE_TRANSFORMER_SOCK_PATH, max_message_size=MAX_MESSAGE_SIZE, - max_threads=MAX_THREADS, + max_threads=NUM_THREADS_DEFAULT, server_info_file=SOURCE_TRANSFORMER_SERVER_INFO_FILE_PATH, ): """ @@ -37,7 +36,7 @@ def __init__( sock_path: The UNIX socket path to be used for the server max_message_size: The max message size in bytes the server can receive and send max_threads: The max number of threads to be spawned; - defaults to number of processors x4 + defaults to 4 and max capped at 16 Example Invocation: @@ -92,7 +91,7 @@ def my_handler(keys: list[str], datum: Datum) -> Messages: grpc_server.start() """ self.sock_path = f"unix://{sock_path}" - self.max_threads = min(max_threads, int(os.getenv("MAX_THREADS", "4"))) + self.max_threads = min(max_threads, MAX_NUM_THREADS) self.max_message_size = max_message_size self.server_info_file = server_info_file diff --git a/tests/batchmap/__init__.py b/tests/batchmap/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/batchmap/test_async_batch_map.py b/tests/batchmap/test_async_batch_map.py new file mode 100644 index 00000000..4922ecb2 --- /dev/null +++ b/tests/batchmap/test_async_batch_map.py @@ -0,0 +1,188 @@ +import asyncio +import logging +import threading +import unittest +from collections.abc import AsyncIterable + +import grpc +from google.protobuf import empty_pb2 as _empty_pb2 +from grpc.aio._server import Server + +from pynumaflow import setup_logging +from pynumaflow.batchmapper import ( + Message, + Datum, + BatchMapper, + BatchResponses, + BatchResponse, + BatchMapAsyncServer, +) +from pynumaflow.proto.batchmapper import batchmap_pb2_grpc +from tests.batchmap.utils import start_request, request_generator + +LOGGER = setup_logging(__name__) + +listen_addr = "unix:///tmp/batch_map.sock" + +_s: Server = None +_channel = grpc.insecure_channel(listen_addr) +_loop = None + + +def startup_callable(loop): + asyncio.set_event_loop(loop) + loop.run_forever() + + +class ExampleClass(BatchMapper): + async def handler( + self, + datums: AsyncIterable[Datum], + ) -> BatchResponses: + batch_responses = BatchResponses() + async for datum in datums: + val = datum.value + _ = datum.event_time + _ = datum.watermark + strs = val.decode("utf-8").split(",") + batch_response = BatchResponse.from_id(datum.id) + if len(strs) == 0: + batch_response.append(Message.to_drop()) + else: + for s in strs: + batch_response.append(Message(str.encode(s))) + batch_responses.append(batch_response) + + return batch_responses + + +async def handler( + datums: AsyncIterable[Datum], +) -> BatchResponses: + batch_responses = BatchResponses() + async for datum in datums: + val = datum.value + _ = datum.event_time + _ = datum.watermark + strs = val.decode("utf-8").split(",") + batch_response = BatchResponse.from_id(datum.id) + if len(strs) == 0: + batch_response.append(Message.to_drop()) + else: + for s in strs: + batch_response.append(Message(str.encode(s))) + batch_responses.append(batch_response) + + return batch_responses + + +def NewAsyncBatchMapper(): + d = ExampleClass() + server_instance = BatchMapAsyncServer(d) + udfs = server_instance.servicer + return udfs + + +async def start_server(udfs): + server = grpc.aio.server() + batchmap_pb2_grpc.add_BatchMapServicer_to_server(udfs, server) + server.add_insecure_port(listen_addr) + logging.info("Starting server on %s", listen_addr) + global _s + _s = server + await server.start() + await server.wait_for_termination() + + +class TestAsyncBatchMapper(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + global _loop + loop = asyncio.new_event_loop() + _loop = loop + _thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) + _thread.start() + udfs = NewAsyncBatchMapper() + asyncio.run_coroutine_threadsafe(start_server(udfs), loop=loop) + while True: + try: + with grpc.insecure_channel(listen_addr) as channel: + f = grpc.channel_ready_future(channel) + f.result(timeout=10) + if f.done(): + break + except grpc.FutureTimeoutError as e: + LOGGER.error("error trying to connect to grpc server") + LOGGER.error(e) + + @classmethod + def tearDownClass(cls) -> None: + try: + _loop.stop() + LOGGER.info("stopped the event loop") + except Exception as e: + LOGGER.error(e) + + def test_batch_map(self) -> None: + stub = self.__stub() + request = start_request() + generator_response = None + + try: + generator_response = stub.BatchMapFn( + request_iterator=request_generator(count=10, request=request) + ) + except grpc.RpcError as e: + logging.error(e) + + # capture the output from the BatchMapFn generator and assert. + count = 0 + for r in generator_response: + self.assertEqual( + bytes( + "test_mock_message", + encoding="utf-8", + ), + r.results[0].value, + ) + _id = r.id + self.assertEqual(_id, str(count)) + count += 1 + + # in our example we should be return 10 messages which is equal to the number + # of requests + self.assertEqual(10, count) + + def test_is_ready(self) -> None: + with grpc.insecure_channel(listen_addr) as channel: + stub = batchmap_pb2_grpc.BatchMapStub(channel) + + request = _empty_pb2.Empty() + response = None + try: + response = stub.IsReady(request=request) + except grpc.RpcError as e: + logging.error(e) + + self.assertTrue(response.ready) + + def test_max_threads(self): + # max cap at 16 + server = BatchMapAsyncServer(batch_mapper_instance=handler, max_threads=32) + self.assertEqual(server.max_threads, 16) + + # use argument provided + server = BatchMapAsyncServer(batch_mapper_instance=handler, max_threads=5) + self.assertEqual(server.max_threads, 5) + + # defaults to 4 + server = BatchMapAsyncServer(batch_mapper_instance=handler) + self.assertEqual(server.max_threads, 4) + + def __stub(self): + return batchmap_pb2_grpc.BatchMapStub(_channel) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/tests/batchmap/test_async_batch_map_err.py b/tests/batchmap/test_async_batch_map_err.py new file mode 100644 index 00000000..6f9c3bd0 --- /dev/null +++ b/tests/batchmap/test_async_batch_map_err.py @@ -0,0 +1,142 @@ +import asyncio +import logging +import threading +import unittest +from unittest.mock import patch + +import grpc + +from grpc.aio._server import Server + +from pynumaflow import setup_logging +from pynumaflow.batchmapper import BatchResponses +from pynumaflow.batchmapper import Datum, BatchMapAsyncServer +from pynumaflow.proto.batchmapper import batchmap_pb2_grpc +from tests.batchmap.utils import start_request +from tests.testing_utils import mock_terminate_on_stop + +LOGGER = setup_logging(__name__) + +raise_error = False + + +def request_generator(count, request, resetkey: bool = False): + for i in range(count): + # add the id to the datum + request.id = str(i) + if resetkey: + request.payload.keys.extend([f"key-{i}"]) + yield request + + +# This handler mimics the scenario where batch map UDF throws a runtime error. +async def err_handler(datums: list[Datum]) -> BatchResponses: + if raise_error: + raise RuntimeError("Got a runtime error from batch map handler.") + batch_responses = BatchResponses() + return batch_responses + + +listen_addr = "unix:///tmp/async_batch_map_err.sock" + +_s: Server = None +_channel = grpc.insecure_channel(listen_addr) +_loop = None + + +def startup_callable(loop): + asyncio.set_event_loop(loop) + loop.run_forever() + + +# We are mocking the terminate function from the psutil to not exit the program during testing +@patch("psutil.Process.kill", mock_terminate_on_stop) +async def start_server(): + server = grpc.aio.server() + server_instance = BatchMapAsyncServer(err_handler) + udfs = server_instance.servicer + batchmap_pb2_grpc.add_BatchMapServicer_to_server(udfs, server) + server.add_insecure_port(listen_addr) + logging.info("Starting server on %s", listen_addr) + global _s + _s = server + await server.start() + await server.wait_for_termination() + + +# We are mocking the terminate function from the psutil to not exit the program during testing +@patch("psutil.Process.kill", mock_terminate_on_stop) +class TestAsyncServerErrorScenario(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + global _loop + loop = asyncio.new_event_loop() + _loop = loop + _thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) + _thread.start() + asyncio.run_coroutine_threadsafe(start_server(), loop=loop) + while True: + try: + with grpc.insecure_channel(listen_addr) as channel: + f = grpc.channel_ready_future(channel) + f.result(timeout=10) + if f.done(): + break + except grpc.FutureTimeoutError as e: + LOGGER.error("error trying to connect to grpc server") + LOGGER.error(e) + + @classmethod + def tearDownClass(cls) -> None: + try: + _loop.stop() + LOGGER.info("stopped the event loop") + except Exception as e: + LOGGER.error(e) + + def test_batch_map_error(self) -> None: + global raise_error + raise_error = True + stub = self.__stub() + try: + generator_response = stub.BatchMapFn( + request_iterator=request_generator(count=10, request=start_request()) + ) + counter = 0 + for _ in generator_response: + counter += 1 + except Exception as err: + self.assertTrue("Got a runtime error from batch map handler." in err.__str__()) + return + self.fail("Expected an exception.") + + def test_batch_map_length_error(self) -> None: + global raise_error + raise_error = False + stub = self.__stub() + try: + generator_response = stub.BatchMapFn( + request_iterator=request_generator(count=10, request=start_request()) + ) + counter = 0 + for _ in generator_response: + counter += 1 + except Exception as err: + self.assertTrue( + "batchMapFn: mismatch between length of batch requests and responses" + in err.__str__() + ) + return + self.fail("Expected an exception.") + + def __stub(self): + return batchmap_pb2_grpc.BatchMapStub(_channel) + + def test_invalid_input(self): + with self.assertRaises(TypeError): + BatchMapAsyncServer() + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/tests/batchmap/test_datatypes.py b/tests/batchmap/test_datatypes.py new file mode 100644 index 00000000..8649732a --- /dev/null +++ b/tests/batchmap/test_datatypes.py @@ -0,0 +1,113 @@ +import unittest + +from google.protobuf import timestamp_pb2 as _timestamp_pb2 + +from pynumaflow.batchmapper._dtypes import ( + Datum, +) +from tests.testing_utils import ( + mock_message, + mock_event_time, + mock_watermark, +) + +TEST_KEYS = ["test"] +TEST_ID = "test_id" +TEST_HEADERS = {"key1": "value1", "key2": "value2"} + + +class TestDatum(unittest.TestCase): + def test_err_event_time(self): + ts = _timestamp_pb2.Timestamp() + ts.GetCurrentTime() + with self.assertRaises(Exception) as context: + Datum( + keys=TEST_KEYS, + value=mock_message(), + event_time=ts, + watermark=ts, + headers=TEST_HEADERS, + id=TEST_ID, + ) + self.assertEqual( + "Wrong data type: " + "for Datum.event_time", + str(context.exception), + ) + + def test_err_watermark(self): + ts = _timestamp_pb2.Timestamp() + ts.GetCurrentTime() + with self.assertRaises(Exception) as context: + Datum( + keys=TEST_KEYS, + value=mock_message(), + event_time=mock_event_time(), + watermark=ts, + headers=TEST_HEADERS, + id=TEST_ID, + ) + self.assertEqual( + "Wrong data type: " + "for Datum.watermark", + str(context.exception), + ) + + def test_value(self): + d = Datum( + keys=TEST_KEYS, + value=mock_message(), + event_time=mock_event_time(), + watermark=mock_watermark(), + headers=TEST_HEADERS, + id=TEST_ID, + ) + self.assertEqual(mock_message(), d.value) + + def test_key(self): + d = Datum( + keys=TEST_KEYS, + value=mock_message(), + event_time=mock_event_time(), + watermark=mock_watermark(), + id=TEST_ID, + ) + self.assertEqual(TEST_KEYS, d.keys()) + + def test_event_time(self): + d = Datum( + keys=TEST_KEYS, + value=mock_message(), + event_time=mock_event_time(), + watermark=mock_watermark(), + headers=TEST_HEADERS, + id=TEST_ID, + ) + self.assertEqual(mock_event_time(), d.event_time) + self.assertEqual(TEST_HEADERS, d.headers) + + def test_watermark(self): + d = Datum( + keys=TEST_KEYS, + value=mock_message(), + event_time=mock_event_time(), + watermark=mock_watermark(), + id=TEST_ID, + ) + self.assertEqual(mock_watermark(), d.watermark) + self.assertEqual({}, d.headers) + + def test_id(self): + d = Datum( + keys=TEST_KEYS, + value=mock_message(), + event_time=mock_event_time(), + watermark=mock_watermark(), + id=TEST_ID, + ) + self.assertEqual(TEST_ID, d.id) + self.assertEqual({}, d.headers) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/batchmap/test_messages.py b/tests/batchmap/test_messages.py new file mode 100644 index 00000000..351db3b5 --- /dev/null +++ b/tests/batchmap/test_messages.py @@ -0,0 +1,99 @@ +import unittest + +from pynumaflow.batchmapper import Message, DROP, BatchResponse, BatchResponses +from tests.batchmap.test_datatypes import TEST_ID +from tests.testing_utils import mock_message + + +class TestBatchResponses(unittest.TestCase): + @staticmethod + def mock_message_object(): + value = mock_message() + return Message(value=value) + + def test_init(self): + batch_responses = BatchResponses() + batch_response1 = BatchResponse.from_id(TEST_ID) + batch_response2 = BatchResponse.from_id(TEST_ID + "2") + batch_responses.append(batch_response1) + batch_responses.append(batch_response2) + self.assertEqual(2, len(batch_responses)) + # test indexing + self.assertEqual(batch_responses[0].id, TEST_ID) + self.assertEqual(batch_responses[1].id, TEST_ID + "2") + # test slicing + resp = batch_responses[0:1] + self.assertEqual(resp[0].id, TEST_ID) + + +class TestBatchResponse(unittest.TestCase): + @staticmethod + def mock_message_object(): + value = mock_message() + return Message(value=value) + + def test_init(self): + batch_response = BatchResponse.from_id(TEST_ID) + self.assertEqual(batch_response.id, TEST_ID) + + def test_invalid_input(self): + with self.assertRaises(TypeError): + BatchResponse() + + def test_append(self): + batch_response = BatchResponse.from_id(TEST_ID) + self.assertEqual(0, len(batch_response.items())) + batch_response.append(self.mock_message_object()) + self.assertEqual(1, len(batch_response.items())) + batch_response.append(self.mock_message_object()) + self.assertEqual(2, len(batch_response.items())) + + def test_items(self): + mock_obj = [ + mock_message(), + mock_message(), + ] + msgs = BatchResponse.with_msgs(TEST_ID, mock_obj) + self.assertEqual(len(mock_obj), len(msgs.items())) + self.assertEqual(mock_obj[0], msgs.items()[0]) + + +class TestMessage(unittest.TestCase): + def test_key(self): + mock_obj = {"Keys": ["test-key"], "Value": mock_message()} + msg = Message(value=mock_obj["Value"], keys=mock_obj["Keys"]) + print(msg) + self.assertEqual(mock_obj["Keys"], msg.keys) + + def test_value(self): + mock_obj = {"Keys": ["test-key"], "Value": mock_message()} + msg = Message(value=mock_obj["Value"], keys=mock_obj["Keys"]) + self.assertEqual(mock_obj["Value"], msg.value) + + def test_message_to_all(self): + mock_obj = {"Keys": [], "Value": mock_message(), "Tags": []} + msg = Message(mock_obj["Value"]) + self.assertEqual(Message, type(msg)) + self.assertEqual(mock_obj["Keys"], msg.keys) + self.assertEqual(mock_obj["Value"], msg.value) + self.assertEqual(mock_obj["Tags"], msg.tags) + + def test_message_to_drop(self): + mock_obj = {"Keys": [], "Value": b"", "Tags": [DROP]} + msg = Message(b"").to_drop() + self.assertEqual(Message, type(msg)) + self.assertEqual(mock_obj["Keys"], msg.keys) + self.assertEqual(mock_obj["Value"], msg.value) + self.assertEqual(mock_obj["Tags"], msg.tags) + + def test_message_to(self): + mock_obj = {"Keys": ["__KEY__"], "Value": mock_message(), "Tags": ["__TAG__"]} + msg = Message(value=mock_obj["Value"], keys=mock_obj["Keys"], tags=mock_obj["Tags"]) + self.assertEqual(Message, type(msg)) + self.assertEqual(mock_obj["Keys"], msg.keys) + self.assertEqual(mock_obj["Value"], msg.value) + self.assertEqual(mock_obj["Tags"], msg.tags) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/batchmap/utils.py b/tests/batchmap/utils.py new file mode 100644 index 00000000..31f87feb --- /dev/null +++ b/tests/batchmap/utils.py @@ -0,0 +1,22 @@ +from pynumaflow.batchmapper import Datum +from pynumaflow.proto.batchmapper import batchmap_pb2 +from tests.testing_utils import get_time_args, mock_message + + +def request_generator(count, request, resetkey: bool = False): + for i in range(count): + # add the id to the datum + request.id = str(i) + if resetkey: + request.payload.keys.extend([f"key-{i}"]) + yield request + + +def start_request() -> Datum: + event_time_timestamp, watermark_timestamp = get_time_args() + request = batchmap_pb2.BatchMapRequest( + value=mock_message(), + event_time=event_time_timestamp, + watermark=watermark_timestamp, + ) + return request diff --git a/tests/map/test_async_mapper.py b/tests/map/test_async_mapper.py index 41c2693a..3ec67e1e 100644 --- a/tests/map/test_async_mapper.py +++ b/tests/map/test_async_mapper.py @@ -219,6 +219,19 @@ def test_invalid_input(self): def __stub(self): return map_pb2_grpc.MapStub(_channel) + def test_max_threads(self): + # max cap at 16 + server = MapAsyncServer(mapper_instance=async_map_handler, max_threads=32) + self.assertEqual(server.max_threads, 16) + + # use argument provided + server = MapAsyncServer(mapper_instance=async_map_handler, max_threads=5) + self.assertEqual(server.max_threads, 5) + + # defaults to 4 + server = MapAsyncServer(mapper_instance=async_map_handler) + self.assertEqual(server.max_threads, 4) + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) diff --git a/tests/map/test_sync_mapper.py b/tests/map/test_sync_mapper.py index 1e9b011c..f58c216d 100644 --- a/tests/map/test_sync_mapper.py +++ b/tests/map/test_sync_mapper.py @@ -149,6 +149,19 @@ def test_invalid_input(self): with self.assertRaises(TypeError): MapServer() + def test_max_threads(self): + # max cap at 16 + server = MapServer(mapper_instance=map_handler, max_threads=32) + self.assertEqual(server.max_threads, 16) + + # use argument provided + server = MapServer(mapper_instance=map_handler, max_threads=5) + self.assertEqual(server.max_threads, 5) + + # defaults to 4 + server = MapServer(mapper_instance=map_handler) + self.assertEqual(server.max_threads, 4) + if __name__ == "__main__": unittest.main() diff --git a/tests/mapstream/test_async_map_stream.py b/tests/mapstream/test_async_map_stream.py index 107289a6..1558ccf7 100644 --- a/tests/mapstream/test_async_map_stream.py +++ b/tests/mapstream/test_async_map_stream.py @@ -133,6 +133,19 @@ def test_is_ready(self) -> None: def __stub(self): return mapstream_pb2_grpc.MapStreamStub(_channel) + def test_max_threads(self): + # max cap at 16 + server = MapStreamAsyncServer(map_stream_instance=async_map_stream_handler, max_threads=32) + self.assertEqual(server.max_threads, 16) + + # use argument provided + server = MapStreamAsyncServer(map_stream_instance=async_map_stream_handler, max_threads=5) + self.assertEqual(server.max_threads, 5) + + # defaults to 4 + server = MapStreamAsyncServer(map_stream_instance=async_map_stream_handler) + self.assertEqual(server.max_threads, 4) + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) diff --git a/tests/reduce/test_async_reduce.py b/tests/reduce/test_async_reduce.py index b8f399ca..91e3c210 100644 --- a/tests/reduce/test_async_reduce.py +++ b/tests/reduce/test_async_reduce.py @@ -250,6 +250,19 @@ class ExampleBadClass: with self.assertRaises(TypeError): ReduceAsyncServer(reducer_handler=ExampleBadClass) + def test_max_threads(self): + # max cap at 16 + server = ReduceAsyncServer(reducer_handler=ExampleClass, max_threads=32) + self.assertEqual(server.max_threads, 16) + + # use argument provided + server = ReduceAsyncServer(reducer_handler=ExampleClass, max_threads=5) + self.assertEqual(server.max_threads, 5) + + # defaults to 4 + server = ReduceAsyncServer(reducer_handler=ExampleClass) + self.assertEqual(server.max_threads, 4) + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) diff --git a/tests/reducestreamer/test_async_reduce.py b/tests/reducestreamer/test_async_reduce.py index c00e27c7..327f8fa9 100644 --- a/tests/reducestreamer/test_async_reduce.py +++ b/tests/reducestreamer/test_async_reduce.py @@ -281,6 +281,19 @@ class ExampleBadClass: with self.assertRaises(TypeError): ReduceStreamAsyncServer(reduce_stream_handler=ExampleBadClass) + def test_max_threads(self): + # max cap at 16 + server = ReduceStreamAsyncServer(reduce_stream_handler=ExampleClass, max_threads=32) + self.assertEqual(server.max_threads, 16) + + # use argument provided + server = ReduceStreamAsyncServer(reduce_stream_handler=ExampleClass, max_threads=5) + self.assertEqual(server.max_threads, 5) + + # defaults to 4 + server = ReduceStreamAsyncServer(reduce_stream_handler=ExampleClass) + self.assertEqual(server.max_threads, 4) + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) diff --git a/tests/sideinput/test_side_input_server.py b/tests/sideinput/test_side_input_server.py index 085b0b7a..50886c79 100644 --- a/tests/sideinput/test_side_input_server.py +++ b/tests/sideinput/test_side_input_server.py @@ -145,6 +145,19 @@ def test_invalid_input(self): with self.assertRaises(TypeError): SideInputServer() + def test_max_threads(self): + # max cap at 16 + server = SideInputServer(retrieve_side_input_handler, max_threads=32) + self.assertEqual(server.max_threads, 16) + + # use argument provided + server = SideInputServer(retrieve_side_input_handler, max_threads=5) + self.assertEqual(server.max_threads, 5) + + # defaults to 4 + server = SideInputServer(retrieve_side_input_handler) + self.assertEqual(server.max_threads, 4) + if __name__ == "__main__": unittest.main() diff --git a/tests/sink/test_async_sink.py b/tests/sink/test_async_sink.py index 3878d500..f04230cd 100644 --- a/tests/sink/test_async_sink.py +++ b/tests/sink/test_async_sink.py @@ -194,6 +194,19 @@ def test_start_fallback_sink(self): self.assertEqual(server.sock_path, f"unix://{FALLBACK_SINK_SOCK_PATH}") self.assertEqual(server.server_info_file, FALLBACK_SINK_SERVER_INFO_FILE_PATH) + def test_max_threads(self): + # max cap at 16 + server = SinkAsyncServer(sinker_instance=udsink_handler, max_threads=32) + self.assertEqual(server.max_threads, 16) + + # use argument provided + server = SinkAsyncServer(sinker_instance=udsink_handler, max_threads=5) + self.assertEqual(server.max_threads, 5) + + # defaults to 4 + server = SinkAsyncServer(sinker_instance=udsink_handler) + self.assertEqual(server.max_threads, 4) + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) diff --git a/tests/sink/test_server.py b/tests/sink/test_server.py index c511730b..d5fa8b17 100644 --- a/tests/sink/test_server.py +++ b/tests/sink/test_server.py @@ -200,6 +200,19 @@ def test_start_fallback_sink(self): self.assertEqual(server.sock_path, f"unix://{FALLBACK_SINK_SOCK_PATH}") self.assertEqual(server.server_info_file, FALLBACK_SINK_SERVER_INFO_FILE_PATH) + def test_max_threads(self): + # max cap at 16 + server = SinkServer(sinker_instance=udsink_handler, max_threads=32) + self.assertEqual(server.max_threads, 16) + + # use argument provided + server = SinkServer(sinker_instance=udsink_handler, max_threads=5) + self.assertEqual(server.max_threads, 5) + + # defaults to 4 + server = SinkServer(sinker_instance=udsink_handler) + self.assertEqual(server.max_threads, 4) + if __name__ == "__main__": unittest.main() diff --git a/tests/source/test_async_source.py b/tests/source/test_async_source.py index fdf5e756..08e7a758 100644 --- a/tests/source/test_async_source.py +++ b/tests/source/test_async_source.py @@ -170,6 +170,20 @@ def test_partitions(self) -> None: def __stub(self): return source_pb2_grpc.SourceStub(_channel) + def test_max_threads(self): + class_instance = AsyncSource() + # max cap at 16 + server = SourceAsyncServer(sourcer_instance=class_instance, max_threads=32) + self.assertEqual(server.max_threads, 16) + + # use argument provided + server = SourceAsyncServer(sourcer_instance=class_instance, max_threads=5) + self.assertEqual(server.max_threads, 5) + + # defaults to 4 + server = SourceAsyncServer(sourcer_instance=class_instance) + self.assertEqual(server.max_threads, 4) + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) diff --git a/tests/source/test_sync_source.py b/tests/source/test_sync_source.py index 5e6c2ca8..36dc5b68 100644 --- a/tests/source/test_sync_source.py +++ b/tests/source/test_sync_source.py @@ -143,6 +143,19 @@ def test_source_partition(self): response, metadata, code, details = method.termination() self.assertEqual(response.result.partitions, mock_partitions()) + def test_max_threads(self): + # max cap at 16 + server = SourceServer(sourcer_instance=SyncSource(), max_threads=32) + self.assertEqual(server.max_threads, 16) + + # use argument provided + server = SourceServer(sourcer_instance=SyncSource(), max_threads=5) + self.assertEqual(server.max_threads, 5) + + # defaults to 4 + server = SourceServer(sourcer_instance=SyncSource()) + self.assertEqual(server.max_threads, 4) + if __name__ == "__main__": unittest.main() diff --git a/tests/sourcetransform/test_multiproc.py b/tests/sourcetransform/test_multiproc.py index 594f5cfa..725da4fe 100644 --- a/tests/sourcetransform/test_multiproc.py +++ b/tests/sourcetransform/test_multiproc.py @@ -146,6 +146,23 @@ def test_invalid_input(self): with self.assertRaises(TypeError): SourceTransformMultiProcServer() + def test_max_threads(self): + # max cap at 16 + server = SourceTransformMultiProcServer( + source_transform_instance=transform_handler, max_threads=32 + ) + self.assertEqual(server.max_threads, 16) + + # use argument provided + server = SourceTransformMultiProcServer( + source_transform_instance=transform_handler, max_threads=5 + ) + self.assertEqual(server.max_threads, 5) + + # defaults to 4 + server = SourceTransformMultiProcServer(source_transform_instance=transform_handler) + self.assertEqual(server.max_threads, 4) + if __name__ == "__main__": unittest.main() diff --git a/tests/sourcetransform/test_sync_server.py b/tests/sourcetransform/test_sync_server.py index 32e510a6..7b5d174b 100644 --- a/tests/sourcetransform/test_sync_server.py +++ b/tests/sourcetransform/test_sync_server.py @@ -138,6 +138,19 @@ def test_invalid_input(self): with self.assertRaises(TypeError): SourceTransformServer() + def test_max_threads(self): + # max cap at 16 + server = SourceTransformServer(source_transform_instance=transform_handler, max_threads=32) + self.assertEqual(server.max_threads, 16) + + # use argument provided + server = SourceTransformServer(source_transform_instance=transform_handler, max_threads=5) + self.assertEqual(server.max_threads, 5) + + # defaults to 4 + server = SourceTransformServer(source_transform_instance=transform_handler) + self.assertEqual(server.max_threads, 4) + if __name__ == "__main__": unittest.main()