Skip to content

Commit

Permalink
Fix script decorator typing
Browse files Browse the repository at this point in the history
* Script decorator still suggests Script kwargs
* Function signature is now left intact when called elsewhere
* Still need to fix call site typing to use Step/Task type hints

Signed-off-by: Elliot Gunton <[email protected]>
  • Loading branch information
elliotgunton committed Apr 4, 2024
1 parent 91df4de commit 83dd240
Showing 1 changed file with 32 additions and 41 deletions.
73 changes: 32 additions & 41 deletions src/hera/workflows/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
TypeVar,
Union,
cast,
overload,
)

from typing_extensions import ParamSpec, get_args, get_origin
Expand Down Expand Up @@ -576,40 +575,27 @@ def _output_annotations_used(source: Callable) -> bool:
FuncIns = ParamSpec("FuncIns") # For input types of given func to script decorator
FuncR = TypeVar("FuncR") # For return type of given func to script decorator
ScriptIns = ParamSpec("ScriptIns") # For attribute types of Script
# SubNodeIns = ParamSpec("SubNodeIns") # For attribute types of TemplateInvocatorSubNodeMixin (Step and Task)
StepIns = ParamSpec("StepIns") # For attribute types of Step
TaskIns = ParamSpec("TaskIns") # For attribute types of Task


def _take_annotation_from(
_: Callable[
ScriptIns,
Callable[[Callable[FuncIns, FuncR]], Union[Callable[FuncIns, FuncR], Callable[ScriptIns, Union[Task, Step]]]],
],
) -> Callable[
[Callable],
Callable[
ScriptIns,
Callable[[Callable[FuncIns, FuncR]], Union[Callable[FuncIns, FuncR], Callable[ScriptIns, Union[Task, Step]]]],
],
]:
def decorator(
real_function: Callable,
) -> Callable[
ScriptIns,
Callable[[Callable[FuncIns, FuncR]], Union[Callable[FuncIns, FuncR], Callable[ScriptIns, Union[Task, Step]]]],
]:
def new_function(
*args: ScriptIns.args, **kwargs: ScriptIns.kwargs
) -> Callable[
[Callable[FuncIns, FuncR]], Union[Callable[FuncIns, FuncR], Callable[ScriptIns, Union[Task, Step]]]
]:
return real_function(*args, **kwargs)

return new_function

return decorator


@_take_annotation_from(Script) # type: ignore
def script(**script_kwargs):
def _use_type_hints_from(
_type_hints_to_use: Callable[ScriptIns, None],
) -> Callable[[Callable[ScriptIns, None]], Callable[ScriptIns, None]]:
"""Pass in a Pydantic class or function to _type_hints_to_use to make its type hints accessible to decorated function.
Accessible via ScriptIns.args/ScriptIns.kwargs.
"""

def dummy_decorator(function: Callable[ScriptIns, None]) -> Callable[ScriptIns, None]:
return function

return dummy_decorator


@_use_type_hints_from(Script)
def script(**script_kwargs: ScriptIns.kwargs):
"""A decorator that wraps a function into a Script object.
Using this decorator users can define a function that will be executed as a script in a container. Once the
Expand All @@ -630,8 +616,8 @@ def script(**script_kwargs):

def script_wrapper(
func: Callable[FuncIns, FuncR],
) -> Union[Callable[FuncIns, FuncR], Callable[ScriptIns, Union[Task, Step]]]:
"""Wraps the given callable into a `Script` object that can be invoked.
) -> Union[Callable[FuncIns, FuncR], Callable[StepIns, Step], Callable[TaskIns, Task]]:
"""Wraps the given callable so it can be invoked as a Step or Task.
Parameters
----------
Expand Down Expand Up @@ -662,16 +648,21 @@ def script_wrapper(

s = Script(name=name, source=source, **script_kwargs)

@overload
def task_wrapper(*args: FuncIns.args, **kwargs: FuncIns.kwargs) -> FuncR: # type: ignore
...
# if TYPE_CHECKING:
# @overload
# def task_wrapper(*args: FuncIns.args, **kwargs: FuncIns.kwargs) -> FuncR: ...

# @overload
# @_use_type_hints_from(Step)
# def task_wrapper(*args: StepIns.args, **kwargs: StepIns.kwargs) -> Step: ...

@overload
def task_wrapper(*args: ScriptIns.args, **kwargs: ScriptIns.kwargs) -> Union[Step, Task]: ...
# @overload
# @_use_type_hints_from(Task)
# def task_wrapper(*args: TaskIns.args, **kwargs: TaskIns.kwargs) -> Task: ...

@wraps(func)
def task_wrapper(*args, **kwargs):
"""Invokes a `Script` object's `__call__` method using the given `task_params`."""
"""Invokes a `Script` object's `__call__` method using the given SubNode (Step or Task) args/kwargs."""
if _context.active:
return s.__call__(*args, **kwargs)
return func(*args, **kwargs)
Expand Down

0 comments on commit 83dd240

Please sign in to comment.