Skip to content

Commit

Permalink
query interface updates
Browse files Browse the repository at this point in the history
  • Loading branch information
DevonFulcher committed Sep 29, 2023
1 parent 65664b5 commit 8df2758
Show file tree
Hide file tree
Showing 12 changed files with 453 additions and 139 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20230929-123932.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: 'Backport: Add support for Dimension(...).grain(...) syntax in where parameter'
time: 2023-09-29T12:39:32.834352-05:00
custom:
Author: DevonFulcher
Issue: None
6 changes: 3 additions & 3 deletions dbt_semantic_interfaces/call_parameter_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@

@dataclass(frozen=True)
class DimensionCallParameterSet:
"""When 'dimension(...)' is used in the Jinja template of the where filter, the parameters to that call."""
"""When 'Dimension(...)' is used in the Jinja template of the where filter, the parameters to that call."""

entity_path: Tuple[EntityReference, ...]
dimension_reference: DimensionReference


@dataclass(frozen=True)
class TimeDimensionCallParameterSet:
"""When 'time_dimension(...)' is used in the Jinja template of the where filter, the parameters to that call."""
"""When 'TimeDimension(...)' is used in the Jinja template of the where filter, the parameters to that call."""

entity_path: Tuple[EntityReference, ...]
time_dimension_reference: TimeDimensionReference
Expand All @@ -30,7 +30,7 @@ class TimeDimensionCallParameterSet:

@dataclass(frozen=True)
class EntityCallParameterSet:
"""When 'entity(...)' is used in the Jinja template of the where filter, the parameters to that call."""
"""When 'Entity(...)' is used in the Jinja template of the where filter, the parameters to that call."""

entity_path: Tuple[EntityReference, ...]
entity_reference: EntityReference
Expand Down
7 changes: 7 additions & 0 deletions dbt_semantic_interfaces/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,10 @@ class ModelTransformError(Exception):
"""Exception to represent errors related to model transformations."""

pass


class InvalidQuerySyntax(Exception):
"""Raised when query syntax is invalid."""

def __init__(self, msg: str) -> None: # noqa: D
super().__init__(msg)
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
PydanticCustomInputParser,
PydanticParseableValueType,
)
from dbt_semantic_interfaces.parsing.where_filter_parser import WhereFilterParser
from dbt_semantic_interfaces.parsing.where_filter.where_filter_parser import (
WhereFilterParser,
)


class PydanticWhereFilter(PydanticCustomInputParser, HashableBaseModel):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from typing import Sequence

from dbt_semantic_interfaces.call_parameter_sets import (
DimensionCallParameterSet,
EntityCallParameterSet,
ParseWhereFilterException,
TimeDimensionCallParameterSet,
)
from dbt_semantic_interfaces.naming.dundered import DunderedNameFormatter
from dbt_semantic_interfaces.naming.keywords import (
METRIC_TIME_ELEMENT_NAME,
is_metric_time_name,
)
from dbt_semantic_interfaces.references import (
DimensionReference,
EntityReference,
TimeDimensionReference,
)
from dbt_semantic_interfaces.type_enums import TimeGranularity


class ParameterSetFactory:
"""Creates parameter sets for use in the Jinja sandbox."""

@staticmethod
def _exception_message_for_incorrect_format(element_name: str) -> str:
return (
f"Name is in an incorrect format: '{element_name}'. It should be of the form: "
f"<primary entity name>__<dimension_name>"
)

@staticmethod
def create_time_dimension(
time_dimension_name: str, time_granularity_name: str, entity_path: Sequence[str] = ()
) -> TimeDimensionCallParameterSet:
"""Gets called by Jinja when rendering {{ TimeDimension(...) }}."""
group_by_item_name = DunderedNameFormatter.parse_name(time_dimension_name)

# metric_time is the only time dimension that does not have an associated primary entity, so the
# GroupByItemName would not have any entity links.
if is_metric_time_name(group_by_item_name.element_name):
if len(group_by_item_name.entity_links) != 0 or group_by_item_name.time_granularity is not None:
raise ParseWhereFilterException(
f"Name is in an incorrect format: {time_dimension_name} "
f"When referencing {METRIC_TIME_ELEMENT_NAME},"
"the name should not have any dunders (double underscores, or __)."
)
else:
if len(group_by_item_name.entity_links) != 1 or group_by_item_name.time_granularity is not None:
raise ParseWhereFilterException(
ParameterSetFactory._exception_message_for_incorrect_format(time_dimension_name)
)

return TimeDimensionCallParameterSet(
time_dimension_reference=TimeDimensionReference(element_name=group_by_item_name.element_name),
entity_path=(
tuple(EntityReference(element_name=arg) for arg in entity_path) + group_by_item_name.entity_links
),
time_granularity=TimeGranularity(time_granularity_name),
)

@staticmethod
def create_dimension(dimension_name: str, entity_path: Sequence[str] = ()) -> DimensionCallParameterSet:
"""Gets called by Jinja when rendering {{ Dimension(...) }}."""
group_by_item_name = DunderedNameFormatter.parse_name(dimension_name)
if is_metric_time_name(group_by_item_name.element_name):
raise ParseWhereFilterException(
f"{METRIC_TIME_ELEMENT_NAME} is a time dimension, so it should be referenced using "
f"TimeDimension(...) or Dimension(...).grain(...)"
)

if len(group_by_item_name.entity_links) != 1:
raise ParseWhereFilterException(ParameterSetFactory._exception_message_for_incorrect_format(dimension_name))

return DimensionCallParameterSet(
dimension_reference=DimensionReference(element_name=group_by_item_name.element_name),
entity_path=(
tuple(EntityReference(element_name=arg) for arg in entity_path) + group_by_item_name.entity_links
),
)

@staticmethod
def create_entity(entity_name: str, entity_path: Sequence[str] = ()) -> EntityCallParameterSet:
"""Gets called by Jinja when rendering {{ Entity(...) }}."""
group_by_item_name = DunderedNameFormatter.parse_name(entity_name)
if len(group_by_item_name.entity_links) > 0 or group_by_item_name.time_granularity is not None:
ParameterSetFactory._exception_message_for_incorrect_format(
f"Name is in an incorrect format: {entity_name} "
f"When referencing entities, the name should not have any dunders (double underscores, or __)."
)

return EntityCallParameterSet(
entity_path=tuple(EntityReference(element_name=arg) for arg in entity_path),
entity_reference=EntityReference(element_name=entity_name),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from __future__ import annotations

from typing import List, Optional, Sequence

from typing_extensions import override

from dbt_semantic_interfaces.errors import InvalidQuerySyntax
from dbt_semantic_interfaces.protocols.protocol_hint import ProtocolHint
from dbt_semantic_interfaces.protocols.query_interface import (
QueryInterfaceDimension,
QueryInterfaceDimensionFactory,
)


class WhereFilterDimension(ProtocolHint[QueryInterfaceDimension]):
"""A dimension that is passed in through the where filter parameter."""

@override
def _implements_protocol(self) -> QueryInterfaceDimension:
return self

def __init__( # noqa
self,
name: str,
entity_path: Sequence[str],
) -> None:
self.name = name
self.entity_path = entity_path
self.time_granularity_name: Optional[str] = None

def grain(self, time_granularity: str) -> QueryInterfaceDimension:
"""The time granularity."""
self.time_granularity_name = time_granularity
return self

def descending(self, _is_descending: bool) -> QueryInterfaceDimension:
"""Set the sort order for order-by."""
raise InvalidQuerySyntax("descending is invalid in the where parameter and filter spec")

def date_part(self, _date_part: str) -> QueryInterfaceDimension:
"""Date part to extract from the dimension."""
raise InvalidQuerySyntax("date_part isn't currently supported in the where parameter and filter spec")


class WhereFilterDimensionFactory(ProtocolHint[QueryInterfaceDimensionFactory]):
"""Creates a WhereFilterDimension.
Each call to `create` adds a WhereFilterDimension to `created`.
"""

@override
def _implements_protocol(self) -> QueryInterfaceDimensionFactory:
return self

def __init__(self) -> None: # noqa
self.created: List[WhereFilterDimension] = []

def create(self, dimension_name: str, entity_path: Sequence[str] = ()) -> WhereFilterDimension:
"""Gets called by Jinja when rendering {{ Dimension(...) }}."""
dimension = WhereFilterDimension(dimension_name, entity_path)
self.created.append(dimension)
return dimension
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from __future__ import annotations

from typing import List, Sequence

from typing_extensions import override

from dbt_semantic_interfaces.call_parameter_sets import EntityCallParameterSet
from dbt_semantic_interfaces.parsing.where_filter.parameter_set_factory import (
ParameterSetFactory,
)
from dbt_semantic_interfaces.protocols.protocol_hint import ProtocolHint
from dbt_semantic_interfaces.protocols.query_interface import (
QueryInterfaceEntity,
QueryInterfaceEntityFactory,
)


class EntityStub(ProtocolHint[QueryInterfaceEntity]):
"""An Entity implementation that just satisfies the protocol.
QueryInterfaceEntity currently has no methods and the parameter set is created in the factory.
So, there is nothing to do here.
"""

@override
def _implements_protocol(self) -> QueryInterfaceEntity:
return self


class WhereFilterEntityFactory(ProtocolHint[QueryInterfaceEntityFactory]):
"""Executes in the Jinja sandbox to produce parameter sets and append them to a list."""

@override
def _implements_protocol(self) -> QueryInterfaceEntityFactory:
return self

def __init__(self) -> None: # noqa
self.entity_call_parameter_sets: List[EntityCallParameterSet] = []

def create(self, entity_name: str, entity_path: Sequence[str] = ()) -> EntityStub:
"""Gets called by Jinja when rendering {{ Entity(...) }}."""
self.entity_call_parameter_sets.append(ParameterSetFactory.create_entity(entity_name, entity_path))
return EntityStub()
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from __future__ import annotations

from jinja2 import StrictUndefined
from jinja2.exceptions import SecurityError, TemplateSyntaxError, UndefinedError
from jinja2.sandbox import SandboxedEnvironment

from dbt_semantic_interfaces.call_parameter_sets import (
FilterCallParameterSets,
ParseWhereFilterException,
)
from dbt_semantic_interfaces.parsing.where_filter.parameter_set_factory import (
ParameterSetFactory,
)
from dbt_semantic_interfaces.parsing.where_filter.where_filter_dimension import (
WhereFilterDimensionFactory,
)
from dbt_semantic_interfaces.parsing.where_filter.where_filter_entity import (
WhereFilterEntityFactory,
)
from dbt_semantic_interfaces.parsing.where_filter.where_filter_time_dimension import (
WhereFilterTimeDimensionFactory,
)


class WhereFilterParser:
"""Parses the template in the WhereFilter into FilterCallParameterSets."""

@staticmethod
def parse_call_parameter_sets(where_sql_template: str) -> FilterCallParameterSets:
"""Return the result of extracting the semantic objects referenced in the where SQL template string."""
time_dimension_factory = WhereFilterTimeDimensionFactory()
dimension_factory = WhereFilterDimensionFactory()
entity_factory = WhereFilterEntityFactory()

try:
# the string that the sandbox renders is unused
SandboxedEnvironment(undefined=StrictUndefined).from_string(where_sql_template).render(
Dimension=dimension_factory.create,
TimeDimension=time_dimension_factory.create,
Entity=entity_factory.create,
)
except (UndefinedError, TemplateSyntaxError, SecurityError) as e:
raise ParseWhereFilterException(f"Error while parsing Jinja template:\n{where_sql_template}") from e

"""
Dimensions that are created with a grain parameter, Dimension(...).grain(...), are
added to time_dimension_call_parameter_sets otherwise they are add to dimension_call_parameter_sets
"""
dimension_call_parameter_sets = []
for dimension in dimension_factory.created:
if dimension.time_granularity_name:
time_dimension_factory.time_dimension_call_parameter_sets.append(
ParameterSetFactory.create_time_dimension(
dimension.name,
dimension.time_granularity_name,
dimension.entity_path,
)
)
else:
dimension_call_parameter_sets.append(
ParameterSetFactory.create_dimension(dimension.name, dimension.entity_path)
)

return FilterCallParameterSets(
dimension_call_parameter_sets=tuple(dimension_call_parameter_sets),
time_dimension_call_parameter_sets=tuple(time_dimension_factory.time_dimension_call_parameter_sets),
entity_call_parameter_sets=tuple(entity_factory.entity_call_parameter_sets),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from __future__ import annotations

from typing import List, Optional, Sequence

from typing_extensions import override

from dbt_semantic_interfaces.call_parameter_sets import TimeDimensionCallParameterSet
from dbt_semantic_interfaces.errors import InvalidQuerySyntax
from dbt_semantic_interfaces.parsing.where_filter.parameter_set_factory import (
ParameterSetFactory,
)
from dbt_semantic_interfaces.protocols.protocol_hint import ProtocolHint
from dbt_semantic_interfaces.protocols.query_interface import (
QueryInterfaceTimeDimension,
QueryInterfaceTimeDimensionFactory,
)


class TimeDimensionStub(ProtocolHint[QueryInterfaceTimeDimension]):
"""A TimeDimension implementation that just satisfies the protocol.
QueryInterfaceTimeDimension currently has no methods and the parameter set is created in the factory.
So, there is nothing to do here.
"""

@override
def _implements_protocol(self) -> QueryInterfaceTimeDimension:
return self


class WhereFilterTimeDimensionFactory(ProtocolHint[QueryInterfaceTimeDimensionFactory]):
"""Executes in the Jinja sandbox to produce parameter sets and append them to a list."""

@override
def _implements_protocol(self) -> QueryInterfaceTimeDimensionFactory:
return self

def __init__(self) -> None: # noqa
self.time_dimension_call_parameter_sets: List[TimeDimensionCallParameterSet] = []

def create(
self,
time_dimension_name: str,
time_granularity_name: str,
entity_path: Sequence[str] = (),
descending: Optional[bool] = None,
date_part_name: Optional[str] = None,
) -> TimeDimensionStub:
"""Gets called by Jinja when rendering {{ TimeDimension(...) }}."""
if descending is not None:
raise InvalidQuerySyntax("descending is invalid in the where parameter and filter spec")
if date_part_name is not None:
raise InvalidQuerySyntax("date_part isn't currently supported in the where parameter and filter spec")
self.time_dimension_call_parameter_sets.append(
ParameterSetFactory.create_time_dimension(time_dimension_name, time_granularity_name, entity_path)
)
return TimeDimensionStub()
Loading

0 comments on commit 8df2758

Please sign in to comment.