diff --git a/metricflow/query/query_parser.py b/metricflow/query/query_parser.py index b2c6aa4166..8eee9111d4 100644 --- a/metricflow/query/query_parser.py +++ b/metricflow/query/query_parser.py @@ -679,8 +679,6 @@ def _parse_linkable_elements( elif linkable_elements: for linkable_element in linkable_elements: parsed_name = StructuredLinkableSpecName.from_name(linkable_element.name) - if parsed_name.time_granularity: - raise ValueError("Must use object syntax for `grain` parameter if `date_part` is requested.") structured_name = StructuredLinkableSpecName( entity_link_names=parsed_name.entity_link_names, element_name=parsed_name.element_name, diff --git a/metricflow/specs/query_param_implementations.py b/metricflow/specs/query_param_implementations.py new file mode 100644 index 0000000000..62c42d6937 --- /dev/null +++ b/metricflow/specs/query_param_implementations.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + +from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity + +from metricflow.naming.linkable_spec_name import StructuredLinkableSpecName +from metricflow.time.date_part import DatePart + + +@dataclass(frozen=True) +class DimensionQueryParameter: + """Time dimension requested in a query.""" + + name: str + grain: Optional[TimeGranularity] = None + date_part: Optional[DatePart] = None + + def __post_init__(self) -> None: # noqa: D + parsed_name = StructuredLinkableSpecName.from_name(self.name) + if parsed_name.time_granularity: + raise ValueError("Must use object syntax for `grain` parameter if `date_part` is requested.") diff --git a/metricflow/test/integration/test_configured_cases.py b/metricflow/test/integration/test_configured_cases.py index 28bfc0112c..af897edcaa 100644 --- a/metricflow/test/integration/test_configured_cases.py +++ b/metricflow/test/integration/test_configured_cases.py @@ -19,6 +19,7 @@ DunderColumnAssociationResolver, ) from metricflow.protocols.sql_client import SqlClient +from metricflow.specs.query_param_implementations import DimensionQueryParameter from metricflow.sql.sql_exprs import ( SqlCastToTimestampExpression, SqlColumnReference, @@ -32,7 +33,6 @@ SqlTimeDeltaExpression, ) from metricflow.test.compare_df import assert_dataframes_equal -from metricflow.test.conftest import MockQueryParameter from metricflow.test.fixtures.setup_fixtures import MetricFlowTestSessionState from metricflow.test.integration.configured_test_case import ( CONFIGURED_INTEGRATION_TESTS_REPOSITORY, @@ -255,7 +255,7 @@ def test_case( check_query_helpers = CheckQueryHelpers(sql_client) - group_by: List[MockQueryParameter] = [] + group_by: List[DimensionQueryParameter] = [] for group_by_kwargs in case.group_by_objs: kwargs = copy(group_by_kwargs) date_part = kwargs.get("date_part") @@ -264,7 +264,7 @@ def test_case( kwargs["date_part"] = DatePart(date_part) if grain: kwargs["grain"] = TimeGranularity(grain) - group_by.append(MockQueryParameter(**kwargs)) + group_by.append(DimensionQueryParameter(**kwargs)) query_result = engine.query( MetricFlowQueryRequest.create_with_random_request_id( metric_names=case.metrics,