Skip to content

Commit

Permalink
Made kwargs default value persistent in the Hub
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinmessiaen committed Oct 26, 2023
1 parent 6edf7bc commit 0ce3c85
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 37 deletions.
71 changes: 40 additions & 31 deletions giskard/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,26 +168,27 @@ def __init__(
self.tags = self.populate_tags(tags)

parameters = self.extract_parameters(callable_obj)
for param in parameters:
param.default = serialize_parameter(param.default)

self.args = {
parameter.name: FunctionArgument(
name=parameter.name,
type=extract_optional(parameter.annotation).__qualname__,
optional=parameter.default != inspect.Parameter.empty,
default=serialize_parameter(parameter.default),
argOrder=idx,
)
for idx, parameter in enumerate(parameters.values())
if name != "self"
}
self.args = {param.name: param for param in parameters}

def extract_parameters(self, callable_obj):
def extract_parameters(self, callable_obj) -> List[FunctionArgument]:
if inspect.isclass(callable_obj):
parameters = list(inspect.signature(callable_obj.__init__).parameters.values())[1:]
else:
parameters = list(inspect.signature(callable_obj).parameters.values())

return parameters
return [
FunctionArgument(
name=parameter.name,
type=extract_optional(parameter.annotation).__qualname__,
optional=parameter.default != inspect.Parameter.empty,
default=parameter.default,
argOrder=idx,
)
for idx, parameter in enumerate(parameters)
]

@staticmethod
def extract_module_doc(func_doc):
Expand Down Expand Up @@ -292,10 +293,8 @@ def __init__(
super().__init__(callable_obj, name, tags, version, type)
self.debug_description = debug_description

def extract_parameters(self, callable_obj):
parameters = unknown_annotations_to_kwargs(CallableMeta.extract_parameters(self, callable_obj))

return {p.name: p for p in parameters}
def extract_parameters(self, callable_obj) -> List[FunctionArgument]:
return unknown_annotations_to_kwargs(CallableMeta.extract_parameters(self, callable_obj))

def to_json(self):
json = super().to_json()
Expand Down Expand Up @@ -345,10 +344,8 @@ def __init__(
else:
self.column_type = None

def extract_parameters(self, callable_obj):
parameters = unknown_annotations_to_kwargs(CallableMeta.extract_parameters(self, callable_obj)[1:])

return {p.name: p for p in parameters}
def extract_parameters(self, callable_obj) -> List[FunctionArgument]:
return unknown_annotations_to_kwargs(CallableMeta.extract_parameters(self, callable_obj)[1:])

def to_json(self):
json = super().to_json()
Expand All @@ -372,25 +369,37 @@ def init_from_json(self, json: Dict[str, Any]):
SMT = TypeVar("SMT", bound=SavableMeta)


def unknown_annotations_to_kwargs(parameters: List[inspect.Parameter]) -> List[inspect.Parameter]:
def unknown_annotations_to_kwargs(parameters: List[FunctionArgument]) -> List[FunctionArgument]:
from giskard.models.base import BaseModel
from giskard.datasets.base import Dataset
from giskard.ml_worker.testing.registry.slicing_function import SlicingFunction
from giskard.ml_worker.testing.registry.transformation_function import TransformationFunction

allowed_types = [str, bool, int, float, BaseModel, Dataset, SlicingFunction, TransformationFunction]
allowed_types = allowed_types + list(map(lambda x: Optional[x], allowed_types))
allowed_types = list(map(lambda x: x.__qualname__, allowed_types))

has_kwargs = any(
[param for param in parameters if not any([param.annotation == allowed_type for allowed_type in allowed_types])]
)
kwargs = [param for param in parameters if not any([param.type == allowed_type for allowed_type in allowed_types])]

parameters = [
param for param in parameters if any([param.annotation == allowed_type for allowed_type in allowed_types])
]
parameters = [param for param in parameters if any([param.type == allowed_type for allowed_type in allowed_types])]

if has_kwargs:
parameters.append(inspect.Parameter(name="kwargs", kind=4, annotation=Kwargs))
for idx, parameter in enumerate(parameters):
parameter.argOrder = idx

if any(kwargs) > 0:
kwargs_with_default = [param for param in kwargs if param.default != inspect.Parameter.empty]
default_value = (
dict({param.name: param.default for param in kwargs_with_default}) if any(kwargs_with_default) else None
)

parameters.append(
FunctionArgument(
name="kwargs",
type="Kwargs",
default=default_value,
optional=len(kwargs_with_default) == len(kwargs),
argOrder=len(parameters),
)
)

return parameters

Expand Down
10 changes: 5 additions & 5 deletions giskard/testing/tests/performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
from giskard.datasets.base import Dataset
from giskard.ml_worker.testing.registry.slicing_function import SlicingFunction
from giskard.ml_worker.testing.test_result import TestResult
from giskard.testing.tests.debug_slicing_functions import incorrect_rows_slicing_fn, nlargest_abs_err_rows_slicing_fn
from giskard.ml_worker.testing.utils import Direction
from giskard.ml_worker.testing.utils import check_slice_not_empty
from giskard.models.base import BaseModel
from giskard.models.utils import np_type_to_native
from . import debug_prefix, debug_description_prefix
from giskard.testing.tests.debug_slicing_functions import incorrect_rows_slicing_fn, nlargest_abs_err_rows_slicing_fn
from . import debug_description_prefix, debug_prefix


def _verify_target_availability(dataset):
Expand Down Expand Up @@ -149,11 +149,11 @@ def _test_diff_prediction(
" reference_dataset is equal to zero"
)

if direction == Direction.Invariant:
if direction == Direction.Invariant or direction == Direction.Invariant.value:
passed = abs(rel_change) < threshold
elif direction == Direction.Decreasing:
elif direction == Direction.Decreasing or direction == Direction.Decreasing.value:
passed = rel_change < threshold
elif direction == Direction.Increasing:
elif direction == Direction.Increasing or direction == Direction.Increasing.value:
passed = rel_change > threshold
else:
raise ValueError(f"Invalid direction: {direction}")
Expand Down
13 changes: 12 additions & 1 deletion giskard/utils/artifacts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import inspect
import uuid
from typing import Any, Optional, Union
from enum import Enum
from typing import Any, Optional, Union, Dict

try:
from types import NoneType
Expand All @@ -18,13 +19,23 @@ def _serialize_artifact(artifact, artifact_uuid: Optional[Union[str, uuid.UUID]]
return str(artifact_uuid)


def repr_parameter(value: Any) -> str:
if isinstance(value, Enum):
return repr(value.value)

return repr(value)


def serialize_parameter(default_value: Any) -> PRIMITIVES:
if default_value == inspect.Parameter.empty:
return None

if isinstance(default_value, PRIMITIVES.__args__):
return default_value

if isinstance(default_value, Dict):
return "\n".join(f"kwargs[{repr(key)}] = {repr_parameter(value)}" for key, value in default_value.items())

from ..ml_worker.core.savable import Artifact

if isinstance(default_value, Artifact):
Expand Down

0 comments on commit 0ce3c85

Please sign in to comment.