Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
Signed-off-by: Sidhant Kohli <[email protected]>
  • Loading branch information
kohlisid committed Oct 1, 2024
1 parent 6cb40e9 commit bd5d22b
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 66 deletions.
2 changes: 1 addition & 1 deletion examples/source/async_source/Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
TAG ?= testsv2
TAG ?= stable
PUSH ?= false
IMAGE_REGISTRY = quay.io/numaio/numaflow-python/async-source:${TAG}
DOCKER_FILE_PATH = examples/source/async_source/Dockerfile
Expand Down
4 changes: 2 additions & 2 deletions examples/source/async_source/pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ spec:
udsource:
container:
# A simple user-defined async source
image: quay.io/numaio/numaflow-python/async-source:testsv2
# imagePullPolicy: Always
image: quay.io/numaio/numaflow-python/async-source:stable
imagePullPolicy: Always
limits:
readBatchSize: 2
- name: out
Expand Down
23 changes: 1 addition & 22 deletions pynumaflow/sourcer/servicer/async_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,6 @@
from pynumaflow._constants import _LOGGER, STREAM_EOF


# async def read_datum_generator(
# request_iterator: AsyncIterable[source_pb2.ReadRequest],
# ) -> AsyncIterable[ReadRequest]:
# """
# This function is used to create an async generator
# from the gRPC request iterator.
# It yields a Datum instance for each request received which is then
# forwarded to the UDF.
# """
# async for d in request_iterator:
# _LOGGER.info("d %s", d)
# request = ReadRequest(
# num_records=d.request.num_records,
# timeout_in_ms=d.request.timeout_in_ms,
# )
# yield request


class AsyncSourceServicer(source_pb2_grpc.SourceServicer):
"""
This class is used to create a new grpc Source servicer instance.
Expand Down Expand Up @@ -122,7 +104,7 @@ async def ReadFn(
exit_on_error(context, str(err))

async def invoke_read(self, req, niter):
# TODO(source-stream): check with this timeout
# TODO(source-stream): check with this timeout in ms
try:
await self.__source_read_handler(
ReadRequest(
Expand All @@ -143,9 +125,6 @@ async def AckFn(
"""
Applies an Ack function in User Defined Source
"""
# proto repeated field(offsets) is of type google._upb._message.RepeatedScalarContainer
# we need to explicitly convert it to list

try:
need_handshake = True
async for req in request_iterator:
Expand Down
81 changes: 45 additions & 36 deletions tests/source/test_async_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,11 @@
ack_req_source_fn,
mock_partitions,
AsyncSource,
mock_offset,
)

LOGGER = setup_logging(__name__)

# if set to true, map handler will raise a `ValueError` exception.
raise_error_from_map = False

server_port = "unix:///tmp/async_source.sock"

_s: Server = None
Expand Down Expand Up @@ -63,13 +61,6 @@ def request_generator(count, request, req_type, resetkey: bool = False):
elif req_type == "ack":
yield source_pb2.AckRequest(handshake=source_pb2.Handshake(sot=True))
yield source_pb2.AckRequest(request=request)
# if resetkey:
# request.payload.keys.extend([f"key-{i}"])
#
# if i % 2:
# request.operation.event = reduce_pb2.ReduceRequest.WindowOperation.Event.OPEN
# else:
# request.operation.event = reduce_pb2.ReduceRequest.WindowOperation.Event.APPEND


class TestAsyncSourcer(unittest.TestCase):
Expand Down Expand Up @@ -115,30 +106,42 @@ def test_read_source(self) -> None:
logging.error(e)

counter = 0
first = True
# capture the output from the ReadFn generator and assert.
for r in generator_response:
counter += 1
print("R ", r)
# self.assertEqual(
# bytes("payload:test_mock_message", encoding="utf-8"),
# r.result.payload,
# )
# self.assertEqual(
# ["test_key"],
# r.result.keys,
# )
# self.assertEqual(
# mock_offset().offset,
# r.result.offset.offset,
# )
# self.assertEqual(
# mock_offset().partition_id,
# r.result.offset.partition_id,
# )
# """Assert that the generator was called 10 times in the stream"""
# self.assertEqual(10, counter)

print(counter)
if first:
self.assertEqual(True, r.handshake.sot)
first = False
continue

if r.status.eot:
last = True
continue

self.assertEqual(
bytes("payload:test_mock_message", encoding="utf-8"),
r.result.payload,
)
self.assertEqual(
["test_key"],
r.result.keys,
)
self.assertEqual(
mock_offset().offset,
r.result.offset.offset,
)
self.assertEqual(
mock_offset().partition_id,
r.result.offset.partition_id,
)

self.assertFalse(first)
self.assertTrue(last)

# Assert that the generator was called 12
# (10 data messages + handshake + eot) times in the stream
self.assertEqual(12, counter)

def test_is_ready(self) -> None:
with grpc.insecure_channel(server_port) as channel:
Expand All @@ -162,12 +165,18 @@ def test_ack(self) -> None:
except grpc.RpcError as e:
print(e)

responses = []
count = 0
first = True
for r in response:
responses.append(r)

self.assertEqual(len(responses), 2)
# TODO(source-stream): add exact check with handshake etc
count += 1
if first:
self.assertEqual(True, r.handshake.sot)
first = False
continue
self.assertTrue(r.result.success)

self.assertEqual(count, 2)
self.assertFalse(first)

def test_pending(self) -> None:
with grpc.insecure_channel(server_port) as channel:
Expand Down
10 changes: 5 additions & 5 deletions tests/source/test_async_source_err.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def tearDownClass(cls) -> None:
LOGGER.error(e)

def test_read_error(self) -> None:
grpcException = None
grpc_exception = None
with grpc.insecure_channel(server_port) as channel:
stub = source_pb2_grpc.SourceStub(channel)
request = read_req_source_fn()
Expand All @@ -89,17 +89,17 @@ def test_read_error(self) -> None:
generator_response = stub.ReadFn(
request_iterator=request_generator(1, request, "read")
)
for r in generator_response:
print(r)
for _ in generator_response:
pass
except BaseException as e:
self.assertTrue("Got a runtime error from read handler." in e.__str__())
return
except grpc.RpcError as e:
grpcException = e
grpc_exception = e
self.assertEqual(grpc.StatusCode.UNKNOWN, e.code())
print(e.details())

self.assertIsNotNone(grpcException)
self.assertIsNotNone(grpc_exception)
self.fail("Expected an exception.")

def test_ack_error(self) -> None:
Expand Down

0 comments on commit bd5d22b

Please sign in to comment.