Skip to content

Commit

Permalink
Merge pull request #11 from kbase/dev_using_hive
Browse files Browse the repository at this point in the history
using hive for metastore
  • Loading branch information
Tianhao-Gu authored May 28, 2024
2 parents 2120545 + 0a42335 commit 720d734
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 43 deletions.
115 changes: 76 additions & 39 deletions src/spark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,57 +3,94 @@
from pyspark.conf import SparkConf
from pyspark.sql import SparkSession

# Default directory for JAR files in the Bitnami Spark image
JAR_DIR = '/opt/bitnami/spark/jars'
HADOOP_AWS_VER = os.getenv('HADOOP_AWS_VER')
DELTA_SPARK_VER = os.getenv('DELTA_SPARK_VER')
SCALA_VER = os.getenv('SCALA_VER')

def get_spark_session(app_name: str,
local: bool = False,
delta_lake: bool = False) -> SparkSession:

def _get_jars(jar_names: list) -> str:
"""
Helper function to get the required JAR files as a comma-separated string.
:param jar_names: List of JAR file names
:return: A comma-separated string of JAR file paths
"""
jars = [os.path.join(JAR_DIR, jar) for jar in jar_names]

missing_jars = [jar for jar in jars if not os.path.exists(jar)]
if missing_jars:
raise FileNotFoundError(f"Some required jars are not found: {missing_jars}")

return ", ".join(jars)


def _get_delta_lake_conf(jars_str: str) -> dict:
"""
Helper function to get Delta Lake specific Spark configuration.
:param jars_str: A comma-separated string of JAR file paths
:return: A dictionary of Delta Lake specific Spark configuration
reference: https://blog.min.io/delta-lake-minio-multi-cloud/
"""
return {
"spark.jars": jars_str,
"spark.sql.extensions": "io.delta.sql.DeltaSparkSessionExtension",
"spark.sql.catalog.spark_catalog": "org.apache.spark.sql.delta.catalog.DeltaCatalog",
"spark.databricks.delta.retentionDurationCheck.enabled": "false",
"spark.hadoop.fs.s3a.endpoint": os.environ.get("MINIO_URL"),
"spark.hadoop.fs.s3a.access.key": os.environ.get("MINIO_ACCESS_KEY"),
"spark.hadoop.fs.s3a.secret.key": os.environ.get("MINIO_SECRET_KEY"),
"spark.hadoop.fs.s3a.path.style.access": "true",
"spark.hadoop.fs.s3a.impl": "org.apache.hadoop.fs.s3a.S3AFileSystem",
"spark.sql.catalogImplementation": "hive",
}


def get_base_spark_conf(app_name: str) -> SparkConf:
"""
Helper to get and manage the `SparkSession` and keep all of our spark configuration params in one place.
Helper function to get the base Spark configuration.
:param app_name: The name of the application
:param local: Whether to run the spark session locally or doesn't
:param delta_lake: build the spark session with delta lake support
:return: A `SparkSession` object
:return: A SparkConf object with the base configuration
"""
return SparkConf().setAll([
("spark.master", os.environ.get("SPARK_MASTER_URL", "spark://spark-master:7077")),
("spark.app.name", app_name),
])


def get_spark_session(
app_name: str,
local: bool = False,
delta_lake: bool = False) -> SparkSession:
"""
Helper to get and manage the SparkSession and keep all of our spark configuration params in one place.
:param app_name: The name of the application
:param local: Whether to run the spark session locally or not
:param delta_lake: Build the spark session with Delta Lake support
:return: A SparkSession object
"""
if local:
return SparkSession.builder.appName(app_name).getOrCreate()

spark_conf = SparkConf()
spark_conf = get_base_spark_conf(app_name)

if delta_lake:

jars_dir = "/opt/bitnami/spark/jars/"
jar_files = [os.path.join(jars_dir, f) for f in os.listdir(jars_dir) if f.endswith(".jar")]
jars = ",".join(jar_files)

spark_conf.setAll(
[
(
"spark.master",
os.environ.get("SPARK_MASTER_URL", "spark://spark-master:7077"),
),
("spark.app.name", app_name),
("spark.hadoop.fs.s3a.endpoint", os.environ.get("MINIO_URL")),
("spark.hadoop.fs.s3a.access.key", os.environ.get("MINIO_ACCESS_KEY")),
("spark.hadoop.fs.s3a.secret.key", os.environ.get("MINIO_SECRET_KEY")),
("spark.jars", jars),
("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension"),
("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog"),
("spark.hadoop.fs.s3a.path.style.access", "true"),
("spark.hadoop.fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem"),
]
)
else:
spark_conf.setAll(
[
(
"spark.master",
os.environ.get("SPARK_MASTER_URL", "spark://spark-master:7077"),
),
("spark.app.name", app_name),
]
)
# Just to include the necessary jars for Delta Lake
jar_names = [f"delta-spark_{SCALA_VER}-{DELTA_SPARK_VER}.jar",
f"hadoop-aws-{HADOOP_AWS_VER}.jar"]
jars_str = _get_jars(jar_names)
delta_conf = _get_delta_lake_conf(jars_str)
for key, value in delta_conf.items():
spark_conf.set(key, value)

return SparkSession.builder.config(conf=spark_conf).getOrCreate()
53 changes: 49 additions & 4 deletions test/src/spark/utils_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import socket
from unittest.mock import patch
from unittest import mock

import pytest
from pyspark import SparkConf
from pyspark.sql import SparkSession

from src.spark.utils import get_spark_session
from src.spark.utils import get_spark_session, _get_jars, get_base_spark_conf, JAR_DIR


@pytest.fixture(scope="session")
Expand All @@ -26,7 +27,7 @@ def mock_spark_master():
@pytest.fixture
def spark_session_local():
"""Provide a local Spark session for testing."""
with patch.dict('os.environ', {}):
with mock.patch.dict('os.environ', {}):
spark_session = get_spark_session("TestApp", local=True)
print("Created local Spark session.")
try:
Expand All @@ -43,7 +44,7 @@ def spark_session_non_local(mock_spark_master):
spark_master_url = f"spark://localhost:{port}"
print(f"Using Spark master URL: {spark_master_url}")

with patch.dict('os.environ', {"SPARK_MASTER_URL": spark_master_url}):
with mock.patch.dict('os.environ', {"SPARK_MASTER_URL": spark_master_url}):
spark_session = get_spark_session("TestApp", local=False)
print("Created non-local Spark session.")
try:
Expand All @@ -66,3 +67,47 @@ def test_spark_session_non_local(spark_session_non_local):
assert isinstance(spark_session, SparkSession)
assert spark_session.conf.get("spark.master") == f"spark://localhost:{port}"
assert spark_session.conf.get("spark.app.name") == "TestApp"


def test_get_jars_success():
jar_names = ["jar1.jar", "jar2.jar"]
expected = f"{JAR_DIR}/jar1.jar, {JAR_DIR}/jar2.jar"

with mock.patch('os.path.exists', return_value=True):
result = _get_jars(jar_names)
assert result == expected


def test_get_jars_missing_file():
jar_names = ["jar1.jar", "jar2.jar"]

def side_effect(path):
return "jar1.jar" in path

with mock.patch('os.path.exists', side_effect=side_effect):
with pytest.raises(FileNotFoundError) as excinfo:
_get_jars(jar_names)
assert "Some required jars are not found" in str(excinfo.value)


def test_get_base_spark_conf():
app_name = "test_app"
expected_master_url = "spark://spark-master:7077"
expected_app_name = app_name

with mock.patch.dict('os.environ', {}):
result = get_base_spark_conf(app_name)
assert isinstance(result, SparkConf)
assert result.get("spark.master") == expected_master_url
assert result.get("spark.app.name") == expected_app_name


def test_get_base_spark_conf_with_env():
app_name = "test_app"
custom_master_url = "spark://custom-master:7077"

with mock.patch.dict('os.environ', {"SPARK_MASTER_URL": custom_master_url}):
result = get_base_spark_conf(app_name)
assert isinstance(result, SparkConf)
assert result.get("spark.master") == custom_master_url
assert result.get("spark.app.name") == app_name

0 comments on commit 720d734

Please sign in to comment.