diff --git a/.changes/unreleased/Features-20231107-180843.yaml b/.changes/unreleased/Features-20231107-180843.yaml new file mode 100644 index 0000000000..59b3d218ea --- /dev/null +++ b/.changes/unreleased/Features-20231107-180843.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Implemented date_part in where filter. +time: 2023-11-07T18:08:43.67846-06:00 +custom: + Author: DevonFulcher + Issue: None diff --git a/.changes/unreleased/Under the Hood-20231107-184138.yaml b/.changes/unreleased/Under the Hood-20231107-184138.yaml new file mode 100644 index 0000000000..51ff6aeb11 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20231107-184138.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Removed DatePart Enum and change imports to depend on DSI version instead. +time: 2023-11-07T18:41:38.606807-06:00 +custom: + Author: DevonFulcher + Issue: None diff --git a/Makefile b/Makefile index e70cbb94ea..41eb3d1d16 100644 --- a/Makefile +++ b/Makefile @@ -74,6 +74,11 @@ postgresql postgres: regenerate-test-snapshots: hatch -v run dev-env:python metricflow/test/generate_snapshots.py +# Populate persistent source schemas for all relevant SQL engines. +.PHONY: populate-persistent-source-schemas +populate-persistent-source-schemas: + hatch -v run dev-env:python metricflow/test/populate_persistent_source_schemas.py + # Re-generate snapshots for the default SQL engine. .PHONY: test-snap test-snap: diff --git a/metricflow/dataset/convert_semantic_model.py b/metricflow/dataset/convert_semantic_model.py index e435e69c22..cc29f45bca 100644 --- a/metricflow/dataset/convert_semantic_model.py +++ b/metricflow/dataset/convert_semantic_model.py @@ -10,6 +10,7 @@ from dbt_semantic_interfaces.protocols.measure import Measure from dbt_semantic_interfaces.protocols.semantic_model import SemanticModel from dbt_semantic_interfaces.references import SemanticModelElementReference, SemanticModelReference +from dbt_semantic_interfaces.type_enums.date_part import DatePart from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity from metricflow.aggregation_properties import AggregationState @@ -47,7 +48,6 @@ SqlSelectStatementNode, SqlTableFromClauseNode, ) -from metricflow.time.date_part import DatePart logger = logging.getLogger(__name__) diff --git a/metricflow/dataset/dataset.py b/metricflow/dataset/dataset.py index ea88d69b72..b32bf28274 100644 --- a/metricflow/dataset/dataset.py +++ b/metricflow/dataset/dataset.py @@ -4,12 +4,12 @@ from typing import Optional, Sequence from dbt_semantic_interfaces.references import TimeDimensionReference +from dbt_semantic_interfaces.type_enums.date_part import DatePart from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity from dbt_semantic_interfaces.validations.unique_valid_name import MetricFlowReservedKeywords from metricflow.instances import InstanceSet, TimeDimensionInstance from metricflow.specs.specs import TimeDimensionSpec -from metricflow.time.date_part import DatePart logger = logging.getLogger(__name__) diff --git a/metricflow/naming/linkable_spec_name.py b/metricflow/naming/linkable_spec_name.py index 27fa92416b..bfcd1d6d12 100644 --- a/metricflow/naming/linkable_spec_name.py +++ b/metricflow/naming/linkable_spec_name.py @@ -4,10 +4,9 @@ from dataclasses import dataclass from typing import Optional, Tuple +from dbt_semantic_interfaces.type_enums.date_part import DatePart from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity -from metricflow.time.date_part import DatePart - DUNDER = "__" logger = logging.getLogger(__name__) diff --git a/metricflow/plan_conversion/instance_converters.py b/metricflow/plan_conversion/instance_converters.py index 07bf4922f1..fbe7d6c30a 100644 --- a/metricflow/plan_conversion/instance_converters.py +++ b/metricflow/plan_conversion/instance_converters.py @@ -9,6 +9,7 @@ from typing import Dict, List, Optional, Sequence, Tuple from dbt_semantic_interfaces.references import SemanticModelReference +from dbt_semantic_interfaces.type_enums.date_part import DatePart from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity from more_itertools import bucket @@ -47,7 +48,6 @@ SqlFunctionExpression, ) from metricflow.sql.sql_plan import SqlSelectColumn -from metricflow.time.date_part import DatePart logger = logging.getLogger(__name__) diff --git a/metricflow/protocols/query_parameter.py b/metricflow/protocols/query_parameter.py index 0424b3f5a1..9285270e75 100644 --- a/metricflow/protocols/query_parameter.py +++ b/metricflow/protocols/query_parameter.py @@ -3,8 +3,7 @@ from typing import Optional, Protocol, Union, runtime_checkable from dbt_semantic_interfaces.type_enums import TimeGranularity - -from metricflow.time.date_part import DatePart +from dbt_semantic_interfaces.type_enums.date_part import DatePart @runtime_checkable diff --git a/metricflow/query/query_parser.py b/metricflow/query/query_parser.py index f302a48365..dbec189e57 100644 --- a/metricflow/query/query_parser.py +++ b/metricflow/query/query_parser.py @@ -21,6 +21,7 @@ MetricReference, TimeDimensionReference, ) +from dbt_semantic_interfaces.type_enums.date_part import DatePart from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity from metricflow.assert_one_arg import assert_exactly_one_arg_set @@ -55,7 +56,6 @@ WhereFilterSpec, ) from metricflow.specs.where_filter_transform import WhereSpecFactory -from metricflow.time.date_part import DatePart from metricflow.time.time_granularity_solver import ( PartialTimeDimensionSpec, RequestTimeGranularityException, diff --git a/metricflow/specs/dimension_spec_resolver.py b/metricflow/specs/dimension_spec_resolver.py index febbefa478..25caa0ac6c 100644 --- a/metricflow/specs/dimension_spec_resolver.py +++ b/metricflow/specs/dimension_spec_resolver.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Sequence +from typing import Optional, Sequence from dbt_semantic_interfaces.call_parameter_sets import ( DimensionCallParameterSet, @@ -10,6 +10,7 @@ from dbt_semantic_interfaces.naming.dundered import DunderedNameFormatter from dbt_semantic_interfaces.references import DimensionReference, EntityReference, TimeDimensionReference from dbt_semantic_interfaces.type_enums import TimeGranularity +from dbt_semantic_interfaces.type_enums.date_part import DatePart from metricflow.specs.specs import DEFAULT_TIME_GRANULARITY, DimensionSpec, TimeDimensionSpec @@ -35,16 +36,21 @@ def resolve_dimension_spec(self, name: str, entity_path: Sequence[str]) -> Dimen ) def resolve_time_dimension_spec( - self, name: str, time_granularity_name: TimeGranularity, entity_path: Sequence[str] + self, + name: str, + time_granularity: Optional[TimeGranularity], + entity_path: Sequence[str], + date_part: Optional[DatePart], ) -> TimeDimensionSpec: """Resolve TimeDimension spec with the call_parameter_sets.""" structured_name = DunderedNameFormatter.parse_name(name) call_parameter_set = TimeDimensionCallParameterSet( time_dimension_reference=TimeDimensionReference(element_name=structured_name.element_name), - time_granularity=time_granularity_name, + time_granularity=time_granularity, entity_path=( tuple(EntityReference(element_name=arg) for arg in entity_path) + structured_name.entity_links ), + date_part=date_part, ) assert call_parameter_set in self._call_parameter_sets.time_dimension_call_parameter_sets return TimeDimensionSpec( @@ -56,4 +62,5 @@ def resolve_time_dimension_spec( if call_parameter_set.time_granularity is not None else DEFAULT_TIME_GRANULARITY ), + date_part=call_parameter_set.date_part, ) diff --git a/metricflow/specs/query_param_implementations.py b/metricflow/specs/query_param_implementations.py index bf445f1435..96372c2d79 100644 --- a/metricflow/specs/query_param_implementations.py +++ b/metricflow/specs/query_param_implementations.py @@ -4,13 +4,13 @@ from typing import Optional from dbt_semantic_interfaces.protocols import ProtocolHint +from dbt_semantic_interfaces.type_enums.date_part import DatePart from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity from typing_extensions import override from metricflow.naming.linkable_spec_name import StructuredLinkableSpecName from metricflow.protocols.query_parameter import InputOrderByParameter from metricflow.protocols.query_parameter import SavedQueryParameter as SavedQueryParameterProtocol -from metricflow.time.date_part import DatePart @dataclass(frozen=True) diff --git a/metricflow/specs/specs.py b/metricflow/specs/specs.py index 2b2e8b9544..601c4d0da0 100644 --- a/metricflow/specs/specs.py +++ b/metricflow/specs/specs.py @@ -26,6 +26,7 @@ TimeDimensionReference, ) from dbt_semantic_interfaces.type_enums.aggregation_type import AggregationType +from dbt_semantic_interfaces.type_enums.date_part import DatePart from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity from metricflow.aggregation_properties import AggregationState @@ -33,7 +34,6 @@ from metricflow.naming.linkable_spec_name import StructuredLinkableSpecName from metricflow.sql.sql_bind_parameters import SqlBindParameters from metricflow.sql.sql_column_type import SqlColumnType -from metricflow.time.date_part import DatePart from metricflow.visitor import VisitorOutputT diff --git a/metricflow/specs/where_filter_dimension.py b/metricflow/specs/where_filter_dimension.py index 9af2858d8f..f98df6d249 100644 --- a/metricflow/specs/where_filter_dimension.py +++ b/metricflow/specs/where_filter_dimension.py @@ -11,12 +11,13 @@ QueryInterfaceDimensionFactory, ) from dbt_semantic_interfaces.type_enums import TimeGranularity +from dbt_semantic_interfaces.type_enums.date_part import DatePart from typing_extensions import override from metricflow.errors.errors import InvalidQuerySyntax from metricflow.specs.column_assoc import ColumnAssociationResolver from metricflow.specs.dimension_spec_resolver import DimensionSpecResolver -from metricflow.specs.specs import TimeDimensionSpec +from metricflow.specs.specs import DimensionSpec, InstanceSpec, TimeDimensionSpec class WhereFilterDimension(ProtocolHint[QueryInterfaceDimension]): @@ -37,32 +38,48 @@ def __init__( # noqa self._column_association_resolver = column_association_resolver self._name = name self._entity_path = entity_path - self.dimension_spec = self._dimension_spec_resolver.resolve_dimension_spec(name, entity_path) - self.time_dimension_spec: Optional[TimeDimensionSpec] = None + self.dimension_spec: DimensionSpec = self._dimension_spec_resolver.resolve_dimension_spec( + self._name, self._entity_path + ) + self.date_part_name: Optional[str] = None + self.time_granularity_name: Optional[str] = None + + @property + def time_dimension_spec(self) -> TimeDimensionSpec: + """TimeDimensionSpec that results from the builder-pattern configuration.""" + return self._dimension_spec_resolver.resolve_time_dimension_spec( + self._name, + TimeGranularity(self.time_granularity_name) if self.time_granularity_name else None, + self._entity_path, + DatePart(self.date_part_name) if self.date_part_name else None, + ) def grain(self, time_granularity_name: str) -> QueryInterfaceDimension: """The time granularity.""" - self.time_dimension_spec = self._dimension_spec_resolver.resolve_time_dimension_spec( - self._name, TimeGranularity(time_granularity_name), self._entity_path - ) + self.time_granularity_name = time_granularity_name return self - def date_part(self, _date_part: str) -> QueryInterfaceDimension: + def date_part(self, date_part_name: str) -> QueryInterfaceDimension: """The date_part requested to extract.""" - raise InvalidQuerySyntax("date_part isn't currently supported in the where parameter") + self.date_part_name = date_part_name + return self def descending(self, _is_descending: bool) -> QueryInterfaceDimension: """Set the sort order for order-by.""" raise InvalidQuerySyntax("descending is invalid in the where parameter") + def _get_spec(self) -> InstanceSpec: + """Get either the TimeDimensionSpec or DimensionSpec.""" + if self.time_granularity_name or self.date_part_name: + return self.time_dimension_spec + return self.dimension_spec + def __str__(self) -> str: """Returns the column name. Important in the Jinja sandbox. """ - return self._column_association_resolver.resolve_spec( - self.time_dimension_spec or self.dimension_spec - ).column_name + return self._column_association_resolver.resolve_spec(self._get_spec()).column_name class WhereFilterDimensionFactory(ProtocolHint[QueryInterfaceDimensionFactory]): diff --git a/metricflow/specs/where_filter_time_dimension.py b/metricflow/specs/where_filter_time_dimension.py index 75d12d5fbe..a6f840dc92 100644 --- a/metricflow/specs/where_filter_time_dimension.py +++ b/metricflow/specs/where_filter_time_dimension.py @@ -9,6 +9,7 @@ QueryInterfaceTimeDimensionFactory, ) from dbt_semantic_interfaces.type_enums import TimeGranularity +from dbt_semantic_interfaces.type_enums.date_part import DatePart from typing_extensions import override from metricflow.errors.errors import InvalidQuerySyntax @@ -68,10 +69,11 @@ def create( raise InvalidQuerySyntax( "Can't set descending in the where clause. Try setting descending in the order_by clause instead" ) - if date_part_name: - raise InvalidQuerySyntax("date_part_name isn't currently supported in the where parameter") time_dimension_spec = self._dimension_spec_resolver.resolve_time_dimension_spec( - time_dimension_name, TimeGranularity(time_granularity_name), entity_path + time_dimension_name, + TimeGranularity(time_granularity_name) if time_dimension_name else None, + entity_path, + DatePart(date_part_name) if date_part_name else None, ) self.time_dimension_specs.append(time_dimension_spec) column_name = self._column_association_resolver.resolve_spec(time_dimension_spec).column_name diff --git a/metricflow/specs/where_filter_transform.py b/metricflow/specs/where_filter_transform.py index c039cee1f8..418737cead 100644 --- a/metricflow/specs/where_filter_transform.py +++ b/metricflow/specs/where_filter_transform.py @@ -82,12 +82,12 @@ def create_from_where_filter(self, where_filter: WhereFilter) -> WhereFilterSpec ) """ - Dimensions that are created with a grain parameter, Dimension(...).grain(...), are - added to dimension_specs otherwise they are add to time_dimension_factory.time_dimension_specs + Dimensions that are created with a grain or date_part parameter, Dimension(...).grain(...), are + added to time_dimension_factory.time_dimension_specs otherwise they are add to dimension_specs """ dimension_specs = [] for dimension in dimension_factory.created: - if dimension.time_dimension_spec: + if dimension.time_granularity_name or dimension.date_part_name: time_dimension_factory.time_dimension_specs.append(dimension.time_dimension_spec) else: dimension_specs.append(dimension.dimension_spec) diff --git a/metricflow/sql/render/big_query.py b/metricflow/sql/render/big_query.py index 83a4f82c38..7b78bad9ed 100644 --- a/metricflow/sql/render/big_query.py +++ b/metricflow/sql/render/big_query.py @@ -4,6 +4,7 @@ from typing import Collection from dbt_semantic_interfaces.enum_extension import assert_values_exhausted +from dbt_semantic_interfaces.type_enums.date_part import DatePart from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity from typing_extensions import override @@ -25,7 +26,6 @@ SqlSubtractTimeIntervalExpression, ) from metricflow.sql.sql_plan import SqlSelectColumn -from metricflow.time.date_part import DatePart class BigQuerySqlExpressionRenderer(DefaultSqlExpressionRenderer): diff --git a/metricflow/sql/render/databricks.py b/metricflow/sql/render/databricks.py index bdd2b58d33..0a9d2de899 100644 --- a/metricflow/sql/render/databricks.py +++ b/metricflow/sql/render/databricks.py @@ -3,6 +3,7 @@ from typing import Collection from dbt_semantic_interfaces.enum_extension import assert_values_exhausted +from dbt_semantic_interfaces.type_enums.date_part import DatePart from typing_extensions import override from metricflow.errors.errors import UnsupportedEngineFeatureError @@ -13,7 +14,6 @@ ) from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer from metricflow.sql.sql_exprs import SqlPercentileExpression, SqlPercentileFunctionType -from metricflow.time.date_part import DatePart class DatabricksSqlExpressionRenderer(DefaultSqlExpressionRenderer): diff --git a/metricflow/sql/render/expr_renderer.py b/metricflow/sql/render/expr_renderer.py index 8c603ffdc4..29fac9b1a1 100644 --- a/metricflow/sql/render/expr_renderer.py +++ b/metricflow/sql/render/expr_renderer.py @@ -8,6 +8,7 @@ from typing import Collection, List import jinja2 +from dbt_semantic_interfaces.type_enums.date_part import DatePart from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity from typing_extensions import override @@ -37,7 +38,6 @@ SqlWindowFunctionExpression, ) from metricflow.sql.sql_plan import SqlSelectColumn -from metricflow.time.date_part import DatePart logger = logging.getLogger(__name__) diff --git a/metricflow/sql/render/redshift.py b/metricflow/sql/render/redshift.py index caa7398a3a..251227c209 100644 --- a/metricflow/sql/render/redshift.py +++ b/metricflow/sql/render/redshift.py @@ -3,6 +3,7 @@ from typing import Collection from dbt_semantic_interfaces.enum_extension import assert_values_exhausted +from dbt_semantic_interfaces.type_enums.date_part import DatePart from typing_extensions import override from metricflow.errors.errors import UnsupportedEngineFeatureError @@ -19,7 +20,6 @@ SqlPercentileExpression, SqlPercentileFunctionType, ) -from metricflow.time.date_part import DatePart class RedshiftSqlExpressionRenderer(DefaultSqlExpressionRenderer): diff --git a/metricflow/sql/render/snowflake.py b/metricflow/sql/render/snowflake.py index c368cbdf57..c6328feefe 100644 --- a/metricflow/sql/render/snowflake.py +++ b/metricflow/sql/render/snowflake.py @@ -3,6 +3,7 @@ from typing import Collection from dbt_semantic_interfaces.enum_extension import assert_values_exhausted +from dbt_semantic_interfaces.type_enums.date_part import DatePart from typing_extensions import override from metricflow.errors.errors import UnsupportedEngineFeatureError @@ -14,7 +15,6 @@ from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer from metricflow.sql.sql_bind_parameters import SqlBindParameters from metricflow.sql.sql_exprs import SqlGenerateUuidExpression, SqlPercentileExpression, SqlPercentileFunctionType -from metricflow.time.date_part import DatePart class SnowflakeSqlExpressionRenderer(DefaultSqlExpressionRenderer): diff --git a/metricflow/sql/sql_exprs.py b/metricflow/sql/sql_exprs.py index 2206691e7c..c517c1bae1 100644 --- a/metricflow/sql/sql_exprs.py +++ b/metricflow/sql/sql_exprs.py @@ -12,6 +12,7 @@ from dbt_semantic_interfaces.enum_extension import assert_values_exhausted from dbt_semantic_interfaces.protocols.measure import MeasureAggregationParameters from dbt_semantic_interfaces.type_enums.aggregation_type import AggregationType +from dbt_semantic_interfaces.type_enums.date_part import DatePart from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity from metricflow.dag.id_generation import ( @@ -35,7 +36,6 @@ ) from metricflow.dag.mf_dag import DagNode, DisplayedProperty, NodeId from metricflow.sql.sql_bind_parameters import SqlBindParameters -from metricflow.time.date_part import DatePart from metricflow.visitor import Visitable, VisitorOutputT diff --git a/metricflow/test/generate_snapshots.py b/metricflow/test/generate_snapshots.py index ceb1568e84..3be91cb750 100644 --- a/metricflow/test/generate_snapshots.py +++ b/metricflow/test/generate_snapshots.py @@ -37,7 +37,7 @@ import logging import os from dataclasses import dataclass -from typing import Optional, Sequence +from typing import Callable, Optional, Sequence from dbt_semantic_interfaces.enum_extension import assert_values_exhausted from dbt_semantic_interfaces.implementations.base import FrozenBaseModel @@ -107,7 +107,11 @@ def run_command(command: str) -> None: # noqa: D raise RuntimeError(f"Error running command: {command}") -def run_tests(test_configuration: MetricFlowTestConfiguration) -> None: # noqa: D +def set_engine_env_variables(test_configuration: MetricFlowTestConfiguration) -> None: + """Set connection env variables dynamically for the engine being used. + + Requires MF_TEST_ENGINE_CREDENTIALS env variable to be set with creds for all engines. + """ if test_configuration.credential_set.engine_url is None: if "MF_SQL_ENGINE_URL" in os.environ: del os.environ["MF_SQL_ENGINE_URL"] @@ -120,6 +124,10 @@ def run_tests(test_configuration: MetricFlowTestConfiguration) -> None: # noqa: else: os.environ["MF_SQL_ENGINE_PASSWORD"] = test_configuration.credential_set.engine_password + +def run_tests(test_configuration: MetricFlowTestConfiguration) -> None: # noqa: D + set_engine_env_variables(test_configuration) + if test_configuration.engine is SqlEngine.DUCKDB: # DuckDB is fast, so generate all snapshots, including the engine-agnostic ones run_command(f"pytest -x -vv -n 4 --overwrite-snapshots -k 'not itest' {TEST_DIRECTORY}") @@ -145,7 +153,7 @@ def run_tests(test_configuration: MetricFlowTestConfiguration) -> None: # noqa: assert_values_exhausted(test_configuration.engine) -def run_cli() -> None: # noqa: D +def run_cli(function_to_run: Callable) -> None: # noqa: D # Setup logging. dev_format = "%(asctime)s %(levelname)s %(filename)s:%(lineno)d [%(threadName)s] - %(message)s" logging.basicConfig(level=logging.INFO, format=dev_format) @@ -165,8 +173,8 @@ def run_cli() -> None: # noqa: D logger.info( f"Running tests for {test_configuration.engine} with URL: {test_configuration.credential_set.engine_url}" ) - run_tests(test_configuration) + function_to_run(test_configuration) if __name__ == "__main__": - run_cli() + run_cli(run_tests) diff --git a/metricflow/test/integration/test_configured_cases.py b/metricflow/test/integration/test_configured_cases.py index 95579a11d1..de7faaeb4a 100644 --- a/metricflow/test/integration/test_configured_cases.py +++ b/metricflow/test/integration/test_configured_cases.py @@ -11,6 +11,7 @@ from dbt_semantic_interfaces.enum_extension import assert_values_exhausted from dbt_semantic_interfaces.implementations.elements.measure import PydanticMeasureAggregationParameters from dbt_semantic_interfaces.test_utils import as_datetime +from dbt_semantic_interfaces.type_enums.date_part import DatePart from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity from metricflow.engine.metricflow_engine import MetricFlowEngine, MetricFlowQueryRequest @@ -43,7 +44,6 @@ from metricflow.test.time.configurable_time_source import ( ConfigurableTimeSource, ) -from metricflow.time.date_part import DatePart logger = logging.getLogger(__name__) diff --git a/metricflow/test/model/test_where_filter_spec.py b/metricflow/test/model/test_where_filter_spec.py index 80dd4cc0fa..3d6e905259 100644 --- a/metricflow/test/model/test_where_filter_spec.py +++ b/metricflow/test/model/test_where_filter_spec.py @@ -5,6 +5,7 @@ import pytest from dbt_semantic_interfaces.implementations.filters.where_filter import PydanticWhereFilter from dbt_semantic_interfaces.references import EntityReference +from dbt_semantic_interfaces.type_enums.date_part import DatePart from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity from metricflow.query.query_exceptions import InvalidQueryException @@ -98,6 +99,94 @@ def test_time_dimension_in_filter( # noqa: D ) +def test_date_part_in_filter( # noqa: D + column_association_resolver: ColumnAssociationResolver, +) -> None: + where_filter = PydanticWhereFilter(where_sql_template="{{ Dimension('metric_time').date_part('year') }} = '2020'") + + where_filter_spec = WhereSpecFactory( + column_association_resolver=column_association_resolver, + ).create_from_where_filter(where_filter) + + assert where_filter_spec.where_sql == "metric_time__extract_year = '2020'" + assert where_filter_spec.linkable_spec_set == LinkableSpecSet( + dimension_specs=(), + time_dimension_specs=( + TimeDimensionSpec( + element_name="metric_time", + entity_links=(), + time_granularity=TimeGranularity.DAY, + date_part=DatePart.YEAR, + ), + ), + entity_specs=(), + ) + + +@pytest.mark.parametrize( + "where_sql", + ( + ("{{ TimeDimension('metric_time', 'WEEK', date_part_name='year') }} = '2020'"), + ("{{ Dimension('metric_time').date_part('year').grain('WEEK') }} = '2020'"), + ("{{ Dimension('metric_time').grain('WEEK').date_part('year') }} = '2020'"), + ), +) +def test_date_part_and_grain_in_filter( # noqa: D + column_association_resolver: ColumnAssociationResolver, where_sql: str +) -> None: + where_filter = PydanticWhereFilter(where_sql_template=where_sql) + + where_filter_spec = WhereSpecFactory( + column_association_resolver=column_association_resolver, + ).create_from_where_filter(where_filter) + + assert where_filter_spec.where_sql == "metric_time__extract_year = '2020'" + assert where_filter_spec.linkable_spec_set == LinkableSpecSet( + dimension_specs=(), + time_dimension_specs=( + TimeDimensionSpec( + element_name="metric_time", + entity_links=(), + time_granularity=TimeGranularity.WEEK, + date_part=DatePart.YEAR, + ), + ), + entity_specs=(), + ) + + +@pytest.mark.parametrize( + "where_sql", + ( + ("{{ TimeDimension('metric_time', 'WEEK', date_part_name='day') }} = '2020'"), + ("{{ Dimension('metric_time').date_part('day').grain('WEEK') }} = '2020'"), + ("{{ Dimension('metric_time').grain('WEEK').date_part('day') }} = '2020'"), + ), +) +def test_date_part_less_than_grain_in_filter( # noqa: D + column_association_resolver: ColumnAssociationResolver, where_sql: str +) -> None: + where_filter = PydanticWhereFilter(where_sql_template=where_sql) + + where_filter_spec = WhereSpecFactory( + column_association_resolver=column_association_resolver, + ).create_from_where_filter(where_filter) + + assert where_filter_spec.where_sql == "metric_time__extract_day = '2020'" + assert where_filter_spec.linkable_spec_set == LinkableSpecSet( + dimension_specs=(), + time_dimension_specs=( + TimeDimensionSpec( + element_name="metric_time", + entity_links=(), + time_granularity=TimeGranularity.WEEK, + date_part=DatePart.DAY, + ), + ), + entity_specs=(), + ) + + def test_entity_in_filter( # noqa: D column_association_resolver: ColumnAssociationResolver, ) -> None: diff --git a/metricflow/test/populate_persistent_source_schemas.py b/metricflow/test/populate_persistent_source_schemas.py new file mode 100644 index 0000000000..0f0ff6f59c --- /dev/null +++ b/metricflow/test/populate_persistent_source_schemas.py @@ -0,0 +1,45 @@ +"""Script to help generate persistent source schemas with test data for all relevant engines.""" + +from __future__ import annotations + +import logging +import os + +from dbt_semantic_interfaces.enum_extension import assert_values_exhausted + +from metricflow.protocols.sql_client import SqlEngine +from metricflow.test.generate_snapshots import ( + MetricFlowTestConfiguration, + run_cli, + run_command, + set_engine_env_variables, +) + +logger = logging.getLogger(__name__) + + +def populate_schemas(test_configuration: MetricFlowTestConfiguration) -> None: # noqa: D + set_engine_env_variables(test_configuration) + + if test_configuration.engine is SqlEngine.DUCKDB or test_configuration.engine is SqlEngine.POSTGRES: + # DuckDB & Postgres don't use persistent source schema + return None + elif ( + test_configuration.engine is SqlEngine.SNOWFLAKE + or test_configuration.engine is SqlEngine.BIGQUERY + or test_configuration.engine is SqlEngine.DATABRICKS + or test_configuration.engine is SqlEngine.REDSHIFT + ): + engine_name = test_configuration.engine.value.lower() + os.environ["MF_TEST_ADAPTER_TYPE"] = engine_name + hatch_env = f"{engine_name}-env" + run_command( + f"hatch -v run {hatch_env}:pytest -vv --use-persistent-source-schema " + "metricflow/test/source_schema_tools.py::populate_source_schema" + ) + else: + assert_values_exhausted(test_configuration.engine) + + +if __name__ == "__main__": + run_cli(populate_schemas) diff --git a/metricflow/test/query/test_query_parser.py b/metricflow/test/query/test_query_parser.py index 03e5ab9ccb..9a0815707f 100644 --- a/metricflow/test/query/test_query_parser.py +++ b/metricflow/test/query/test_query_parser.py @@ -8,6 +8,7 @@ from dbt_semantic_interfaces.parsing.objects import YamlConfigFile from dbt_semantic_interfaces.references import EntityReference from dbt_semantic_interfaces.test_utils import as_datetime +from dbt_semantic_interfaces.type_enums.date_part import DatePart from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity from metricflow.errors.errors import UnableToSatisfyQueryError @@ -30,7 +31,6 @@ from metricflow.test.fixtures.model_fixtures import query_parser_from_yaml from metricflow.test.model.example_project_configuration import EXAMPLE_PROJECT_CONFIGURATION_YAML_CONFIG_FILE from metricflow.test.time.metric_time_dimension import MTD -from metricflow.time.date_part import DatePart from metricflow.time.time_granularity_solver import RequestTimeGranularityException logger = logging.getLogger(__name__) diff --git a/metricflow/test/query_rendering/test_granularity_date_part_rendering.py b/metricflow/test/query_rendering/test_granularity_date_part_rendering.py index 548d33b088..38d2cb5c8e 100644 --- a/metricflow/test/query_rendering/test_granularity_date_part_rendering.py +++ b/metricflow/test/query_rendering/test_granularity_date_part_rendering.py @@ -8,6 +8,7 @@ import pytest from _pytest.fixtures import FixtureRequest +from dbt_semantic_interfaces.type_enums.date_part import DatePart from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity from metricflow.dataflow.builder.dataflow_plan_builder import DataflowPlanBuilder @@ -20,7 +21,6 @@ ) from metricflow.test.fixtures.setup_fixtures import MetricFlowTestSessionState from metricflow.test.query_rendering.compare_rendered_query import convert_and_check -from metricflow.time.date_part import DatePart @pytest.mark.sql_engine_snapshot diff --git a/metricflow/test/sql/test_sql_expr_render.py b/metricflow/test/sql/test_sql_expr_render.py index 5d3e309b36..c883313a4d 100644 --- a/metricflow/test/sql/test_sql_expr_render.py +++ b/metricflow/test/sql/test_sql_expr_render.py @@ -4,6 +4,7 @@ import textwrap import pytest +from dbt_semantic_interfaces.type_enums.date_part import DatePart from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity from metricflow.sql.render.expr_renderer import DefaultSqlExpressionRenderer @@ -30,7 +31,6 @@ SqlWindowFunctionExpression, SqlWindowOrderByArgument, ) -from metricflow.time.date_part import DatePart logger = logging.getLogger(__name__) diff --git a/metricflow/test/sql_clients/test_date_time_operations.py b/metricflow/test/sql_clients/test_date_time_operations.py index c0f0a7689e..60d77f6bae 100644 --- a/metricflow/test/sql_clients/test_date_time_operations.py +++ b/metricflow/test/sql_clients/test_date_time_operations.py @@ -21,6 +21,7 @@ import pandas as pd import pytest +from dbt_semantic_interfaces.type_enums.date_part import DatePart from metricflow.protocols.sql_client import SqlClient from metricflow.sql.sql_exprs import ( @@ -29,7 +30,6 @@ SqlExtractExpression, SqlStringLiteralExpression, ) -from metricflow.time.date_part import DatePart from metricflow.time.time_granularity import TimeGranularity logger = logging.getLogger(__name__) diff --git a/metricflow/time/date_part.py b/metricflow/time/date_part.py deleted file mode 100644 index cc1b5768d1..0000000000 --- a/metricflow/time/date_part.py +++ /dev/null @@ -1,60 +0,0 @@ -from __future__ import annotations - -from enum import Enum -from typing import List - -from dbt_semantic_interfaces.enum_extension import assert_values_exhausted -from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity - - -class DatePart(Enum): - """Date parts able to be extracted from a time dimension. - - Note this does not support WEEK (aka WEEKOFYEAR), because week numbering is very strange. - The ISO spec calls for weeks to start on Monday. Fair enough. It also calls for years to - start on Monday, but only about 1 out of every 7 do. In order to ensure years start on - Monday, the ISO decided that the first day of any given year is the Monday of the week - containing the first Thursday of that year. Consequently, the ISO standard produces - weeks numbered 1-53, but any days belonging to the preceding calendar year but in the - first week of the new year are part of the new ISO year. This is not really what people - expect. - - But there's more - different SQL engines also have different implementations of week of year. - When not using ISO, you get either 0-53, 1-54, or 1-53 with different ways of deciding - how to count the first few days in any given year. As such, we just don't support this. - - When the time comes, we can support week using whatever standard makes the most sense for - our usage context, but as it is not clear what that standard looks like we simply don't - support date_part = week for now. - - TODO: add support for hour, minute, second once those granularities are available - """ - - YEAR = "year" - QUARTER = "quarter" - MONTH = "month" - DAY = "day" - DOW = "dow" - DOY = "doy" - - def to_int(self) -> int: - """Convert to an int so that the size of the granularity can be easily compared.""" - if self is DatePart.DAY: - return TimeGranularity.DAY.to_int() - elif self is DatePart.DOW: - return TimeGranularity.DAY.to_int() - elif self is DatePart.DOY: - return TimeGranularity.DAY.to_int() - elif self is DatePart.MONTH: - return TimeGranularity.MONTH.to_int() - elif self is DatePart.QUARTER: - return TimeGranularity.QUARTER.to_int() - elif self is DatePart.YEAR: - return TimeGranularity.YEAR.to_int() - else: - assert_values_exhausted(self) - - @property - def compatible_granularities(self) -> List[TimeGranularity]: - """Granularities that can be queried with this date part.""" - return [granularity for granularity in TimeGranularity if granularity.to_int() >= self.to_int()] diff --git a/metricflow/time/time_granularity_solver.py b/metricflow/time/time_granularity_solver.py index 802fc7fd45..55e33ca865 100644 --- a/metricflow/time/time_granularity_solver.py +++ b/metricflow/time/time_granularity_solver.py @@ -13,6 +13,7 @@ MetricReference, TimeDimensionReference, ) +from dbt_semantic_interfaces.type_enums.date_part import DatePart from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity from metricflow.dataflow.builder.node_data_set import DataflowPlanNodeOutputDataSetResolver @@ -23,7 +24,6 @@ from metricflow.specs.specs import ( TimeDimensionSpec, ) -from metricflow.time.date_part import DatePart from metricflow.time.time_granularity import ( adjust_to_end_of_period, adjust_to_start_of_period,