Skip to content

Commit

Permalink
feat: add ResultType for validation of operator output
Browse files Browse the repository at this point in the history
  • Loading branch information
kreczko committed Oct 15, 2024
1 parent 97036cb commit 8af41f4
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 8 deletions.
20 changes: 20 additions & 0 deletions src/fasthep_flow/operators/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Definition of the Operator protocol."""

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Protocol


Expand All @@ -18,3 +20,21 @@ def __repr__(self) -> str:

def configure(self, **kwargs: Any) -> None:
"""General function to configure the operator."""


@dataclass
class ResultType:
"""The result type of an operator. Can add validation here if needed."""

result: Any
stdout: str
stderr: str
exit_code: int

def to_dict(self) -> dict[str, Any]:
return {
"result": self.result,
"stdout": self.stdout,
"stderr": self.stderr,
"exit_code": self.exit_code,
}
6 changes: 4 additions & 2 deletions src/fasthep_flow/operators/bash.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from typing import Any

from .base import Operator
from .base import Operator, ResultType

try:
# try to import plumbum
Expand Down Expand Up @@ -31,7 +31,9 @@ def configure(self, **kwargs: Any) -> None:
def __call__(self, **kwargs: Any) -> dict[str, Any]:
command = plumbum.local[self.bash_command]
exit_code, stdout, stderr = command.run(*self.arguments)
return {"stdout": stdout, "stderr": stderr, "exit_code": exit_code}
return ResultType(
result=None, stdout=stdout, stderr=stderr, exit_code=exit_code
).to_dict()

def __repr__(self) -> str:
return f'LocalBashOperator(bash_command="{self.bash_command}", arguments={self.arguments})'
Expand Down
14 changes: 8 additions & 6 deletions src/fasthep_flow/operators/py_call.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Python related operators."""

from __future__ import annotations

import io
from collections.abc import Callable
from contextlib import redirect_stderr, redirect_stdout
from typing import Any

from .base import Operator
from .base import Operator, ResultType


class PythonOperator(Operator):
Expand All @@ -26,11 +27,12 @@ def __call__(self, **kwargs: Any) -> dict[str, Any]:
with redirect_stdout(stdout), redirect_stderr(stderr):
result = self.python_callable(*self.arguments)
result = self.python_callable(*self.arguments)
return {
"result": result,
"stdout": stdout.getvalue(),
"stderr": stderr.getvalue(),
}
return ResultType(
result=result,
stdout=stdout.getvalue(),
stderr=stderr.getvalue(),
exit_code=0,
).to_dict()

def __repr__(self) -> str:
return f"PythonOperator(python_callable={self.python_callable}, arguments={self.arguments})"

0 comments on commit 8af41f4

Please sign in to comment.