Skip to content

Commit

Permalink
Script for populating all persistent source schemas
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Nov 7, 2023
1 parent 043bbbb commit c95b3b7
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 5 deletions.
6 changes: 6 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ populate-persistent-source-schema-snowflake:
hatch -v run snowflake-env:pytest -vv $(ADDITIONAL_PYTEST_OPTIONS) $(USE_PERSISTENT_SOURCE_SCHEMA) $(POPULATE_PERSISTENT_SOURCE_SCHEMA)



.PHONY: lint
lint:
hatch -v run dev-env:pre-commit run --all-files
Expand All @@ -74,6 +75,11 @@ postgresql postgres:
regenerate-test-snapshots:
hatch -v run dev-env:python metricflow/test/generate_snapshots.py

# Populate persistent source schemas for all relevant SQL engines.
.PHONY: populate-persistent-source-schemas
populate-persistent-source-schemas:
hatch -v run dev-env:python metricflow/test/populate_persistent_source_schemas.py

# Re-generate snapshots for the default SQL engine.
.PHONY: test-snap
test-snap:
Expand Down
18 changes: 13 additions & 5 deletions metricflow/test/generate_snapshots.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import logging
import os
from dataclasses import dataclass
from typing import Optional, Sequence
from typing import Callable, Optional, Sequence

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.implementations.base import FrozenBaseModel
Expand Down Expand Up @@ -107,7 +107,11 @@ def run_command(command: str) -> None: # noqa: D
raise RuntimeError(f"Error running command: {command}")


def run_tests(test_configuration: MetricFlowTestConfiguration) -> None: # noqa: D
def set_engine_env_variables(test_configuration: MetricFlowTestConfiguration) -> None:
"""Set connection env variables dynamically for the engine being used.
Requires MF_TEST_ENGINE_CREDENTIALS env variable to be set with creds for all engines.
"""
if test_configuration.credential_set.engine_url is None:
if "MF_SQL_ENGINE_URL" in os.environ:
del os.environ["MF_SQL_ENGINE_URL"]
Expand All @@ -120,6 +124,10 @@ def run_tests(test_configuration: MetricFlowTestConfiguration) -> None: # noqa:
else:
os.environ["MF_SQL_ENGINE_PASSWORD"] = test_configuration.credential_set.engine_password


def run_tests(test_configuration: MetricFlowTestConfiguration) -> None: # noqa: D
set_engine_env_variables(test_configuration)

if test_configuration.engine is SqlEngine.DUCKDB:
# DuckDB is fast, so generate all snapshots, including the engine-agnostic ones
run_command(f"pytest -x -vv -n 4 --overwrite-snapshots -k 'not itest' {TEST_DIRECTORY}")
Expand All @@ -145,7 +153,7 @@ def run_tests(test_configuration: MetricFlowTestConfiguration) -> None: # noqa:
assert_values_exhausted(test_configuration.engine)


def run_cli() -> None: # noqa: D
def run_cli(function_to_run: Callable) -> None: # noqa: D
# Setup logging.
dev_format = "%(asctime)s %(levelname)s %(filename)s:%(lineno)d [%(threadName)s] - %(message)s"
logging.basicConfig(level=logging.INFO, format=dev_format)
Expand All @@ -165,8 +173,8 @@ def run_cli() -> None: # noqa: D
logger.info(
f"Running tests for {test_configuration.engine} with URL: {test_configuration.credential_set.engine_url}"
)
run_tests(test_configuration)
function_to_run(test_configuration)


if __name__ == "__main__":
run_cli()
run_cli(run_tests)
45 changes: 45 additions & 0 deletions metricflow/test/populate_persistent_source_schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""Script to help generate persistent source schemas with test data for all relevant engines."""

from __future__ import annotations

import logging
import os

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted

from metricflow.protocols.sql_client import SqlEngine
from metricflow.test.generate_snapshots import (
MetricFlowTestConfiguration,
run_cli,
run_command,
set_engine_env_variables,
)

logger = logging.getLogger(__name__)


def populate_schemas(test_configuration: MetricFlowTestConfiguration) -> None: # noqa: D
set_engine_env_variables(test_configuration)

if test_configuration.engine is SqlEngine.DUCKDB or test_configuration.engine is SqlEngine.POSTGRES:
# DuckDB & Postgres don't use persistent source schema
return None
elif (
test_configuration.engine is SqlEngine.SNOWFLAKE
or test_configuration.engine is SqlEngine.BIGQUERY
or test_configuration.engine is SqlEngine.DATABRICKS
or test_configuration.engine is SqlEngine.REDSHIFT
):
engine_name = test_configuration.engine.value.lower()
os.environ["MF_TEST_ADAPTER_TYPE"] = engine_name
hatch_env = f"{engine_name}-env"
run_command(
f"hatch -v run {hatch_env}:pytest -vv --use-persistent-source-schema "
"metricflow/test/source_schema_tools.py::populate_source_schema"
)
else:
assert_values_exhausted(test_configuration.engine)


if __name__ == "__main__":
run_cli(populate_schemas)

0 comments on commit c95b3b7

Please sign in to comment.