From 101427f8fa8727608f4ea1b37323f257442af4cb Mon Sep 17 00:00:00 2001 From: ebonnal Date: Tue, 21 Jan 2025 21:13:51 +0000 Subject: [PATCH] `.skip`/`.truncate`: make no-op if arguments are None --- streamable/functions.py | 4 ++-- streamable/stream.py | 4 ++-- streamable/util/validationtools.py | 19 ++++--------------- tests/test_stream.py | 24 ++++++++++++------------ 4 files changed, 20 insertions(+), 31 deletions(-) diff --git a/streamable/functions.py b/streamable/functions.py index c2cdb20..24e1f7e 100644 --- a/streamable/functions.py +++ b/streamable/functions.py @@ -41,10 +41,10 @@ validate_group_interval, validate_group_size, validate_iterator, + validate_optional_count, validate_skip_args, validate_throttle_interval, validate_throttle_per_period, - validate_truncate_args, ) with suppress(ImportError): @@ -210,7 +210,7 @@ def truncate( when: Optional[Callable[[T], Any]] = None, ) -> Iterator[T]: validate_iterator(iterator) - validate_truncate_args(count, when) + validate_optional_count(count) if count is not None: iterator = CountTruncateIterator(iterator, count) if when is not None: diff --git a/streamable/stream.py b/streamable/stream.py index 759da5a..f9c74d0 100644 --- a/streamable/stream.py +++ b/streamable/stream.py @@ -28,10 +28,10 @@ validate_concurrency, validate_group_interval, validate_group_size, + validate_optional_count, validate_skip_args, validate_throttle_interval, validate_throttle_per_period, - validate_truncate_args, validate_via, ) @@ -489,7 +489,7 @@ def truncate( Returns: Stream[T]: A stream of at most `count` upstream elements not satisfying the `when` predicate. """ - validate_truncate_args(count, when) + validate_optional_count(count) return TruncateStream(self, count, when) diff --git a/streamable/util/validationtools.py b/streamable/util/validationtools.py index b54d192..2cb2aad 100644 --- a/streamable/util/validationtools.py +++ b/streamable/util/validationtools.py @@ -46,6 +46,9 @@ def validate_count(count: int): if count >= sys.maxsize: raise ValueError(f"`count` must be < sys.maxsize but got {count}") +def validate_optional_count(count: Optional[int]): + if count is not None: + validate_count(count) def validate_throttle_per_period(per_period_arg_name: str, value: int) -> None: if value < 1: @@ -56,24 +59,10 @@ def validate_throttle_interval(interval: datetime.timedelta) -> None: if interval < datetime.timedelta(0): raise ValueError(f"`interval` must be >= 0 but got {repr(interval)}") - -def validate_truncate_args( - count: Optional[int] = None, when: Optional[Callable[[T], Any]] = None -) -> None: - if count is None: - if when is None: - raise ValueError("`count` and `when` cannot both be None") - else: - validate_count(count) - - def validate_skip_args( count: Optional[int] = None, until: Optional[Callable[[T], Any]] = None ) -> None: - if count is None: - if until is None: - raise ValueError("`count` and `until` cannot both be None") - else: + if count is not None: if until is not None: raise ValueError("`count` and `until` cannot both be set") validate_count(count) diff --git a/tests/test_stream.py b/tests/test_stream.py index af4f38d..b799104 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -820,12 +820,11 @@ def test_skip(self) -> None: ): Stream(src).skip(0, until=bool) - with self.assertRaisesRegex( - ValueError, - "`count` and `until` cannot both be None", - msg="`skip` must raise ValueError if both `count` and `until` are None", - ): - Stream(src).skip() + self.assertListEqual( + list(Stream(src).skip()), + list(src), + msg="`skip` must be no-op if both `count` and `until` are None", + ) for count in [0, 1, 3]: self.assertListEqual( @@ -857,17 +856,18 @@ def test_skip(self) -> None: ) def test_truncate(self) -> None: - with self.assertRaisesRegex( - ValueError, - "`count` and `when` cannot both be None", - ): - Stream(src).truncate() - self.assertListEqual( list(Stream(src).truncate(N * 2)), list(src), msg="`truncate` must be ok with count >= stream length", ) + + self.assertListEqual( + list(Stream(src).truncate()), + list(src), + msg="`truncate must be no-op if both `count` and `when` are None", + ) + self.assertListEqual( list(Stream(src).truncate(2)), [0, 1],