Skip to content

Commit

Permalink
Update callsites to cast str to TimeGranularity as custom grain is no…
Browse files Browse the repository at this point in the history
…t supported yet
  • Loading branch information
WilliamDee committed Nov 11, 2024
1 parent 3a40926 commit 7c0f846
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from __future__ import annotations

from typing import Optional

from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity


def error_if_not_standard_grain(input_granularity: str, context: Optional[str] = None) -> TimeGranularity:
"""Cast input grainularity string to TimeGranularity, otherwise error.
TODO: Not needed once, custom grain is supported for most things.
"""
try:
time_grain = TimeGranularity(input_granularity)
except ValueError:
error_msg = f"Received a non-standard time granularity, which is not supported at the moment, received: {input_granularity}."
if context:
error_msg += f"\nContext: {context}"
raise ValueError(error_msg)
return time_grain
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity
from typing_extensions import override

from metricflow_semantics.errors.custom_grain_not_supported import error_if_not_standard_grain
from metricflow_semantics.mf_logging.formatting import indent
from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat
from metricflow_semantics.mf_logging.pretty_print import mf_pformat, mf_pformat_many
Expand Down Expand Up @@ -401,7 +402,13 @@ def visit_metric_node(self, node: MetricGroupByItemResolutionNode) -> PushDownRe

# If time granularity is not set for the metric, defaults to DAY if available, else the smallest available granularity.
# Note: ignores any granularity set on input metrics.
metric_default_time_granularity = metric_to_use_for_time_granularity_resolution.time_granularity or max(
metric_time_granularity: Optional[TimeGranularity] = None
if metric_to_use_for_time_granularity_resolution.time_granularity is not None:
metric_time_granularity = error_if_not_standard_grain(
context=f"Metric({metric_to_use_for_time_granularity_resolution}).time_granularity",
input_granularity=metric_to_use_for_time_granularity_resolution.time_granularity,
)
metric_default_time_granularity = metric_time_granularity or max(
TimeGranularity.DAY,
self._semantic_manifest_lookup.metric_lookup.get_min_queryable_time_granularity(
MetricReference(metric_to_use_for_time_granularity_resolution.name)
Expand Down
30 changes: 23 additions & 7 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from dbt_semantic_interfaces.validations.unique_valid_name import MetricFlowReservedKeywords
from metricflow_semantics.dag.id_prefix import StaticIdPrefix
from metricflow_semantics.dag.mf_dag import DagId
from metricflow_semantics.errors.custom_grain_not_supported import error_if_not_standard_grain
from metricflow_semantics.errors.error_classes import UnableToSatisfyQueryError
from metricflow_semantics.filters.time_constraint import TimeRangeConstraint
from metricflow_semantics.mf_logging.formatting import indent
Expand Down Expand Up @@ -512,6 +513,16 @@ def _build_base_metric_output_node(
len(metric.input_measures) == 1
), f"A base metric should not have multiple measures. Got {metric.input_measures}"

cumulative_grain_to_date: Optional[TimeGranularity] = None
if (
metric.type_params.cumulative_type_params
and metric.type_params.cumulative_type_params.grain_to_date is not None
):
cumulative_grain_to_date = error_if_not_standard_grain(
context=f"CumulativeMetric({metric_spec.element_name}).grain_to_date",
input_granularity=metric.type_params.cumulative_type_params.grain_to_date,
)

metric_input_measure_spec = self._build_input_measure_spec(
filter_spec_factory=filter_spec_factory,
metric=metric,
Expand All @@ -526,11 +537,7 @@ def _build_base_metric_output_node(
if metric.type_params.cumulative_type_params
else None
),
cumulative_grain_to_date=(
metric.type_params.cumulative_type_params.grain_to_date
if metric.type_params.cumulative_type_params
else None
),
cumulative_grain_to_date=cumulative_grain_to_date,
)
if metric.type is MetricType.CUMULATIVE
else None
Expand Down Expand Up @@ -1399,6 +1406,13 @@ def _build_input_metric_specs_for_derived_metric(
),
)

input_metric_offset_to_grain: Optional[TimeGranularity] = None
if input_metric.offset_to_grain is not None:
input_metric_offset_to_grain = error_if_not_standard_grain(
context=f"Metric({metric.name}).InputMetric({input_metric.name}).offset_to_grain",
input_granularity=input_metric.offset_to_grain,
)

spec = MetricSpec(
element_name=input_metric.name,
filter_spec_set=filter_spec_set,
Expand All @@ -1411,7 +1425,7 @@ def _build_input_metric_specs_for_derived_metric(
if input_metric.offset_window
else None
),
offset_to_grain=input_metric.offset_to_grain,
offset_to_grain=input_metric_offset_to_grain,
)
input_metric_specs.append(spec)
return tuple(input_metric_specs)
Expand Down Expand Up @@ -1514,7 +1528,9 @@ def _build_aggregated_measure_from_measure_source_node(
granularity: Optional[TimeGranularity] = None
count = 0
if cumulative_window is not None:
granularity = cumulative_window.granularity
granularity = error_if_not_standard_grain(
context="CumulativeMetric.window.granularity", input_granularity=cumulative_window.granularity
)
count = cumulative_window.count
elif cumulative_grain_to_date is not None:
count = 1
Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/join_conversion_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOut

@property
def description(self) -> str: # noqa: D102
return f"Find conversions for {self.entity_spec.qualified_name} within the range of {f'{self.window.count} {self.window.granularity.value}' if self.window else 'INF'}"
return f"Find conversions for {self.entity_spec.qualified_name} within the range of {f'{self.window.count} {self.window.granularity}' if self.window else 'INF'}"

@property
def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
Expand Down
9 changes: 7 additions & 2 deletions metricflow/plan_conversion/sql_join_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dbt_semantic_interfaces.protocols.metric import MetricTimeWindow
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity
from metricflow_semantics.assert_one_arg import assert_exactly_one_arg_set
from metricflow_semantics.errors.custom_grain_not_supported import error_if_not_standard_grain
from metricflow_semantics.sql.sql_join_type import SqlJoinType

from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
Expand Down Expand Up @@ -466,7 +467,9 @@ def _make_time_range_window_join_condition(
right_expr=SqlSubtractTimeIntervalExpression.create(
arg=time_comparison_column_expr,
count=window.count,
granularity=window.granularity,
granularity=error_if_not_standard_grain(
input_granularity=window.granularity,
),
),
)
comparison_expressions.append(start_of_range_comparison_expr)
Expand Down Expand Up @@ -551,7 +554,9 @@ def make_join_to_time_spine_join_description(
)
if node.offset_window:
left_expr = SqlSubtractTimeIntervalExpression.create(
arg=left_expr, count=node.offset_window.count, granularity=node.offset_window.granularity
arg=left_expr,
count=node.offset_window.count,
granularity=error_if_not_standard_grain(input_granularity=node.offset_window.granularity),
)
elif node.offset_to_grain:
left_expr = SqlDateTruncExpression.create(time_granularity=node.offset_to_grain, arg=left_expr)
Expand Down

0 comments on commit 7c0f846

Please sign in to comment.