Skip to content

Commit

Permalink
update spark container to be more clear about what is happening when
Browse files Browse the repository at this point in the history
  • Loading branch information
mikealfare committed Dec 19, 2024
1 parent b6acf21 commit dc41792
Showing 1 changed file with 29 additions and 41 deletions.
70 changes: 29 additions & 41 deletions dagger/run_dbt_spark_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,76 +87,64 @@ def get_spark_container(client: dagger.Client) -> (dagger.Service, str):

async def test_spark(test_args):
async with dagger.Connection(dagger.Config(log_output=sys.stderr)) as client:
test_profile = test_args.profile

# create cache volumes, these are persisted between runs saving time when developing locally
os_reqs_cache = client.cache_volume("os_reqs")
pip_cache = client.cache_volume("pip")

# setup directories as we don't want to copy the whole repo into the container
req_files = client.host().directory(
"./", include=["test.env", "hatch.toml", "pyproject.toml", "README.md", "License.md"]
)
dbt_spark_dir = client.host().directory("./dbt")
test_dir = client.host().directory("./tests")
scripts = client.host().directory("./dagger/scripts")

platform = dagger.Platform("linux/amd64")
tst_container = (
client.container(platform=platform)
client.container(platform=dagger.Platform("linux/amd64"))
.from_("python:3.9-slim")
.with_mounted_cache("/var/cache/apt/archives", os_reqs_cache)
.with_mounted_cache("/root/.cache/pip", pip_cache)
# install OS deps first so any local changes don't invalidate the cache
.with_directory("/scripts", scripts)
.with_exec(["./scripts/install_os_reqs.sh"])
# install dbt-spark + python deps
.with_directory("/src", req_files)
.with_exec(["pip", "install", "-U", "pip", "hatch"])
.with_mounted_cache("/var/cache/apt/archives", client.cache_volume("os_reqs"))
.with_mounted_cache("/root/.cache/pip", client.cache_volume("pip"))
)

# install local dbt-spark changes
# install system dependencies first so any local changes don't invalidate the cache
tst_container = (
tst_container.with_workdir("/")
.with_directory("src/dbt", dbt_spark_dir)
.with_workdir("/src")
.with_exec(["hatch", "shell"])
.with_directory("/scripts", client.host().directory("./dagger/scripts"))
.with_exec(["./scripts/install_os_reqs.sh"])
.with_exec(["pip", "install", "-U", "pip", "hatch"])
.with_(env_variables(TESTING_ENV_VARS))
)

# install local test changes
# copy project files into image
tst_container = (
tst_container.with_workdir("/")
.with_directory("src/tests", test_dir)
.with_workdir("/src")
.with_directory("/src/dbt", client.host().directory("./dbt"))
.with_directory("/src/tests", client.host().directory("./tests"))
.with_file("/src/hatch.toml", client.host().file("./hatch.toml"))
.with_file("/src/License.md", client.host().file("./License.md"))
.with_file("/src/pyproject.toml", client.host().file("./pyproject.toml"))
.with_file("/src/README.md", client.host().file("./README.md"))
.with_file("/src/test.env", client.host().file("./test.env"))
)

if test_profile == "apache_spark":
# install profile-specific system dependencies last since tests usually rotate through profiles
if test_args.profile == "apache_spark":
spark_ctr, spark_host = get_spark_container(client)
tst_container = tst_container.with_service_binding(alias=spark_host, service=spark_ctr)

elif test_profile in ["databricks_cluster", "databricks_sql_endpoint", "spark_http_odbc"]:
tst_container = (
tst_container.with_workdir("/")
.with_exec(["./scripts/configure_odbc.sh"])
.with_workdir("/src")
elif test_args.profile in [
"databricks_cluster",
"databricks_sql_endpoint",
"spark_http_odbc",
]:
tst_container = tst_container.with_workdir("/").with_exec(
["./scripts/configure_odbc.sh"]
)

elif test_profile == "spark_session":
tst_container = tst_container.with_exec(["pip", "install", "pyspark"])
elif test_args.profile == "spark_session":
tst_container = tst_container.with_exec(["apt-get", "install", "openjdk-17-jre", "-y"])

tst_container = tst_container.with_(env_variables(TESTING_ENV_VARS))
test_path = test_args.test_path if test_args.test_path else "tests/functional/adapter"
# run the tests
result = await tst_container.with_exec(
["hatch", "run", "pytest", "--profile", test_profile, test_path]
["hatch", "run", "pytest", "--profile", test_args.profile, test_args.test_path]
).stdout()

return result


parser = argparse.ArgumentParser()
parser.add_argument("--profile", required=True, type=str)
parser.add_argument("--test-path", required=False, type=str)
parser.add_argument("--test-path", required=False, type=str, default="tests/functional/adapter")
args = parser.parse_args()

anyio.run(test_spark, args)

0 comments on commit dc41792

Please sign in to comment.