diff --git a/dagger/run_dbt_spark_tests.py b/dagger/run_dbt_spark_tests.py index 68dc12105..3c3fb935d 100644 --- a/dagger/run_dbt_spark_tests.py +++ b/dagger/run_dbt_spark_tests.py @@ -87,68 +87,56 @@ def get_spark_container(client: dagger.Client) -> (dagger.Service, str): async def test_spark(test_args): async with dagger.Connection(dagger.Config(log_output=sys.stderr)) as client: - test_profile = test_args.profile # create cache volumes, these are persisted between runs saving time when developing locally - os_reqs_cache = client.cache_volume("os_reqs") - pip_cache = client.cache_volume("pip") - - # setup directories as we don't want to copy the whole repo into the container - req_files = client.host().directory( - "./", include=["test.env", "hatch.toml", "pyproject.toml", "README.md", "License.md"] - ) - dbt_spark_dir = client.host().directory("./dbt") - test_dir = client.host().directory("./tests") - scripts = client.host().directory("./dagger/scripts") - - platform = dagger.Platform("linux/amd64") tst_container = ( - client.container(platform=platform) + client.container(platform=dagger.Platform("linux/amd64")) .from_("python:3.9-slim") - .with_mounted_cache("/var/cache/apt/archives", os_reqs_cache) - .with_mounted_cache("/root/.cache/pip", pip_cache) - # install OS deps first so any local changes don't invalidate the cache - .with_directory("/scripts", scripts) - .with_exec(["./scripts/install_os_reqs.sh"]) - # install dbt-spark + python deps - .with_directory("/src", req_files) - .with_exec(["pip", "install", "-U", "pip", "hatch"]) + .with_mounted_cache("/var/cache/apt/archives", client.cache_volume("os_reqs")) + .with_mounted_cache("/root/.cache/pip", client.cache_volume("pip")) ) - # install local dbt-spark changes + # install system dependencies first so any local changes don't invalidate the cache tst_container = ( tst_container.with_workdir("/") - .with_directory("src/dbt", dbt_spark_dir) - .with_workdir("/src") - .with_exec(["hatch", "shell"]) + .with_directory("/scripts", client.host().directory("./dagger/scripts")) + .with_exec(["./scripts/install_os_reqs.sh"]) + .with_exec(["pip", "install", "-U", "pip", "hatch"]) + .with_(env_variables(TESTING_ENV_VARS)) ) - # install local test changes + # copy project files into image tst_container = ( tst_container.with_workdir("/") - .with_directory("src/tests", test_dir) - .with_workdir("/src") + .with_directory("/src/dbt", client.host().directory("./dbt")) + .with_directory("/src/tests", client.host().directory("./tests")) + .with_file("/src/hatch.toml", client.host().file("./hatch.toml")) + .with_file("/src/License.md", client.host().file("./License.md")) + .with_file("/src/pyproject.toml", client.host().file("./pyproject.toml")) + .with_file("/src/README.md", client.host().file("./README.md")) + .with_file("/src/test.env", client.host().file("./test.env")) ) - if test_profile == "apache_spark": + # install profile-specific system dependencies last since tests usually rotate through profiles + if test_args.profile == "apache_spark": spark_ctr, spark_host = get_spark_container(client) tst_container = tst_container.with_service_binding(alias=spark_host, service=spark_ctr) - elif test_profile in ["databricks_cluster", "databricks_sql_endpoint", "spark_http_odbc"]: - tst_container = ( - tst_container.with_workdir("/") - .with_exec(["./scripts/configure_odbc.sh"]) - .with_workdir("/src") + elif test_args.profile in [ + "databricks_cluster", + "databricks_sql_endpoint", + "spark_http_odbc", + ]: + tst_container = tst_container.with_workdir("/").with_exec( + ["./scripts/configure_odbc.sh"] ) - elif test_profile == "spark_session": - tst_container = tst_container.with_exec(["pip", "install", "pyspark"]) + elif test_args.profile == "spark_session": tst_container = tst_container.with_exec(["apt-get", "install", "openjdk-17-jre", "-y"]) - tst_container = tst_container.with_(env_variables(TESTING_ENV_VARS)) - test_path = test_args.test_path if test_args.test_path else "tests/functional/adapter" + # run the tests result = await tst_container.with_exec( - ["hatch", "run", "pytest", "--profile", test_profile, test_path] + ["hatch", "run", "pytest", "--profile", test_args.profile, test_args.test_path] ).stdout() return result @@ -156,7 +144,7 @@ async def test_spark(test_args): parser = argparse.ArgumentParser() parser.add_argument("--profile", required=True, type=str) -parser.add_argument("--test-path", required=False, type=str) +parser.add_argument("--test-path", required=False, type=str, default="tests/functional/adapter") args = parser.parse_args() anyio.run(test_spark, args)