Skip to content

Commit

Permalink
Make project name optional (#92)
Browse files Browse the repository at this point in the history
* Update __init__.py

made project_name optional

* Update _version.py

new release tag

* Update test_client.py

tests first draft

* updated comment

* either project name or batch name must be provided

* updated tests

* black formatted

* removed the explicit project_name=None parameter

* removed test_process_tasks_endpoint_args_with_batch_name

* formatted with black

* updated logic

* formatted with black

---------

Co-authored-by: Alejandro Gonzalez Barriga <[email protected]>
  • Loading branch information
1 parent 5661ee8 commit f6d64cb
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 4 deletions.
16 changes: 13 additions & 3 deletions scaleapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def tasks(self, **kwargs) -> Tasklist:

def get_tasks(
self,
project_name: str,
project_name: str = None,
batch_name: str = None,
task_type: TaskType = None,
status: TaskStatus = None,
Expand All @@ -345,7 +345,7 @@ def get_tasks(
`task_list = list(get_tasks(...))`
Args:
project_name (str):
project_name (str, optional):
Project Name
batch_name (str, optional):
Expand Down Expand Up @@ -412,6 +412,11 @@ def get_tasks(
Yields Task objects, can be iterated.
"""

if not project_name and not batch_name:
raise ValueError(
"At least one of project_name or batch_name must be provided."
)

next_token = None
has_more = True

Expand Down Expand Up @@ -548,7 +553,7 @@ def get_tasks_count(

@staticmethod
def _process_tasks_endpoint_args(
project_name: str,
project_name: str = None,
batch_name: str = None,
task_type: TaskType = None,
status: TaskStatus = None,
Expand All @@ -565,6 +570,11 @@ def _process_tasks_endpoint_args(
limited_response: bool = None,
):
"""Generates args for /tasks endpoint."""
if not project_name and not batch_name:
raise ValueError(
"At least one of project_name or batch_name must be provided."
)

tasks_args = {
"start_time": created_after,
"end_time": created_before,
Expand Down
2 changes: 1 addition & 1 deletion scaleapi/_version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__version__ = "2.15.10"
__version__ = "2.15.11"
__package_name__ = "scaleapi"
38 changes: 38 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,3 +503,41 @@ def test_list_teammates():
# assert len(new_teammates) >= len(
# old_teammates
# ) # needs to sleep for teammates list to be updated


def test_get_tasks_without_project_name():
with pytest.raises(ValueError):
list(client.get_tasks())


def test_get_tasks_with_optional_project_name():
batch = create_a_batch()
tasks = []
for _ in range(3):
tasks.append(make_a_task(batch=batch.name))
task_ids = {task.id for task in tasks}
for task in client.get_tasks(
project_name=None,
batch_name=batch.name,
limit=1,
):
assert task.id in task_ids


def test_process_tasks_endpoint_args_with_optional_project_name():
args = client._process_tasks_endpoint_args(batch_name="test_batch")
assert args["project"] is None
assert args["batch"] == "test_batch"


def test_get_tasks_with_batch_name():
batch = create_a_batch()
tasks = []
for _ in range(3):
tasks.append(make_a_task(batch=batch.name))
task_ids = {task.id for task in tasks}
for task in client.get_tasks(
batch_name=batch.name,
limit=1,
):
assert task.id in task_ids

0 comments on commit f6d64cb

Please sign in to comment.