diff --git a/src/spark/utils.py b/src/spark/utils.py index 108696b..c050b4e 100644 --- a/src/spark/utils.py +++ b/src/spark/utils.py @@ -3,57 +3,95 @@ from pyspark.conf import SparkConf from pyspark.sql import SparkSession +# Default directory for JAR files in the Bitnami Spark image +JAR_DIR = '/opt/bitnami/spark/jars' +HADOOP_AWS_VER = os.getenv('HADOOP_AWS_VER') +DELTA_SPARK_VER = os.getenv('DELTA_SPARK_VER') +SCALA_VER = os.getenv('SCALA_VER') -def get_spark_session(app_name: str, - local: bool = False, - delta_lake: bool = False) -> SparkSession: + +def _get_jars(jar_names: list) -> str: + """ + Helper function to get the required JAR files as a comma-separated string. + + :param jar_names: List of JAR file names + + :return: A comma-separated string of JAR file paths + """ + jars = [os.path.join(JAR_DIR, jar) for jar in jar_names] + + missing_jars = [jar for jar in jars if not os.path.exists(jar)] + if missing_jars: + raise FileNotFoundError(f"Some required jars are not found: {missing_jars}") + + return ", ".join(jars) + + +def _get_delta_lake_conf(jars_str: str) -> dict: """ - Helper to get and manage the `SparkSession` and keep all of our spark configuration params in one place. + Helper function to get Delta Lake specific Spark configuration. + + :param jars_str: A comma-separated string of JAR file paths + + :return: A dictionary of Delta Lake specific Spark configuration + + reference: https://blog.min.io/delta-lake-minio-multi-cloud/ + """ + return { + "spark.jars": jars_str, + "spark.sql.extensions": "io.delta.sql.DeltaSparkSessionExtension", + "spark.sql.catalog.spark_catalog": "org.apache.spark.sql.delta.catalog.DeltaCatalog", + "spark.databricks.delta.retentionDurationCheck.enabled": "false", + "spark.hadoop.fs.s3a.endpoint": os.environ.get("MINIO_URL"), + "spark.hadoop.fs.s3a.access.key": os.environ.get("MINIO_ACCESS_KEY"), + "spark.hadoop.fs.s3a.secret.key": os.environ.get("MINIO_SECRET_KEY"), + "spark.hadoop.fs.s3a.path.style.access": "true", + "spark.hadoop.fs.s3a.impl": "org.apache.hadoop.fs.s3a.S3AFileSystem", + "spark.sql.catalogImplementation": "hive" + } + + +def get_base_spark_conf(app_name: str) -> SparkConf: + """ + Helper function to get the base Spark configuration. :param app_name: The name of the application - :param local: Whether to run the spark session locally or doesn't - :param delta_lake: build the spark session with delta lake support - :return: A `SparkSession` object + :return: A SparkConf object with the base configuration """ + return SparkConf().setAll([ + ("spark.master", os.environ.get("SPARK_MASTER_URL", "spark://spark-master:7077")), + ("spark.app.name", app_name), + ]) + +def get_spark_session( + app_name: str, + local: bool = False, + delta_lake: bool = False) -> SparkSession: + """ + Helper to get and manage the SparkSession and keep all of our spark configuration params in one place. + + :param app_name: The name of the application + :param local: Whether to run the spark session locally or not + :param delta_lake: Build the spark session with Delta Lake support + + :return: A SparkSession object + """ if local: return SparkSession.builder.appName(app_name).getOrCreate() - spark_conf = SparkConf() + spark_conf = get_base_spark_conf(app_name) if delta_lake: - jars_dir = "/opt/bitnami/spark/jars/" - jar_files = [os.path.join(jars_dir, f) for f in os.listdir(jars_dir) if f.endswith(".jar")] - jars = ",".join(jar_files) - - spark_conf.setAll( - [ - ( - "spark.master", - os.environ.get("SPARK_MASTER_URL", "spark://spark-master:7077"), - ), - ("spark.app.name", app_name), - ("spark.hadoop.fs.s3a.endpoint", os.environ.get("MINIO_URL")), - ("spark.hadoop.fs.s3a.access.key", os.environ.get("MINIO_ACCESS_KEY")), - ("spark.hadoop.fs.s3a.secret.key", os.environ.get("MINIO_SECRET_KEY")), - ("spark.jars", jars), - ("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension"), - ("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog"), - ("spark.hadoop.fs.s3a.path.style.access", "true"), - ("spark.hadoop.fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem"), - ] - ) - else: - spark_conf.setAll( - [ - ( - "spark.master", - os.environ.get("SPARK_MASTER_URL", "spark://spark-master:7077"), - ), - ("spark.app.name", app_name), - ] - ) - - return SparkSession.builder.config(conf=spark_conf).getOrCreate() + # Just to include the necessary jars for Delta Lake + jar_names = [f"delta-spark_{SCALA_VER}-{DELTA_SPARK_VER}.jar", + # f"delta-storage-{DELTA_SPARK_VER}.jar", + f"hadoop-aws-{HADOOP_AWS_VER}.jar"] + jars_str = _get_jars(jar_names) + delta_conf = _get_delta_lake_conf(jars_str) + for key, value in delta_conf.items(): + spark_conf.set(key, value) + + return SparkSession.builder.config(conf=spark_conf).enableHiveSupport().getOrCreate() diff --git a/test/src/spark/utils_test.py b/test/src/spark/utils_test.py index 4a47a95..a9658b7 100644 --- a/test/src/spark/utils_test.py +++ b/test/src/spark/utils_test.py @@ -1,10 +1,11 @@ import socket -from unittest.mock import patch +from unittest import mock import pytest +from pyspark import SparkConf from pyspark.sql import SparkSession -from src.spark.utils import get_spark_session +from src.spark.utils import get_spark_session, _get_jars, get_base_spark_conf, JAR_DIR @pytest.fixture(scope="session") @@ -26,7 +27,7 @@ def mock_spark_master(): @pytest.fixture def spark_session_local(): """Provide a local Spark session for testing.""" - with patch.dict('os.environ', {}): + with mock.patch.dict('os.environ', {}): spark_session = get_spark_session("TestApp", local=True) print("Created local Spark session.") try: @@ -43,7 +44,7 @@ def spark_session_non_local(mock_spark_master): spark_master_url = f"spark://localhost:{port}" print(f"Using Spark master URL: {spark_master_url}") - with patch.dict('os.environ', {"SPARK_MASTER_URL": spark_master_url}): + with mock.patch.dict('os.environ', {"SPARK_MASTER_URL": spark_master_url}): spark_session = get_spark_session("TestApp", local=False) print("Created non-local Spark session.") try: @@ -66,3 +67,47 @@ def test_spark_session_non_local(spark_session_non_local): assert isinstance(spark_session, SparkSession) assert spark_session.conf.get("spark.master") == f"spark://localhost:{port}" assert spark_session.conf.get("spark.app.name") == "TestApp" + + +def test_get_jars_success(): + jar_names = ["jar1.jar", "jar2.jar"] + expected = f"{JAR_DIR}/jar1.jar, {JAR_DIR}/jar2.jar" + + with mock.patch('os.path.exists', return_value=True): + result = _get_jars(jar_names) + assert result == expected + + +def test_get_jars_missing_file(): + jar_names = ["jar1.jar", "jar2.jar"] + + def side_effect(path): + return "jar1.jar" in path + + with mock.patch('os.path.exists', side_effect=side_effect): + with pytest.raises(FileNotFoundError) as excinfo: + _get_jars(jar_names) + assert "Some required jars are not found" in str(excinfo.value) + + +def test_get_base_spark_conf(): + app_name = "test_app" + expected_master_url = "spark://spark-master:7077" + expected_app_name = app_name + + with mock.patch.dict('os.environ', {}): + result = get_base_spark_conf(app_name) + assert isinstance(result, SparkConf) + assert result.get("spark.master") == expected_master_url + assert result.get("spark.app.name") == expected_app_name + + +def test_get_base_spark_conf_with_env(): + app_name = "test_app" + custom_master_url = "spark://custom-master:7077" + + with mock.patch.dict('os.environ', {"SPARK_MASTER_URL": custom_master_url}): + result = get_base_spark_conf(app_name) + assert isinstance(result, SparkConf) + assert result.get("spark.master") == custom_master_url + assert result.get("spark.app.name") == app_name