diff --git a/metricflow-semantics/metricflow_semantics/mf_logging/pretty_formattable.py b/metricflow-semantics/metricflow_semantics/mf_logging/pretty_formattable.py new file mode 100644 index 0000000000..d4ebcda88f --- /dev/null +++ b/metricflow-semantics/metricflow_semantics/mf_logging/pretty_formattable.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Optional + + +class MetricFlowPrettyFormattable(ABC): + """Changes behavior for pretty-formatting using `MetricFlowPrettyFormatter`. + + This interface is pending updates to allow for additional configuration and structured return types. + """ + + @property + @abstractmethod + def pretty_format(self) -> Optional[str]: + """Return the pretty-formatted version of this object, or None if the default approach should be used.""" + raise NotImplementedError diff --git a/metricflow-semantics/metricflow_semantics/mf_logging/pretty_print.py b/metricflow-semantics/metricflow_semantics/mf_logging/pretty_print.py index d29472c577..69277b278a 100644 --- a/metricflow-semantics/metricflow_semantics/mf_logging/pretty_print.py +++ b/metricflow-semantics/metricflow_semantics/mf_logging/pretty_print.py @@ -10,6 +10,7 @@ from dsi_pydantic_shim import BaseModel from metricflow_semantics.mf_logging.formatting import indent +from metricflow_semantics.mf_logging.pretty_formattable import MetricFlowPrettyFormattable logger = logging.getLogger(__name__) @@ -27,7 +28,7 @@ def __init__( ) -> None: """See mf_pformat() for argument descriptions.""" self._indent_prefix = indent_prefix - if max_line_length <= 0: + if not max_line_length > 0: raise ValueError(f"max_line_length must be > 0 as required by pprint.pformat(). Got {max_line_length}") self._max_line_width = max_line_length self._include_object_field_names = include_object_field_names @@ -336,6 +337,10 @@ def _handle_any_obj(self, obj: Any, remaining_line_width: Optional[int]) -> str: remaining_line_width=remaining_line_width, ) + if isinstance(obj, MetricFlowPrettyFormattable): + if obj.pretty_format is not None: + return obj.pretty_format + if is_dataclass(obj): # dataclasses.asdict() seems to exclude None fields, so doing this instead. mapping = {field.name: getattr(obj, field.name) for field in fields(obj)} diff --git a/metricflow-semantics/tests_metricflow_semantics/collection_helpers/test_pretty_print.py b/metricflow-semantics/tests_metricflow_semantics/collection_helpers/test_pretty_print.py index 462a7d4f7b..364032c0ed 100644 --- a/metricflow-semantics/tests_metricflow_semantics/collection_helpers/test_pretty_print.py +++ b/metricflow-semantics/tests_metricflow_semantics/collection_helpers/test_pretty_print.py @@ -2,12 +2,16 @@ import logging import textwrap +from dataclasses import dataclass +from typing import Optional from dbt_semantic_interfaces.implementations.elements.dimension import PydanticDimension from dbt_semantic_interfaces.type_enums import DimensionType from metricflow_semantics.mf_logging.formatting import indent +from metricflow_semantics.mf_logging.pretty_formattable import MetricFlowPrettyFormattable from metricflow_semantics.mf_logging.pretty_print import mf_pformat, mf_pformat_many from metricflow_semantics.test_helpers.metric_time_dimension import MTD_SPEC_DAY +from typing_extensions import override logger = logging.getLogger(__name__) @@ -179,3 +183,19 @@ def test_pformat_many_with_strings() -> None: # noqa: D103 ).rstrip() == result ) + + +def test_custom_pretty_print() -> None: + """Check that `MetricFlowPrettyFormattable` can be used to override the result when using MF's pretty-printer.""" + + @dataclass(frozen=True) + class _ExampleDataclass(MetricFlowPrettyFormattable): + field_0: float + + @property + @override + def pretty_format(self) -> Optional[str]: + """Only show 2 decimal points when pretty printing.""" + return f"{self.__class__.__name__}({self.field_0:.2f})" + + assert mf_pformat(_ExampleDataclass(1.2345)) == f"{_ExampleDataclass.__name__}(1.23)"