diff --git a/dbt_semantic_interfaces/dataclass_serialization.py b/dbt_semantic_interfaces/dataclass_serialization.py index 4bed70aa..a8bbacca 100644 --- a/dbt_semantic_interfaces/dataclass_serialization.py +++ b/dbt_semantic_interfaces/dataclass_serialization.py @@ -2,14 +2,19 @@ import dataclasses import datetime +import inspect import logging +from abc import ABC from builtins import NameError from dataclasses import dataclass from enum import Enum from typing import ( Any, + ClassVar, Dict, Optional, + Sequence, + Set, Tuple, Type, TypeVar, @@ -141,25 +146,47 @@ def _get_type_parameter_for_sequence_like_tuple_type(field_type: Type) -> Type: return args[0] -class SerializableDataclass: - """Describes a dataclass that can be serialized using DataclassSerializer. +class SerializableDataclass(ABC): + """Describes a dataclass that can be serialized using `DataclassSerializer`. - Previously, Pydnatic has been used for defining objects as it provides built in support for serialization and + Previously, Pydantic has been used for defining objects as it provides built in support for serialization and deserialization. However, Pydantic object is slow compared to dataclass initialization, with tests showing 10x-100x slower performance. This is an issue if many objects are created, which can happen in during plan generation. Using - the BaseModel.construct() is still not as fast as dataclass initiaization and it also makes for an awkward developer - interface. Because of this, MF implements a simple custom serializer / deserializer to work with the built-in - Python dataclass. + the BaseModel.construct() is still not as fast as dataclass initialization, and it also makes for an awkward + developer interface. Because of this, MF implements a simple custom serializer / deserializer to work with the + built-in Python dataclass. The dataclass must have concrete types for all fields and not all types are supported. Please see implementation details in DataclassSerializer. Not adding post_init checks as there have been previous issues with slow object initialization. - - This is a concrete object as MyPy currently throws a type error if a Python dataclass is defined with an abstract - parent class. """ - pass + # Contains all known implementing subclasses. + _concrete_subclass_registry: ClassVar[Optional[Set[Type[SerializableDataclass]]]] = None + + @classmethod + def concrete_subclasses_for_testing(cls) -> Sequence[Type[SerializableDataclass]]: + """Returns subclasses that implement this interface. + + This is intended to be used in tests to verify the ability to serialize the class. + """ + return sorted( + cls._concrete_subclass_registry or (), key=lambda class_type: (class_type.__module__, class_type.__name__) + ) + + def __init_subclass__(cls, **kwargs) -> None: + """Adds the implementing class to the registry and check for non-concrete fields. + + It would be helpful to check that the fields of the dataclass are concrete fields, but that would need to be + done after class initialization, and checking in `__post_init__` adds significant overhead. + """ + super().__init_subclass__(**kwargs) + + if SerializableDataclass._concrete_subclass_registry is None: + SerializableDataclass._concrete_subclass_registry = set() + + if not inspect.isabstract(cls): + SerializableDataclass._concrete_subclass_registry.add(cls) SerializableDataclassT = TypeVar("SerializableDataclassT", bound=SerializableDataclass) diff --git a/dbt_semantic_interfaces/references.py b/dbt_semantic_interfaces/references.py index 1999e77b..f885656c 100644 --- a/dbt_semantic_interfaces/references.py +++ b/dbt_semantic_interfaces/references.py @@ -67,6 +67,7 @@ class GroupByMetricReference(LinkableElementReference): pass +@dataclass(frozen=True, order=True) class ModelReference(SerializableDataclass): """A reference to something in the model. diff --git a/dbt_semantic_interfaces/test_helpers/__init__.py b/dbt_semantic_interfaces/test_helpers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dbt_semantic_interfaces/test_helpers/dataclass_serialization.py b/dbt_semantic_interfaces/test_helpers/dataclass_serialization.py new file mode 100644 index 00000000..fe5d9e2b --- /dev/null +++ b/dbt_semantic_interfaces/test_helpers/dataclass_serialization.py @@ -0,0 +1,38 @@ +from typing import Iterable, Sequence, Type + +from dbt_semantic_interfaces.dataclass_serialization import ( + DataClassDeserializer, + DataclassSerializer, + SerializableDataclass, +) + + +def assert_includes_all_serializable_dataclass_types( + instances: Sequence[SerializableDataclass], excluded_classes: Iterable[Type[SerializableDataclass]] +) -> None: + """Verify that the given instances include at least one instance of the known subclasses.""" + instance_types = {type(instance) for instance in instances} + missing_instance_types = ( + set(SerializableDataclass.concrete_subclasses_for_testing()) + .difference(instance_types) + .difference(excluded_classes) + ) + missing_type_names = sorted(instance_type.__name__ for instance_type in missing_instance_types) + assert ( + len(missing_type_names) == 0 + ), f"Missing instances of the following classes: {missing_type_names}. Please add them." + + +def assert_serializable(instances: Sequence[SerializableDataclass]) -> None: + """Verify that the given instances are actually serializable.""" + serializer = DataclassSerializer() + deserializer = DataClassDeserializer() + + for instance in instances: + try: + serialized_output = serializer.pydantic_serialize(instance) + deserialized_instance = deserializer.pydantic_deserialize(type(instance), serialized_output) + except Exception as e: + raise AssertionError(f"Error serializing {instance=}") from e + + assert instance == deserialized_instance diff --git a/tests/serialization/__init__.py b/tests/serialization/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/serialization/test_serializable_dataclass_subclasses.py b/tests/serialization/test_serializable_dataclass_subclasses.py new file mode 100644 index 00000000..5b6e9fb4 --- /dev/null +++ b/tests/serialization/test_serializable_dataclass_subclasses.py @@ -0,0 +1,76 @@ +import itertools +import logging + +from dbt_semantic_interfaces.references import ( + DimensionReference, + ElementReference, + EntityReference, + GroupByMetricReference, + LinkableElementReference, + MeasureReference, + MetricModelReference, + MetricReference, + ModelReference, + SemanticModelElementReference, + SemanticModelReference, + TimeDimensionReference, +) +from dbt_semantic_interfaces.test_helpers.dataclass_serialization import ( + assert_includes_all_serializable_dataclass_types, + assert_serializable, +) +from tests.test_dataclass_serialization import ( + DataclassWithDataclassDefault, + DataclassWithDefaultTuple, + DataclassWithOptional, + DataclassWithPrimitiveTypes, + DataclassWithTuple, + DeeplyNestedDataclass, + NestedDataclass, + NestedDataclassWithProtocol, + SimpleClassWithProtocol, + SimpleDataclass, +) + +logger = logging.getLogger(__name__) + + +def test_serializable_dataclass_subclasses() -> None: + """Verify that all subclasses of `SerializableDataclass` are serializable.""" + counter = itertools.count(start=0) + + def _get_next_field_str() -> str: + return f"field_{next(counter)}" + + instances = [ + LinkableElementReference(_get_next_field_str()), + ElementReference(_get_next_field_str()), + SemanticModelElementReference(_get_next_field_str(), _get_next_field_str()), + EntityReference(_get_next_field_str()), + SemanticModelReference(_get_next_field_str()), + TimeDimensionReference(_get_next_field_str()), + MetricReference(_get_next_field_str()), + GroupByMetricReference(_get_next_field_str()), + MetricModelReference(_get_next_field_str()), + DimensionReference(_get_next_field_str()), + MeasureReference(_get_next_field_str()), + ModelReference(), + ] + + assert_includes_all_serializable_dataclass_types( + instances=instances, + # These are classes defined and used in a separate test. + excluded_classes=[ + DataclassWithDataclassDefault, + DataclassWithDefaultTuple, + DataclassWithOptional, + DataclassWithPrimitiveTypes, + DataclassWithTuple, + DeeplyNestedDataclass, + NestedDataclass, + NestedDataclassWithProtocol, + SimpleClassWithProtocol, + SimpleDataclass, + ], + ) + assert_serializable(instances)