Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Script for populating all persistent source schemas #851

Merged
merged 1 commit into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ postgresql postgres:
regenerate-test-snapshots:
hatch -v run dev-env:python metricflow/test/generate_snapshots.py

# Populate persistent source schemas for all relevant SQL engines.
.PHONY: populate-persistent-source-schemas
populate-persistent-source-schemas:
hatch -v run dev-env:python metricflow/test/populate_persistent_source_schemas.py

# Re-generate snapshots for the default SQL engine.
.PHONY: test-snap
test-snap:
Expand Down
18 changes: 13 additions & 5 deletions metricflow/test/generate_snapshots.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import logging
import os
from dataclasses import dataclass
from typing import Optional, Sequence
from typing import Callable, Optional, Sequence

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.implementations.base import FrozenBaseModel
Expand Down Expand Up @@ -107,7 +107,11 @@ def run_command(command: str) -> None: # noqa: D
raise RuntimeError(f"Error running command: {command}")


def run_tests(test_configuration: MetricFlowTestConfiguration) -> None: # noqa: D
def set_engine_env_variables(test_configuration: MetricFlowTestConfiguration) -> None:
"""Set connection env variables dynamically for the engine being used.

Requires MF_TEST_ENGINE_CREDENTIALS env variable to be set with creds for all engines.
"""
if test_configuration.credential_set.engine_url is None:
if "MF_SQL_ENGINE_URL" in os.environ:
del os.environ["MF_SQL_ENGINE_URL"]
Expand All @@ -120,6 +124,10 @@ def run_tests(test_configuration: MetricFlowTestConfiguration) -> None: # noqa:
else:
os.environ["MF_SQL_ENGINE_PASSWORD"] = test_configuration.credential_set.engine_password


def run_tests(test_configuration: MetricFlowTestConfiguration) -> None: # noqa: D
set_engine_env_variables(test_configuration)

if test_configuration.engine is SqlEngine.DUCKDB:
# DuckDB is fast, so generate all snapshots, including the engine-agnostic ones
run_command(f"pytest -x -vv -n 4 --overwrite-snapshots -k 'not itest' {TEST_DIRECTORY}")
Expand All @@ -145,7 +153,7 @@ def run_tests(test_configuration: MetricFlowTestConfiguration) -> None: # noqa:
assert_values_exhausted(test_configuration.engine)


def run_cli() -> None: # noqa: D
def run_cli(function_to_run: Callable) -> None: # noqa: D
# Setup logging.
dev_format = "%(asctime)s %(levelname)s %(filename)s:%(lineno)d [%(threadName)s] - %(message)s"
logging.basicConfig(level=logging.INFO, format=dev_format)
Expand All @@ -165,8 +173,8 @@ def run_cli() -> None: # noqa: D
logger.info(
f"Running tests for {test_configuration.engine} with URL: {test_configuration.credential_set.engine_url}"
)
run_tests(test_configuration)
function_to_run(test_configuration)


if __name__ == "__main__":
run_cli()
run_cli(run_tests)
45 changes: 45 additions & 0 deletions metricflow/test/populate_persistent_source_schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""Script to help generate persistent source schemas with test data for all relevant engines."""

from __future__ import annotations

import logging
import os

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted

from metricflow.protocols.sql_client import SqlEngine
from metricflow.test.generate_snapshots import (
MetricFlowTestConfiguration,
run_cli,
run_command,
set_engine_env_variables,
)

logger = logging.getLogger(__name__)


def populate_schemas(test_configuration: MetricFlowTestConfiguration) -> None: # noqa: D
set_engine_env_variables(test_configuration)

if test_configuration.engine is SqlEngine.DUCKDB or test_configuration.engine is SqlEngine.POSTGRES:
# DuckDB & Postgres don't use persistent source schema
return None
elif (
test_configuration.engine is SqlEngine.SNOWFLAKE
or test_configuration.engine is SqlEngine.BIGQUERY
or test_configuration.engine is SqlEngine.DATABRICKS
or test_configuration.engine is SqlEngine.REDSHIFT
):
engine_name = test_configuration.engine.value.lower()
os.environ["MF_TEST_ADAPTER_TYPE"] = engine_name
hatch_env = f"{engine_name}-env"
run_command(
f"hatch -v run {hatch_env}:pytest -vv --use-persistent-source-schema "
"metricflow/test/source_schema_tools.py::populate_source_schema"
)
else:
assert_values_exhausted(test_configuration.engine)


if __name__ == "__main__":
run_cli(populate_schemas)