Skip to content

Commit

Permalink
Migrate to model_validator from deprecated __get_validators__
Browse files Browse the repository at this point in the history
  • Loading branch information
esciara committed Nov 25, 2023
1 parent cab6d6d commit 77c7ebb
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 55 deletions.
26 changes: 12 additions & 14 deletions dbt_semantic_interfaces/implementations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import os
from abc import ABC, abstractmethod
from typing import Any, Callable, ClassVar, Generator, Generic, Type, TypeVar
from typing import Any, ClassVar, Dict, Generic, Type, TypeVar

from pydantic import BaseModel, ConfigDict, model_validator

Expand Down Expand Up @@ -119,23 +119,20 @@ class PydanticCustomInputParser(ABC, Generic[ModelObjectT_co]):
and validation of that model object itself.
"""

@model_validator(mode="before")
@classmethod
def __get_validators__(
cls: Type[PydanticCustomInputParser[ModelObjectT_co]],
) -> Generator[Callable[[PydanticParseableValueType], PydanticCustomInputParser[ModelObjectT_co]], None, None]:
def _model_validator(
cls: Type[PydanticCustomInputParser[ModelObjectT_co]], input: PydanticParseableValueType
) -> Dict[str, Any]:
"""Pydantic magic method for allowing parsing of arbitrary input on validate_model invocation.
This allows for parsing and validation prior to object initialization. Most classes implementing this
interface in our model are doing so because the input value from user-supplied YAML will be a string
representation rather than the structured object type.
"""
yield cls.__parse_with_custom_handling
@classmethod
def __parse_with_custom_handling(
cls: Type[PydanticCustomInputParser[ModelObjectT_co]], input: PydanticParseableValueType
) -> PydanticCustomInputParser[ModelObjectT_co]:
"""Core method for handling common valid - or easily validated - input types.
the previous and next docstrings were from two different methods, which have been combined here.
Core method for handling common valid - or easily validated - input types.
Pydantic objects can commonly appear as JSON object types (from, e.g., deserializing a Pydantic-serialized
model) or direct instances of the model object class (from, e.g., initializing an object and passing it in
Expand All @@ -149,16 +146,17 @@ def __parse_with_custom_handling(
to the caller to be pre-validated, and so we do not bother guarding against that here.
"""
if isinstance(input, dict):
return cls(**input) # type: ignore
elif isinstance(input, cls):
return input
elif isinstance(input, cls):
# TODO: find a better way to avoid mypy type ignore
return input.model_dump() # type: ignore[attr-defined]
else:
return cls._from_yaml_value(input)

@classmethod
@abstractmethod
def _from_yaml_value(
cls: Type[PydanticCustomInputParser[ModelObjectT_co]], input: PydanticParseableValueType
) -> PydanticCustomInputParser[ModelObjectT_co]:
) -> Dict[str, Any]:
"""Abstract method for providing object-specific parsing logic."""
raise NotImplementedError()
31 changes: 10 additions & 21 deletions dbt_semantic_interfaces/implementations/filters/where_filter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

from typing import Callable, Generator, List, Tuple
from typing import Any, Dict, List, Tuple

from typing_extensions import Self
from pydantic import model_validator

from dbt_semantic_interfaces.call_parameter_sets import (
FilterCallParameterSets,
Expand Down Expand Up @@ -34,17 +34,14 @@ class PydanticWhereFilter(PydanticCustomInputParser, HashableBaseModel):
where_sql_template: str

@classmethod
def _from_yaml_value(
cls,
input: PydanticParseableValueType,
) -> PydanticWhereFilter:
def _from_yaml_value(cls, input: PydanticParseableValueType) -> Dict[str, Any]:
"""Parses a WhereFilter from a string found in a user-provided model specification.
User-provided constraint strings are SQL snippets conforming to the expectations of SQL WHERE clauses,
and as such we parse them using our standard parse method below.
"""
if isinstance(input, str):
return PydanticWhereFilter(where_sql_template=input)
return {"where_sql_template": input}
else:
raise ValueError(f"Expected input to be of type string, but got type {type(input)} with value: {input}")

Expand All @@ -62,17 +59,9 @@ class PydanticWhereFilterIntersection(HashableBaseModel):

where_filters: List[PydanticWhereFilter]

@model_validator(mode="before")
@classmethod
def __get_validators__(cls) -> Generator[Callable[[PydanticParseableValueType], Self], None, None]:
"""Pydantic magic method for allowing handling of arbitrary input on model_validate invocation.
This class requires more subtle handling of input deserialized object types (dicts), and so it cannot
extend the common interface via _from_yaml_values.
"""
yield cls._convert_legacy_and_yaml_input

@classmethod
def _convert_legacy_and_yaml_input(cls, input: PydanticParseableValueType) -> Self:
def _convert_legacy_and_yaml_input(cls, input: PydanticParseableValueType) -> Dict[str, Any]:
"""Specifies raw input conversion rules to ensure serialized semantic manifests will parse correctly.
The original spec for where filters relied on a raw WhereFilter object, but this has now been updated to
Expand Down Expand Up @@ -101,13 +90,13 @@ def _convert_legacy_and_yaml_input(cls, input: PydanticParseableValueType) -> Se
is_legacy_where_filter = isinstance(input, str) or isinstance(input, PydanticWhereFilter) or has_legacy_keys

if is_legacy_where_filter:
return cls(where_filters=[input])
return {"where_filters": [input]}
elif isinstance(input, list):
return cls(where_filters=input)
return {"where_filters": input}
elif isinstance(input, dict):
return cls(**input)
elif isinstance(input, cls):
return input
elif isinstance(input, cls):
return input.model_dump()
else:
raise ValueError(
f"Expected input to be of type string, list, PydanticWhereFilter, PydanticWhereFilterIntersection, "
Expand Down
30 changes: 13 additions & 17 deletions dbt_semantic_interfaces/implementations/metric.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import List, Optional, Sequence
from typing import Any, Dict, List, Optional, Sequence

from pydantic import Field

Expand Down Expand Up @@ -34,15 +34,15 @@ class PydanticMetricInputMeasure(PydanticCustomInputParser, HashableBaseModel):
fill_nulls_with: Optional[int] = None

@classmethod
def _from_yaml_value(cls, input: PydanticParseableValueType) -> PydanticMetricInputMeasure:
def _from_yaml_value(cls, input: PydanticParseableValueType) -> Dict[str, Any]:
"""Parses a MetricInputMeasure from a string (name only) or object (struct spec) input.
For user input cases, the original YAML spec for a PydanticMetric included measure(s) specified as string names
or lists of string names. As such, configs pre-dating the addition of this model type will only provide the
base name for this object.
"""
if isinstance(input, str):
return PydanticMetricInputMeasure(name=input)
return {"name": input}
else:
raise ValueError(
f"MetricInputMeasure inputs from model configs are expected to be of either type string or "
Expand All @@ -67,26 +67,22 @@ class PydanticMetricTimeWindow(PydanticCustomInputParser, HashableBaseModel):
granularity: TimeGranularity

@classmethod
def _from_yaml_value(cls, input: PydanticParseableValueType) -> PydanticMetricTimeWindow:
def _from_yaml_value(cls, input: PydanticParseableValueType) -> Dict[str, Any]:
"""Parses a MetricTimeWindow from a string input found in a user provided model specification.
The MetricTimeWindow is always expected to be provided as a string in user-defined YAML configs.
Output of the form: (<time unit count>, <time granularity>, <error message>) - error message is None if window
is formatted properly
"""
if isinstance(input, str):
return PydanticMetricTimeWindow.parse(input)
else:
if not isinstance(input, str):
raise ValueError(
f"MetricTimeWindow inputs from model configs are expected to always be of type string, but got "
f"type {type(input)} with value: {input}"
)

@staticmethod
def parse(window: str) -> PydanticMetricTimeWindow:
"""Returns window values if parsing succeeds, None otherwise.
window = input

Output of the form: (<time unit count>, <time granularity>, <error message>) - error message is None if window
is formatted properly
"""
parts = window.split(" ")
if len(parts) != 2:
raise ParsingException(
Expand All @@ -108,10 +104,10 @@ def parse(window: str) -> PydanticMetricTimeWindow:
if not count.isdigit():
raise ParsingException(f"Invalid count ({count}) in cumulative metric window string: ({window})")

return PydanticMetricTimeWindow(
count=int(count),
granularity=TimeGranularity(granularity),
)
return {
"count": int(count),
"granularity": TimeGranularity(granularity),
}


class PydanticMetricInput(HashableBaseModel):
Expand Down
2 changes: 1 addition & 1 deletion dbt_semantic_interfaces/validations/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _validate_cumulative_sum_metric_params(metric: Metric) -> List[ValidationIss
try:
window_str = f"{metric.type_params.window.count} {metric.type_params.window.granularity.value}"
# TODO: Should not call an implementation class.
PydanticMetricTimeWindow.parse(window_str)
PydanticMetricTimeWindow.model_validate(window_str)
except ParsingException as e:
issues.append(
ValidationError(
Expand Down
4 changes: 2 additions & 2 deletions tests/validations/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def test_derived_metric() -> None: # noqa: D
expr="random_metric / random_metric3",
metrics=[
PydanticMetricInput(
name="random_metric", offset_window=PydanticMetricTimeWindow.parse("3 weeks")
name="random_metric", offset_window=PydanticMetricTimeWindow.model_validate("3 weeks")
),
PydanticMetricInput(
name="random_metric", offset_to_grain=TimeGranularity.MONTH, alias="random_metric3"
Expand All @@ -286,7 +286,7 @@ def test_derived_metric() -> None: # noqa: D
metrics=[
PydanticMetricInput(
name="random_metric",
offset_window=PydanticMetricTimeWindow.parse("3 weeks"),
offset_window=PydanticMetricTimeWindow.model_validate("3 weeks"),
offset_to_grain=TimeGranularity.MONTH,
)
],
Expand Down

0 comments on commit 77c7ebb

Please sign in to comment.