Skip to content

Commit

Permalink
Add implementation for QueryParameter
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Sep 18, 2023
1 parent 1c8a21e commit 8a63473
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 5 deletions.
2 changes: 0 additions & 2 deletions metricflow/query/query_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,8 +679,6 @@ def _parse_linkable_elements(
elif linkable_elements:
for linkable_element in linkable_elements:
parsed_name = StructuredLinkableSpecName.from_name(linkable_element.name)
if parsed_name.time_granularity:
raise ValueError("Must use object syntax for `grain` parameter if `date_part` is requested.")
structured_name = StructuredLinkableSpecName(
entity_link_names=parsed_name.entity_link_names,
element_name=parsed_name.element_name,
Expand Down
23 changes: 23 additions & 0 deletions metricflow/specs/query_param_implementations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional

from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity

from metricflow.naming.linkable_spec_name import StructuredLinkableSpecName
from metricflow.time.date_part import DatePart


@dataclass(frozen=True)
class DimensionQueryParameter:
"""Time dimension requested in a query."""

name: str
grain: Optional[TimeGranularity] = None
date_part: Optional[DatePart] = None

def __post_init__(self) -> None: # noqa: D
parsed_name = StructuredLinkableSpecName.from_name(self.name)
if parsed_name.time_granularity:
raise ValueError("Must use object syntax for `grain` parameter if `date_part` is requested.")
6 changes: 3 additions & 3 deletions metricflow/test/integration/test_configured_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
DunderColumnAssociationResolver,
)
from metricflow.protocols.sql_client import SqlClient
from metricflow.specs.query_param_implementations import DimensionQueryParameter
from metricflow.sql.sql_exprs import (
SqlCastToTimestampExpression,
SqlColumnReference,
Expand All @@ -32,7 +33,6 @@
SqlTimeDeltaExpression,
)
from metricflow.test.compare_df import assert_dataframes_equal
from metricflow.test.conftest import MockQueryParameter
from metricflow.test.fixtures.setup_fixtures import MetricFlowTestSessionState
from metricflow.test.integration.configured_test_case import (
CONFIGURED_INTEGRATION_TESTS_REPOSITORY,
Expand Down Expand Up @@ -255,7 +255,7 @@ def test_case(

check_query_helpers = CheckQueryHelpers(sql_client)

group_by: List[MockQueryParameter] = []
group_by: List[DimensionQueryParameter] = []
for group_by_kwargs in case.group_by_objs:
kwargs = copy(group_by_kwargs)
date_part = kwargs.get("date_part")
Expand All @@ -264,7 +264,7 @@ def test_case(
kwargs["date_part"] = DatePart(date_part)
if grain:
kwargs["grain"] = TimeGranularity(grain)
group_by.append(MockQueryParameter(**kwargs))
group_by.append(DimensionQueryParameter(**kwargs))
query_result = engine.query(
MetricFlowQueryRequest.create_with_random_request_id(
metric_names=case.metrics,
Expand Down

0 comments on commit 8a63473

Please sign in to comment.