diff --git a/scaleapi/__init__.py b/scaleapi/__init__.py index 838edfb..240ac1a 100644 --- a/scaleapi/__init__.py +++ b/scaleapi/__init__.py @@ -452,7 +452,7 @@ def get_tasks( def get_tasks_count( self, - project_name: str, + project_name: str = None, batch_name: str = None, task_type: TaskType = None, status: TaskStatus = None, @@ -470,7 +470,7 @@ def get_tasks_count( """Returns number of tasks with given filters. Args: - project_name (str): + project_name (str, optional): Project Name batch_name (str, optional): @@ -529,6 +529,11 @@ def get_tasks_count( Returns number of tasks """ + if not project_name and not batch_name: + raise ValueError( + "At least one of project_name or batch_name must be provided." + ) + tasks_args = self._process_tasks_endpoint_args( project_name, batch_name, diff --git a/scaleapi/_version.py b/scaleapi/_version.py index f04eeb8..64dfb12 100644 --- a/scaleapi/_version.py +++ b/scaleapi/_version.py @@ -1,2 +1,2 @@ -__version__ = "2.15.12" +__version__ = "2.15.13" __package_name__ = "scaleapi" diff --git a/tests/test_client.py b/tests/test_client.py index e3a007d..f2e88c8 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -417,6 +417,13 @@ def test_get_tasks_count(): assert tasks_count == get_tasks_count +def test_get_tasks_count_with_only_batch(): + batch = create_a_batch() + tasks_count = client.tasks(batch=batch.name).total + get_tasks_count = client.get_tasks_count(batch_name=batch.name) + assert tasks_count == get_tasks_count + + def test_finalize_batch(): batch = create_a_batch() batch = client.finalize_batch(batch.name)