diff --git a/Makefile b/Makefile index e70cbb94ea..41eb3d1d16 100644 --- a/Makefile +++ b/Makefile @@ -74,6 +74,11 @@ postgresql postgres: regenerate-test-snapshots: hatch -v run dev-env:python metricflow/test/generate_snapshots.py +# Populate persistent source schemas for all relevant SQL engines. +.PHONY: populate-persistent-source-schemas +populate-persistent-source-schemas: + hatch -v run dev-env:python metricflow/test/populate_persistent_source_schemas.py + # Re-generate snapshots for the default SQL engine. .PHONY: test-snap test-snap: diff --git a/metricflow/test/generate_snapshots.py b/metricflow/test/generate_snapshots.py index ceb1568e84..3be91cb750 100644 --- a/metricflow/test/generate_snapshots.py +++ b/metricflow/test/generate_snapshots.py @@ -37,7 +37,7 @@ import logging import os from dataclasses import dataclass -from typing import Optional, Sequence +from typing import Callable, Optional, Sequence from dbt_semantic_interfaces.enum_extension import assert_values_exhausted from dbt_semantic_interfaces.implementations.base import FrozenBaseModel @@ -107,7 +107,11 @@ def run_command(command: str) -> None: # noqa: D raise RuntimeError(f"Error running command: {command}") -def run_tests(test_configuration: MetricFlowTestConfiguration) -> None: # noqa: D +def set_engine_env_variables(test_configuration: MetricFlowTestConfiguration) -> None: + """Set connection env variables dynamically for the engine being used. + + Requires MF_TEST_ENGINE_CREDENTIALS env variable to be set with creds for all engines. + """ if test_configuration.credential_set.engine_url is None: if "MF_SQL_ENGINE_URL" in os.environ: del os.environ["MF_SQL_ENGINE_URL"] @@ -120,6 +124,10 @@ def run_tests(test_configuration: MetricFlowTestConfiguration) -> None: # noqa: else: os.environ["MF_SQL_ENGINE_PASSWORD"] = test_configuration.credential_set.engine_password + +def run_tests(test_configuration: MetricFlowTestConfiguration) -> None: # noqa: D + set_engine_env_variables(test_configuration) + if test_configuration.engine is SqlEngine.DUCKDB: # DuckDB is fast, so generate all snapshots, including the engine-agnostic ones run_command(f"pytest -x -vv -n 4 --overwrite-snapshots -k 'not itest' {TEST_DIRECTORY}") @@ -145,7 +153,7 @@ def run_tests(test_configuration: MetricFlowTestConfiguration) -> None: # noqa: assert_values_exhausted(test_configuration.engine) -def run_cli() -> None: # noqa: D +def run_cli(function_to_run: Callable) -> None: # noqa: D # Setup logging. dev_format = "%(asctime)s %(levelname)s %(filename)s:%(lineno)d [%(threadName)s] - %(message)s" logging.basicConfig(level=logging.INFO, format=dev_format) @@ -165,8 +173,8 @@ def run_cli() -> None: # noqa: D logger.info( f"Running tests for {test_configuration.engine} with URL: {test_configuration.credential_set.engine_url}" ) - run_tests(test_configuration) + function_to_run(test_configuration) if __name__ == "__main__": - run_cli() + run_cli(run_tests) diff --git a/metricflow/test/populate_persistent_source_schemas.py b/metricflow/test/populate_persistent_source_schemas.py new file mode 100644 index 0000000000..0f0ff6f59c --- /dev/null +++ b/metricflow/test/populate_persistent_source_schemas.py @@ -0,0 +1,45 @@ +"""Script to help generate persistent source schemas with test data for all relevant engines.""" + +from __future__ import annotations + +import logging +import os + +from dbt_semantic_interfaces.enum_extension import assert_values_exhausted + +from metricflow.protocols.sql_client import SqlEngine +from metricflow.test.generate_snapshots import ( + MetricFlowTestConfiguration, + run_cli, + run_command, + set_engine_env_variables, +) + +logger = logging.getLogger(__name__) + + +def populate_schemas(test_configuration: MetricFlowTestConfiguration) -> None: # noqa: D + set_engine_env_variables(test_configuration) + + if test_configuration.engine is SqlEngine.DUCKDB or test_configuration.engine is SqlEngine.POSTGRES: + # DuckDB & Postgres don't use persistent source schema + return None + elif ( + test_configuration.engine is SqlEngine.SNOWFLAKE + or test_configuration.engine is SqlEngine.BIGQUERY + or test_configuration.engine is SqlEngine.DATABRICKS + or test_configuration.engine is SqlEngine.REDSHIFT + ): + engine_name = test_configuration.engine.value.lower() + os.environ["MF_TEST_ADAPTER_TYPE"] = engine_name + hatch_env = f"{engine_name}-env" + run_command( + f"hatch -v run {hatch_env}:pytest -vv --use-persistent-source-schema " + "metricflow/test/source_schema_tools.py::populate_source_schema" + ) + else: + assert_values_exhausted(test_configuration.engine) + + +if __name__ == "__main__": + run_cli(populate_schemas)