Skip to content

Commit

Permalink
chore: refactor reduce handler argument (#201)
Browse files Browse the repository at this point in the history
Signed-off-by: Sidhant Kohli <[email protected]>
  • Loading branch information
kohlisid authored Nov 8, 2024
1 parent f2f7bf6 commit f46fbed
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 19 deletions.
8 changes: 4 additions & 4 deletions pynumaflow/reducer/async_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_handler(
"""
if inspect.isfunction(reducer_handler):
if len(init_args) > 0 or len(init_kwargs) > 0:
# if the init_args or init_kwargs are passed, then the reducer_handler
# if the init_args or init_kwargs are passed, then the reducer_instance
# can only be of class Reducer type
raise TypeError("Cannot pass function handler with init args or kwargs")
# return the function handler
Expand All @@ -58,7 +58,7 @@ class ReduceAsyncServer(NumaflowServer):
A new servicer instance is created and attached to the server.
The server instance is returned.
Args:
reducer_handler: The reducer instance to be used for Reduce UDF
reducer_instance: The reducer instance to be used for Reduce UDF
sock_path: The UNIX socket path to be used for the server
max_message_size: The max message size in bytes the server can receive and send
max_threads: The max number of threads to be spawned;
Expand Down Expand Up @@ -115,7 +115,7 @@ async def reduce_handler(keys: list[str],

def __init__(
self,
reducer_handler: ReduceCallable,
reducer_instance: ReduceCallable,
init_args: tuple = (),
init_kwargs: dict = None,
sock_path=REDUCE_SOCK_PATH,
Expand All @@ -137,7 +137,7 @@ def __init__(
"""
if init_kwargs is None:
init_kwargs = {}
self.reducer_handler = get_handler(reducer_handler, init_args, init_kwargs)
self.reducer_handler = get_handler(reducer_instance, init_args, init_kwargs)
self.sock_path = f"unix://{sock_path}"
self.max_message_size = max_message_size
self.max_threads = min(max_threads, MAX_NUM_THREADS)
Expand Down
10 changes: 5 additions & 5 deletions pynumaflow/reducestreamer/async_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_handler(
"""
if inspect.isfunction(reducer_handler):
if init_args or init_kwargs:
# if the init_args or init_kwargs are passed, then the reduce_stream_handler
# if the init_args or init_kwargs are passed, then the reduce_stream_instance
# can only be of class ReduceStreamer type
raise TypeError("Cannot pass function handler with init args or kwargs")
# return the function handler
Expand All @@ -60,7 +60,7 @@ class ReduceStreamAsyncServer(NumaflowServer):
A new servicer instance is created and attached to the server.
The server instance is returned.
Args:
reduce_stream_handler: The reducer instance to be used for
reduce_stream_instance: The reducer instance to be used for
Reduce Streaming UDF
init_args: The arguments to be passed to the reduce_stream_handler
init_kwargs: The keyword arguments to be passed to the
Expand Down Expand Up @@ -128,7 +128,7 @@ async def reduce_handler(

def __init__(
self,
reduce_stream_handler: ReduceStreamCallable,
reduce_stream_instance: ReduceStreamCallable,
init_args: tuple = (),
init_kwargs: dict = None,
sock_path=REDUCE_STREAM_SOCK_PATH,
Expand All @@ -141,7 +141,7 @@ def __init__(
A new servicer instance is created and attached to the server.
The server instance is returned.
Args:
reduce_stream_handler: The reducer instance to be used for
reduce_stream_instance: The reducer instance to be used for
Reduce Streaming UDF
init_args: The arguments to be passed to the reduce_stream_handler
init_kwargs: The keyword arguments to be passed to the
Expand All @@ -154,7 +154,7 @@ def __init__(
"""
if init_kwargs is None:
init_kwargs = {}
self.reduce_stream_handler = get_handler(reduce_stream_handler, init_args, init_kwargs)
self.reduce_stream_handler = get_handler(reduce_stream_instance, init_args, init_kwargs)
self.sock_path = f"unix://{sock_path}"
self.max_message_size = max_message_size
self.max_threads = min(max_threads, MAX_NUM_THREADS)
Expand Down
10 changes: 5 additions & 5 deletions tests/reduce/test_async_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def __stub(self):
return reduce_pb2_grpc.ReduceStub(_channel)

def test_error_init(self):
# Check that reducer_handler in required
# Check that reducer_instance in required
with self.assertRaises(TypeError):
ReduceAsyncServer()
# Check that the init_args and init_kwargs are passed
Expand All @@ -248,19 +248,19 @@ class ExampleBadClass:
pass

with self.assertRaises(TypeError):
ReduceAsyncServer(reducer_handler=ExampleBadClass)
ReduceAsyncServer(reducer_instance=ExampleBadClass)

def test_max_threads(self):
# max cap at 16
server = ReduceAsyncServer(reducer_handler=ExampleClass, max_threads=32)
server = ReduceAsyncServer(reducer_instance=ExampleClass, max_threads=32)
self.assertEqual(server.max_threads, 16)

# use argument provided
server = ReduceAsyncServer(reducer_handler=ExampleClass, max_threads=5)
server = ReduceAsyncServer(reducer_instance=ExampleClass, max_threads=5)
self.assertEqual(server.max_threads, 5)

# defaults to 4
server = ReduceAsyncServer(reducer_handler=ExampleClass)
server = ReduceAsyncServer(reducer_instance=ExampleClass)
self.assertEqual(server.max_threads, 4)


Expand Down
10 changes: 5 additions & 5 deletions tests/reducestreamer/test_async_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def __stub(self):
return reduce_pb2_grpc.ReduceStub(_channel)

def test_error_init(self):
# Check that reducer_handler in required
# Check that reducer_instance in required
with self.assertRaises(TypeError):
ReduceStreamAsyncServer()
# Check that the init_args and init_kwargs are passed
Expand All @@ -279,19 +279,19 @@ class ExampleBadClass:
pass

with self.assertRaises(TypeError):
ReduceStreamAsyncServer(reduce_stream_handler=ExampleBadClass)
ReduceStreamAsyncServer(reduce_stream_instance=ExampleBadClass)

def test_max_threads(self):
# max cap at 16
server = ReduceStreamAsyncServer(reduce_stream_handler=ExampleClass, max_threads=32)
server = ReduceStreamAsyncServer(reduce_stream_instance=ExampleClass, max_threads=32)
self.assertEqual(server.max_threads, 16)

# use argument provided
server = ReduceStreamAsyncServer(reduce_stream_handler=ExampleClass, max_threads=5)
server = ReduceStreamAsyncServer(reduce_stream_instance=ExampleClass, max_threads=5)
self.assertEqual(server.max_threads, 5)

# defaults to 4
server = ReduceStreamAsyncServer(reduce_stream_handler=ExampleClass)
server = ReduceStreamAsyncServer(reduce_stream_instance=ExampleClass)
self.assertEqual(server.max_threads, 4)


Expand Down

0 comments on commit f46fbed

Please sign in to comment.