From bd5d22b218f271d97a8b8033ae9c6f2991b8764d Mon Sep 17 00:00:00 2001 From: Sidhant Kohli Date: Tue, 1 Oct 2024 15:38:49 -0700 Subject: [PATCH] add tests Signed-off-by: Sidhant Kohli --- examples/source/async_source/Makefile | 2 +- examples/source/async_source/pipeline.yaml | 4 +- pynumaflow/sourcer/servicer/async_servicer.py | 23 +----- tests/source/test_async_source.py | 81 ++++++++++--------- tests/source/test_async_source_err.py | 10 +-- 5 files changed, 54 insertions(+), 66 deletions(-) diff --git a/examples/source/async_source/Makefile b/examples/source/async_source/Makefile index 1f1583d3..c39b171c 100644 --- a/examples/source/async_source/Makefile +++ b/examples/source/async_source/Makefile @@ -1,4 +1,4 @@ -TAG ?= testsv2 +TAG ?= stable PUSH ?= false IMAGE_REGISTRY = quay.io/numaio/numaflow-python/async-source:${TAG} DOCKER_FILE_PATH = examples/source/async_source/Dockerfile diff --git a/examples/source/async_source/pipeline.yaml b/examples/source/async_source/pipeline.yaml index dcb2ae1b..4c148fb0 100644 --- a/examples/source/async_source/pipeline.yaml +++ b/examples/source/async_source/pipeline.yaml @@ -9,8 +9,8 @@ spec: udsource: container: # A simple user-defined async source - image: quay.io/numaio/numaflow-python/async-source:testsv2 -# imagePullPolicy: Always + image: quay.io/numaio/numaflow-python/async-source:stable + imagePullPolicy: Always limits: readBatchSize: 2 - name: out diff --git a/pynumaflow/sourcer/servicer/async_servicer.py b/pynumaflow/sourcer/servicer/async_servicer.py index ee26c59a..e178307e 100644 --- a/pynumaflow/sourcer/servicer/async_servicer.py +++ b/pynumaflow/sourcer/servicer/async_servicer.py @@ -15,24 +15,6 @@ from pynumaflow._constants import _LOGGER, STREAM_EOF -# async def read_datum_generator( -# request_iterator: AsyncIterable[source_pb2.ReadRequest], -# ) -> AsyncIterable[ReadRequest]: -# """ -# This function is used to create an async generator -# from the gRPC request iterator. -# It yields a Datum instance for each request received which is then -# forwarded to the UDF. -# """ -# async for d in request_iterator: -# _LOGGER.info("d %s", d) -# request = ReadRequest( -# num_records=d.request.num_records, -# timeout_in_ms=d.request.timeout_in_ms, -# ) -# yield request - - class AsyncSourceServicer(source_pb2_grpc.SourceServicer): """ This class is used to create a new grpc Source servicer instance. @@ -122,7 +104,7 @@ async def ReadFn( exit_on_error(context, str(err)) async def invoke_read(self, req, niter): - # TODO(source-stream): check with this timeout + # TODO(source-stream): check with this timeout in ms try: await self.__source_read_handler( ReadRequest( @@ -143,9 +125,6 @@ async def AckFn( """ Applies an Ack function in User Defined Source """ - # proto repeated field(offsets) is of type google._upb._message.RepeatedScalarContainer - # we need to explicitly convert it to list - try: need_handshake = True async for req in request_iterator: diff --git a/tests/source/test_async_source.py b/tests/source/test_async_source.py index 7d97e663..6a45900a 100644 --- a/tests/source/test_async_source.py +++ b/tests/source/test_async_source.py @@ -17,13 +17,11 @@ ack_req_source_fn, mock_partitions, AsyncSource, + mock_offset, ) LOGGER = setup_logging(__name__) -# if set to true, map handler will raise a `ValueError` exception. -raise_error_from_map = False - server_port = "unix:///tmp/async_source.sock" _s: Server = None @@ -63,13 +61,6 @@ def request_generator(count, request, req_type, resetkey: bool = False): elif req_type == "ack": yield source_pb2.AckRequest(handshake=source_pb2.Handshake(sot=True)) yield source_pb2.AckRequest(request=request) - # if resetkey: - # request.payload.keys.extend([f"key-{i}"]) - # - # if i % 2: - # request.operation.event = reduce_pb2.ReduceRequest.WindowOperation.Event.OPEN - # else: - # request.operation.event = reduce_pb2.ReduceRequest.WindowOperation.Event.APPEND class TestAsyncSourcer(unittest.TestCase): @@ -115,30 +106,42 @@ def test_read_source(self) -> None: logging.error(e) counter = 0 + first = True # capture the output from the ReadFn generator and assert. for r in generator_response: counter += 1 - print("R ", r) - # self.assertEqual( - # bytes("payload:test_mock_message", encoding="utf-8"), - # r.result.payload, - # ) - # self.assertEqual( - # ["test_key"], - # r.result.keys, - # ) - # self.assertEqual( - # mock_offset().offset, - # r.result.offset.offset, - # ) - # self.assertEqual( - # mock_offset().partition_id, - # r.result.offset.partition_id, - # ) - # """Assert that the generator was called 10 times in the stream""" - # self.assertEqual(10, counter) - - print(counter) + if first: + self.assertEqual(True, r.handshake.sot) + first = False + continue + + if r.status.eot: + last = True + continue + + self.assertEqual( + bytes("payload:test_mock_message", encoding="utf-8"), + r.result.payload, + ) + self.assertEqual( + ["test_key"], + r.result.keys, + ) + self.assertEqual( + mock_offset().offset, + r.result.offset.offset, + ) + self.assertEqual( + mock_offset().partition_id, + r.result.offset.partition_id, + ) + + self.assertFalse(first) + self.assertTrue(last) + + # Assert that the generator was called 12 + # (10 data messages + handshake + eot) times in the stream + self.assertEqual(12, counter) def test_is_ready(self) -> None: with grpc.insecure_channel(server_port) as channel: @@ -162,12 +165,18 @@ def test_ack(self) -> None: except grpc.RpcError as e: print(e) - responses = [] + count = 0 + first = True for r in response: - responses.append(r) - - self.assertEqual(len(responses), 2) - # TODO(source-stream): add exact check with handshake etc + count += 1 + if first: + self.assertEqual(True, r.handshake.sot) + first = False + continue + self.assertTrue(r.result.success) + + self.assertEqual(count, 2) + self.assertFalse(first) def test_pending(self) -> None: with grpc.insecure_channel(server_port) as channel: diff --git a/tests/source/test_async_source_err.py b/tests/source/test_async_source_err.py index 37a7ffa3..5c7525c5 100644 --- a/tests/source/test_async_source_err.py +++ b/tests/source/test_async_source_err.py @@ -80,7 +80,7 @@ def tearDownClass(cls) -> None: LOGGER.error(e) def test_read_error(self) -> None: - grpcException = None + grpc_exception = None with grpc.insecure_channel(server_port) as channel: stub = source_pb2_grpc.SourceStub(channel) request = read_req_source_fn() @@ -89,17 +89,17 @@ def test_read_error(self) -> None: generator_response = stub.ReadFn( request_iterator=request_generator(1, request, "read") ) - for r in generator_response: - print(r) + for _ in generator_response: + pass except BaseException as e: self.assertTrue("Got a runtime error from read handler." in e.__str__()) return except grpc.RpcError as e: - grpcException = e + grpc_exception = e self.assertEqual(grpc.StatusCode.UNKNOWN, e.code()) print(e.details()) - self.assertIsNotNone(grpcException) + self.assertIsNotNone(grpc_exception) self.fail("Expected an exception.") def test_ack_error(self) -> None: