Skip to content

Commit

Permalink
ENH Move some test Task models to use Fields instead of direct parame…
Browse files Browse the repository at this point in the history
…ters
  • Loading branch information
gadorlhiac committed May 7, 2024
1 parent 783b9c7 commit 301bd92
Showing 1 changed file with 29 additions and 13 deletions.
42 changes: 29 additions & 13 deletions lute/io/models/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
]
__author__ = "Gabriel Dorlhiac"

from typing import Dict, Any
from typing import Dict, Any, Optional

from pydantic import (
BaseModel,
Expand All @@ -42,32 +42,47 @@
class TestParameters(TaskParameters):
"""Parameters for the test Task `Test`."""

float_var: float = 0.01
str_var: str = "test"
float_var: float = Field(0.01, description="A floating point number.")
str_var: str = Field("test", description="A string.")

class CompoundVar(BaseModel):
int_var: int = 1
dict_var: Dict[str, str] = {"a": "b"}

compound_var: CompoundVar
throw_error: bool = False
compound_var: CompoundVar = Field(
description=(
"A compound parameter - consists of a `int_var` (int) and `dict_var`"
" (Dict[str, str])."
)
)
throw_error: bool = Field(
False, description="If `True`, raise an exception to test error handling."
)


class TestBinaryParameters(ThirdPartyParameters):
executable: str = "/sdf/home/d/dorlhiac/test_tasks/test_threads"
p_arg1: int = 1
executable: str = Field(
"/sdf/home/d/dorlhiac/test_tasks/test_threads",
description="Multi-threaded test binary.",
)
p_arg1: int = Field(1, descriptions="Number of threads.")


class TestBinaryErrParameters(ThirdPartyParameters):
"""Same as TestBinary, but exits with non-zero code."""

executable: str = "/sdf/home/d/dorlhiac/test_tasks/test_threads_err"
p_arg1: int = 1
executable: str = Field(
"/sdf/home/d/dorlhiac/test_tasks/test_threads_err",
description="Multi-threaded tes tbinary with non-zero exit code.",
)
p_arg1: int = Field(1, description="Number of threads.")


class TestSocketParameters(TaskParameters):
array_size: int = 10000
num_arrays: int = 10
array_size: int = Field(
10000, description="Size of an array to send (number of values) via socket."
)
num_arrays: int = Field(10, description="Number of arrays to send via socket.")


class TestWriteOutputParameters(TaskParameters):
Expand All @@ -83,8 +98,9 @@ class TestReadOutputParameters(TaskParameters):
@validator("in_file", always=True)
def validate_in_file(cls, in_file: str, values: Dict[str, Any]) -> str:
if in_file == "":
filename: str = read_latest_db_entry(
filename: Optional[str] = read_latest_db_entry(
f"{values['lute_config'].work_dir}", "TestWriteOutput", "outfile_name"
)
in_file: str = f"{values['lute_config'].work_dir}/{filename}"
if filename is not None:
return f"{values['lute_config'].work_dir}/{filename}"
return in_file

0 comments on commit 301bd92

Please sign in to comment.