Skip to content

Commit

Permalink
fix: Improve linting/editor experience for hera models. (#950)
Browse files Browse the repository at this point in the history
**Pull Request Checklist**
- [ ] Fixes #<!--issue number goes here-->
- [ ] Tests added
- [ ] Documentation/examples added
- [ ] [Good commit messages](https://cbea.ms/git-commit/) and/or PR
title

**Description of PR**
Today, hera models seem to receive no editor/linter support for the
`__init__` arguments of hera models.

This **seems** to be solely from the necessary try/except imports used
to support both pydantic v1 and v2. Both mypy and pyright will bail on
and not yield useful types when you do this. But it can be hacked around
with some **more** conditionals imports 😅

Today:
<img width="1219" alt="Screenshot 2024-01-31 at 2 19 19 PM"
src="https://github.com/argoproj-labs/hera/assets/701548/6bcddeab-58c3-4205-836b-58ec24c23f03">
<img width="367" alt="Screenshot 2024-01-31 at 2 19 27 PM"
src="https://github.com/argoproj-labs/hera/assets/701548/35148eb6-5a08-4987-a215-ee6580b3ea36">

With this change:
<img width="1224" alt="Screenshot 2024-01-31 at 2 18 27 PM"
src="https://github.com/argoproj-labs/hera/assets/701548/76e06cad-ebe0-4f4a-b6be-25bf8c3070c4">
<img width="1341" alt="Screenshot 2024-01-31 at 2 18 36 PM"
src="https://github.com/argoproj-labs/hera/assets/701548/b711663d-5539-4c16-a043-7bcb83ed2ad6">

Signed-off-by: DanCardin <[email protected]>
  • Loading branch information
DanCardin authored Feb 1, 2024
1 parent 77c50e7 commit d0c132e
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 19 deletions.
6 changes: 3 additions & 3 deletions src/hera/shared/_global_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union

from hera.auth import TokenGenerator
from hera.shared._pydantic import BaseModel, root_validator
from hera.shared._pydantic import BaseModel, get_fields, root_validator

TBase = TypeVar("TBase", bound="BaseMixin")
TypeTBase = Type[TBase]
Expand Down Expand Up @@ -119,7 +119,7 @@ def set_class_defaults(self, cls: Type[TBase], **kwargs: Any) -> None:
cls: The class to set defaults for.
kwargs: The default values to set.
"""
invalid_keys = set(kwargs) - set(cls.__fields__)
invalid_keys = set(kwargs) - set(get_fields(cls))
if invalid_keys:
raise ValueError(f"Invalid keys for class {cls}: {invalid_keys}")
self._defaults[cls].update(kwargs)
Expand All @@ -143,7 +143,7 @@ def _init_private_attributes(self):
this method. We also tried other ways including creating a metaclass that invokes hera_init after init,
but that always broke auto-complete for IDEs like VSCode.
"""
super()._init_private_attributes()
super()._init_private_attributes() # type: ignore
self.__hera_init__()

def __hera_init__(self):
Expand Down
24 changes: 21 additions & 3 deletions src/hera/shared/_pydantic.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
"""Module that holds the underlying base Pydantic models for Hera objects."""

from typing import Literal
from typing import TYPE_CHECKING, Any, Dict, Literal, Type

_PYDANTIC_VERSION: Literal[1, 2] = 1
# The pydantic v1 interface is used for both pydantic v1 and v2 in order to support
# users across both versions.

try:
from pydantic.v1 import ( # type: ignore
BaseModel as PydanticBaseModel,
Field,
ValidationError,
root_validator,
Expand All @@ -18,7 +17,6 @@
_PYDANTIC_VERSION = 2
except (ImportError, ModuleNotFoundError):
from pydantic import ( # type: ignore[assignment,no-redef]
BaseModel as PydanticBaseModel,
Field,
ValidationError,
root_validator,
Expand All @@ -28,6 +26,26 @@
_PYDANTIC_VERSION = 1


# TYPE_CHECKING-guarding specifically the `BaseModel` import helps the type checkers
# provide proper type checking to models. Without this, both mypy and pyright lose
# native pydantic hinting for `__init__` arguments.
if TYPE_CHECKING:
from pydantic import BaseModel as PydanticBaseModel
else:
try:
from pydantic.v1 import BaseModel as PydanticBaseModel # type: ignore
except (ImportError, ModuleNotFoundError):
from pydantic import BaseModel as PydanticBaseModel # type: ignore[assignment,no-redef]


def get_fields(cls: Type[PydanticBaseModel]) -> Dict[str, Any]:
"""Centralize access to __fields__."""
try:
return cls.model_fields # type: ignore
except AttributeError:
return cls.__fields__ # type: ignore


__all__ = [
"BaseModel",
"Field",
Expand Down
7 changes: 4 additions & 3 deletions src/hera/workflows/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from typing_extensions import Annotated, get_args, get_origin # type: ignore

from hera.shared import BaseMixin, global_config
from hera.shared._pydantic import BaseModel, root_validator, validator
from hera.shared._pydantic import BaseModel, get_fields, root_validator, validator
from hera.shared.serialization import serialize
from hera.workflows._context import SubNodeMixin, _context
from hera.workflows.artifact import Artifact
Expand Down Expand Up @@ -1211,9 +1211,10 @@ def __init__(self, model_path: str, hera_builder: Optional[Callable] = None):
self.model_path = model_path.split(".")
curr_class: Type[BaseModel] = self._get_model_class()
for key in self.model_path:
if key not in curr_class.__fields__:
fields = get_fields(curr_class)
if key not in fields:
raise ValueError(f"Model key '{key}' does not exist in class {curr_class}")
curr_class = curr_class.__fields__[key].outer_type_
curr_class = fields[key].outer_type_

@classmethod
def _get_model_class(cls) -> Type[BaseModel]:
Expand Down
27 changes: 17 additions & 10 deletions src/hera/workflows/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import ChainMap
from typing import Any, List, Optional, Union

from hera.shared._pydantic import BaseModel
from hera.shared._pydantic import BaseModel, get_fields
from hera.shared.serialization import serialize
from hera.workflows.artifact import Artifact
from hera.workflows.parameter import Parameter
Expand Down Expand Up @@ -32,30 +32,36 @@ def _get_parameters(cls, object_override: "Optional[RunnerInput]" = None) -> Lis
parameters = []
annotations = {k: v for k, v in ChainMap(*(get_annotations(c) for c in cls.__mro__)).items()}

for field in cls.__fields__:
fields = get_fields(cls)
for field in fields:
if get_origin(annotations[field]) is Annotated:
if isinstance(get_args(annotations[field])[1], Parameter):
param = get_args(annotations[field])[1]
if object_override:
param.default = serialize(getattr(object_override, field))
elif cls.__fields__[field].default:
elif fields[field].default:
# Serialize the value (usually done in Parameter's validator)
param.default = serialize(cls.__fields__[field].default)
param.default = serialize(fields[field].default)
parameters.append(param)
else:
# Create a Parameter from basic type annotations
if object_override:
parameters.append(Parameter(name=field, default=serialize(getattr(object_override, field))))
parameters.append(
Parameter(
name=field,
default=serialize(getattr(object_override, field)),
)
)
else:
parameters.append(Parameter(name=field, default=cls.__fields__[field].default))
parameters.append(Parameter(name=field, default=fields[field].default))
return parameters

@classmethod
def _get_artifacts(cls) -> List[Artifact]:
artifacts = []
annotations = {k: v for k, v in ChainMap(*(get_annotations(c) for c in cls.__mro__)).items()}

for field in cls.__fields__:
for field in get_fields(cls):
if get_origin(annotations[field]) is Annotated:
if isinstance(get_args(annotations[field])[1], Artifact):
artifact = get_args(annotations[field])[1]
Expand All @@ -82,15 +88,16 @@ def _get_outputs(cls) -> List[Union[Artifact, Parameter]]:
outputs = []
annotations = {k: v for k, v in ChainMap(*(get_annotations(c) for c in cls.__mro__)).items()}

for field in cls.__fields__:
fields = get_fields(cls)
for field in fields:
if field in {"exit_code", "result"}:
continue
if get_origin(annotations[field]) is Annotated:
if isinstance(get_args(annotations[field])[1], (Parameter, Artifact)):
outputs.append(get_args(annotations[field])[1])
else:
# Create a Parameter from basic type annotations
outputs.append(Parameter(name=field, default=cls.__fields__[field].default))
outputs.append(Parameter(name=field, default=fields[field].default))
return outputs

@classmethod
Expand All @@ -102,4 +109,4 @@ def _get_output(cls, field_name: str) -> Union[Artifact, Parameter]:
return get_args(annotation)[1]

# Create a Parameter from basic type annotations
return Parameter(name=field_name, default=cls.__fields__[field_name].default)
return Parameter(name=field_name, default=get_fields(cls)[field_name].default)

0 comments on commit d0c132e

Please sign in to comment.