Skip to content

Commit

Permalink
/* PR_START p--query-resolution-perf 02 */ Add an interface to overri…
Browse files Browse the repository at this point in the history
…de pretty-print behavior.
  • Loading branch information
plypaul committed Jul 11, 2024
1 parent 2a5985d commit f29f6f6
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Optional


class MetricFlowPrettyFormattable(ABC):
"""Changes behavior for pretty-formatting using `MetricFlowPrettyFormatter`.
This interface is pending updates to allow for additional configuration and structured return types.
"""

@property
@abstractmethod
def pretty_format(self) -> Optional[str]:
"""Return the pretty-formatted version of this object, or None if the default approach should be used."""
raise NotImplementedError
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from dsi_pydantic_shim import BaseModel

from metricflow_semantics.mf_logging.formatting import indent
from metricflow_semantics.mf_logging.pretty_formattable import MetricFlowPrettyFormattable

logger = logging.getLogger(__name__)

Expand All @@ -27,7 +28,7 @@ def __init__(
) -> None:
"""See mf_pformat() for argument descriptions."""
self._indent_prefix = indent_prefix
if max_line_length <= 0:
if not max_line_length > 0:
raise ValueError(f"max_line_length must be > 0 as required by pprint.pformat(). Got {max_line_length}")
self._max_line_width = max_line_length
self._include_object_field_names = include_object_field_names
Expand Down Expand Up @@ -336,6 +337,10 @@ def _handle_any_obj(self, obj: Any, remaining_line_width: Optional[int]) -> str:
remaining_line_width=remaining_line_width,
)

if isinstance(obj, MetricFlowPrettyFormattable):
if obj.pretty_format is not None:
return obj.pretty_format

if is_dataclass(obj):
# dataclasses.asdict() seems to exclude None fields, so doing this instead.
mapping = {field.name: getattr(obj, field.name) for field in fields(obj)}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@

import logging
import textwrap
from dataclasses import dataclass
from typing import Optional

from dbt_semantic_interfaces.implementations.elements.dimension import PydanticDimension
from dbt_semantic_interfaces.type_enums import DimensionType
from metricflow_semantics.mf_logging.formatting import indent
from metricflow_semantics.mf_logging.pretty_formattable import MetricFlowPrettyFormattable
from metricflow_semantics.mf_logging.pretty_print import mf_pformat, mf_pformat_many
from metricflow_semantics.test_helpers.metric_time_dimension import MTD_SPEC_DAY
from typing_extensions import override

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -179,3 +183,19 @@ def test_pformat_many_with_strings() -> None: # noqa: D103
).rstrip()
== result
)


def test_custom_pretty_print() -> None:
"""Check that `MetricFlowPrettyFormattable` can be used to override the result when using MF's pretty-printer."""

@dataclass(frozen=True)
class _ExampleDataclass(MetricFlowPrettyFormattable):
field_0: float

@property
@override
def pretty_format(self) -> Optional[str]:
"""Only show 2 decimal points when pretty printing."""
return f"{self.__class__.__name__}({self.field_0:.2f})"

assert mf_pformat(_ExampleDataclass(1.2345)) == f"{_ExampleDataclass.__name__}(1.23)"

0 comments on commit f29f6f6

Please sign in to comment.